summaryrefslogtreecommitdiff
path: root/research/flossing/engelken_python_flossing.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/engelken_python_flossing.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/engelken_python_flossing.py')
-rw-r--r--research/flossing/engelken_python_flossing.py336
1 files changed, 336 insertions, 0 deletions
diff --git a/research/flossing/engelken_python_flossing.py b/research/flossing/engelken_python_flossing.py
new file mode 100644
index 0000000..b83be2f
--- /dev/null
+++ b/research/flossing/engelken_python_flossing.py
@@ -0,0 +1,336 @@
+"""PyTorch port of Rainer Engelken's GradientFlossing_ExampleCode.jl.
+
+This script is intentionally a vanilla-RNN delayed-XOR reproduction scaffold,
+not an HRM/TRM experiment. It keeps the algorithmic choices that matter for
+checking gradient flossing itself:
+
+ - flossing is a separate phase, not mixed with task loss;
+ - flossing optimizes only input/recurrent/bias parameters, not readout;
+ - flossing uses mean((lambda_i - target)^2);
+ - flossing batch size is one, as in the Julia code;
+ - QR is differentiable and participates in the backward pass;
+ - Lyapunov accumulation discards the initial transient.
+
+Reference:
+ research/flossing/external/GradientFlossing/GradientFlossing_ExampleCode.jl
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import math
+import time
+from dataclasses import asdict, dataclass
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def recabs(bits: np.ndarray) -> np.ndarray:
+ out = bits[0]
+ for bit in bits[1:]:
+ out = np.abs(out - bit)
+ return out
+
+
+def generate_input_output(
+ batch_size: int,
+ steps: int,
+ seed: int,
+ input_dim: int,
+ input_scale: float,
+ delay: int,
+ task: int,
+ device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Port of generateInputOutput in GradientFlossing_XOR.jl.
+
+ Returns tensors shaped as Julia uses conceptually:
+ s: [T, S, B]
+ target: [T, 1, B]
+ """
+ rng = np.random.default_rng(seed)
+ raw_steps = steps + delay
+ s = rng.integers(0, 2, size=(raw_steps, input_dim, batch_size), dtype=np.int64).astype(np.float32)
+ target = np.zeros((raw_steps, 1, batch_size), dtype=np.float32)
+
+ if task == 0:
+ for i in range(delay, raw_steps):
+ target[i, :, :] = s[i - delay, :1, :]
+ else:
+ h = abs(task)
+ span = delay * (2**h)
+ for i in range(span, raw_steps):
+ indices = [i - delay * j for j in range(1, 2**h + 1)]
+ values = np.flip(s[indices, 0, :], axis=0)
+ target[i, 0, :] = recabs(values)
+
+ sx = torch.from_numpy(s[delay:] * np.float32(input_scale)).to(device)
+ ty = torch.from_numpy(target[delay:]).to(device)
+ return sx, ty
+
+
+@dataclass
+class Config:
+ hidden_size: int = 80
+ train_epochs: int = 3001
+ inter_period: int = 100
+ inter_epochs: int = 500
+ pre_epochs: int = 500
+ max_inter_episodes: int = 2
+ batch_size: int = 16
+ input_dim: int = 1
+ train_steps: int = 300
+ lyap_steps: int = 55
+ floss_input_steps: int = 300
+ seed_ic: int = 1
+ seed_input: int = 1
+ seed_net: int = 1
+ seed_ons: int = 1
+ lr: float = 1e-3
+ beta1: float = 0.9
+ beta2: float = 0.999
+ init_type: int = 1
+ recurrent_gain: float = 1.0
+ recurrent_mean_gain: float = 0.0
+ input_scale: float = 1.0
+ delay: int = 10
+ ws_std: float = 1.0
+ ws_mean: float = 0.0
+ wr_std: float = 1.0
+ wr_mean: float = 0.0
+ b_std: float = 0.1
+ b_mean: float = 0.0
+ n_lyap: int = 75
+ task: int = -1
+ lyap_target: float = 0.0
+ eval_every: int = 100
+ eval_batches: int = 4
+ log_every_floss: int = 50
+ device: str = "cpu"
+ out: str = "research/flossing/engelken_python/run.json"
+
+
+class VanillaRNN:
+ def __init__(self, cfg: Config, device: torch.device):
+ self.cfg = cfg
+ self.device = device
+ torch.manual_seed(cfg.seed_net)
+ n = cfg.hidden_size
+
+ self.ws = torch.nn.Parameter(
+ cfg.ws_std * torch.randn(n, cfg.input_dim, device=device) + cfg.ws_mean
+ )
+ self.wr = torch.nn.Parameter(
+ cfg.wr_std * torch.randn(1, n, device=device) + cfg.wr_mean
+ )
+ self.b = torch.nn.Parameter(
+ cfg.b_mean + cfg.b_std * torch.rand(n, device=device)
+ )
+ self.offset = torch.nn.Parameter(torch.tensor([0.001], device=device))
+
+ if cfg.init_type == 1:
+ j = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n)
+ if cfg.recurrent_mean_gain != 0:
+ j = j + cfg.recurrent_mean_gain / n
+ elif cfg.init_type == 3:
+ j0 = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n)
+ q, _ = torch.linalg.qr(j0)
+ j = cfg.recurrent_gain * q
+ else:
+ raise ValueError(f"unsupported init_type={cfg.init_type}")
+ self.j = torch.nn.Parameter(j)
+
+ self.x_init = torch.nn.Parameter(self._relax_initial_condition(cfg.seed_ic, 100))
+
+ def _relax_initial_condition(self, seed: int, steps: int) -> torch.Tensor:
+ gen = torch.Generator(device=self.device).manual_seed(seed)
+ x = torch.randn(self.cfg.hidden_size, self.cfg.batch_size, generator=gen, device=self.device)
+ with torch.no_grad():
+ for _ in range(steps):
+ x = self.j.detach() @ torch.tanh(x) + self.b.detach().unsqueeze(1)
+ return x
+
+ def task_parameters(self) -> list[torch.nn.Parameter]:
+ return [self.ws, self.j, self.wr, self.b, self.offset, self.x_init]
+
+ def floss_parameters(self) -> list[torch.nn.Parameter]:
+ return [self.ws, self.j, self.b]
+
+ def task_loss_and_accuracy(self, s: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, float]:
+ cfg = self.cfg
+ first_loss_idx = 2 * cfg.delay if cfg.task > 0 else cfg.delay * (2 ** abs(cfg.task))
+ # Julia is 1-indexed and starts loss at ti >= firstlossidx.
+ first_loss_idx = max(first_loss_idx - 1, 0)
+ x = self.x_init
+ loss = torch.zeros((), device=self.device)
+ correct = 0.0
+ count = 0
+ for ti in range(cfg.train_steps):
+ x = self.j @ torch.tanh(x) + self.b.unsqueeze(1) + (self.ws @ s[ti]) / cfg.input_dim
+ if ti >= first_loss_idx:
+ logits = self.wr @ x + self.offset.unsqueeze(1)
+ loss = loss + F.binary_cross_entropy_with_logits(logits, target[ti], reduction="sum")
+ pred = (torch.sigmoid(logits) >= 0.5).to(target.dtype)
+ correct += (pred == target[ti]).float().sum().item()
+ count += target[ti].numel()
+ denom = cfg.batch_size * max(cfg.train_steps - first_loss_idx, 1)
+ return loss / denom, correct / max(count, 1)
+
+ def lyapunov_spectrum(self, s: torch.Tensor) -> torch.Tensor:
+ cfg = self.cfg
+ if cfg.n_lyap <= 0:
+ return torch.empty(0, device=self.device)
+
+ x = self.x_init[:, 0].clone()
+ gen = torch.Generator(device=self.device).manual_seed(cfg.seed_ons)
+ q0 = torch.randn(cfg.hidden_size, cfg.n_lyap, generator=gen, device=self.device)
+ q, _ = torch.linalg.qr(q0)
+ accum = torch.zeros(cfg.n_lyap, device=self.device)
+ tsim = 0
+ transient = 10
+ transient_ons = math.ceil(cfg.lyap_steps / 10)
+ diag_idx = torch.arange(cfg.n_lyap, device=self.device)
+
+ for n in range(1, cfg.lyap_steps + 1):
+ x = self.j @ torch.tanh(x) + self.b + (self.ws @ s[n - 1, :, 0]) / cfg.input_dim
+ phi_prime = 1.0 / torch.cosh(x).square()
+ tangent = phi_prime.unsqueeze(1) * self.j
+ q = tangent @ q
+ q, r = torch.linalg.qr(q)
+ if n > transient + transient_ons:
+ accum = accum + r[diag_idx, diag_idx].abs().clamp_min(1e-30).log()
+ tsim += 1
+
+ if tsim == 0:
+ raise ValueError("lyap_steps too short after transient discard")
+ return accum / tsim
+
+
+def evaluate(model: VanillaRNN, cfg: Config, epoch: int) -> dict[str, float]:
+ losses = []
+ accs = []
+ with torch.no_grad():
+ for i in range(cfg.eval_batches):
+ seed = 10_000_000 + epoch * 1000 + i
+ s, target = generate_input_output(
+ cfg.batch_size, cfg.train_steps, seed, cfg.input_dim,
+ cfg.input_scale, cfg.delay, cfg.task, model.device,
+ )
+ loss, acc = model.task_loss_and_accuracy(s, target)
+ losses.append(float(loss.item()))
+ accs.append(acc)
+ return {"loss": float(np.mean(losses)), "accuracy": float(np.mean(accs))}
+
+
+def run(cfg: Config) -> dict:
+ device = torch.device(cfg.device)
+ if cfg.n_lyap > cfg.hidden_size:
+ raise ValueError("n_lyap must be <= hidden_size")
+ if cfg.lyap_steps <= 10 + math.ceil(cfg.lyap_steps / 10):
+ raise ValueError("lyap_steps too short for official transient discard")
+
+ model = VanillaRNN(cfg, device)
+ task_optim = torch.optim.AdamW(
+ model.task_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0
+ )
+ floss_optim = torch.optim.AdamW(
+ model.floss_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0
+ )
+ target = torch.full((cfg.n_lyap,), cfg.lyap_target, device=device)
+ log: dict = {"config": asdict(cfg), "evals": [], "floss": [], "task": []}
+
+ t0 = time.time()
+ inter_count = 0
+ for epoch in range(1, cfg.train_epochs + 1):
+ should_prefloss = epoch == 1 and cfg.pre_epochs > 0
+ should_interfloss = (
+ epoch % cfg.inter_period == 0
+ and inter_count < cfg.max_inter_episodes
+ and cfg.inter_epochs > 0
+ )
+ if should_prefloss or should_interfloss:
+ n_steps = cfg.pre_epochs if should_prefloss else cfg.inter_epochs
+ kind = "pre" if should_prefloss else "inter"
+ if should_interfloss:
+ inter_count += 1
+ for floss_step in range(1, n_steps + 1):
+ seed = cfg.train_epochs * cfg.seed_input + epoch + floss_step
+ s, _ = generate_input_output(
+ 1, cfg.floss_input_steps, seed, cfg.input_dim,
+ cfg.input_scale, cfg.delay, cfg.task, device,
+ )
+ floss_optim.zero_grad(set_to_none=True)
+ spectrum = model.lyapunov_spectrum(s)
+ floss_loss = (spectrum - target).square().mean()
+ floss_loss.backward()
+ floss_optim.step()
+ if floss_step == 1 or floss_step == n_steps or floss_step % cfg.log_every_floss == 0:
+ rec = {
+ "epoch": epoch,
+ "kind": kind,
+ "floss_step": floss_step,
+ "loss": float(floss_loss.item()),
+ "lambda_mean": float(spectrum.detach().mean().item()),
+ "lambda_1": float(spectrum.detach()[0].item()),
+ "elapsed": time.time() - t0,
+ }
+ log["floss"].append(rec)
+ print(
+ f"F {kind} epoch={epoch} step={floss_step}/{n_steps} "
+ f"loss={rec['loss']:.6f} lambda1={rec['lambda_1']:+.4f}",
+ flush=True,
+ )
+
+ seed = cfg.train_epochs * cfg.seed_input + epoch
+ s, target_batch = generate_input_output(
+ cfg.batch_size, cfg.train_steps, seed, cfg.input_dim,
+ cfg.input_scale, cfg.delay, cfg.task, device,
+ )
+ task_optim.zero_grad(set_to_none=True)
+ task_loss, task_acc = model.task_loss_and_accuracy(s, target_batch)
+ task_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.task_parameters(), 0.03)
+ task_optim.step()
+
+ if epoch == 1 or epoch == cfg.train_epochs or epoch % cfg.eval_every == 0:
+ eval_metrics = evaluate(model, cfg, epoch)
+ rec = {
+ "epoch": epoch,
+ "train_loss": float(task_loss.item()),
+ "train_accuracy": float(task_acc),
+ "eval_loss": eval_metrics["loss"],
+ "eval_accuracy": eval_metrics["accuracy"],
+ "elapsed": time.time() - t0,
+ }
+ log["evals"].append(rec)
+ print(
+ f"T epoch={epoch}/{cfg.train_epochs} loss={rec['train_loss']:.4f} "
+ f"train_acc={rec['train_accuracy']:.3f} eval_acc={rec['eval_accuracy']:.3f}",
+ flush=True,
+ )
+
+ out_path = Path(cfg.out)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ out_path.write_text(json.dumps(log, indent=2))
+ return log
+
+
+def parse_args() -> Config:
+ parser = argparse.ArgumentParser()
+ defaults = Config()
+ for field_name, default in asdict(defaults).items():
+ arg = "--" + field_name.replace("_", "-")
+ if isinstance(default, bool):
+ parser.add_argument(arg, action=argparse.BooleanOptionalAction, default=default)
+ else:
+ parser.add_argument(arg, type=type(default), default=default)
+ ns = parser.parse_args()
+ return Config(**vars(ns))
+
+
+if __name__ == "__main__":
+ run(parse_args())