From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- research/flossing/engelken_python_flossing.py | 336 ++++++++++++++++++++++++++ 1 file changed, 336 insertions(+) create mode 100644 research/flossing/engelken_python_flossing.py (limited to 'research/flossing/engelken_python_flossing.py') diff --git a/research/flossing/engelken_python_flossing.py b/research/flossing/engelken_python_flossing.py new file mode 100644 index 0000000..b83be2f --- /dev/null +++ b/research/flossing/engelken_python_flossing.py @@ -0,0 +1,336 @@ +"""PyTorch port of Rainer Engelken's GradientFlossing_ExampleCode.jl. + +This script is intentionally a vanilla-RNN delayed-XOR reproduction scaffold, +not an HRM/TRM experiment. It keeps the algorithmic choices that matter for +checking gradient flossing itself: + + - flossing is a separate phase, not mixed with task loss; + - flossing optimizes only input/recurrent/bias parameters, not readout; + - flossing uses mean((lambda_i - target)^2); + - flossing batch size is one, as in the Julia code; + - QR is differentiable and participates in the backward pass; + - Lyapunov accumulation discards the initial transient. + +Reference: + research/flossing/external/GradientFlossing/GradientFlossing_ExampleCode.jl +""" + +from __future__ import annotations + +import argparse +import json +import math +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + + +def recabs(bits: np.ndarray) -> np.ndarray: + out = bits[0] + for bit in bits[1:]: + out = np.abs(out - bit) + return out + + +def generate_input_output( + batch_size: int, + steps: int, + seed: int, + input_dim: int, + input_scale: float, + delay: int, + task: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Port of generateInputOutput in GradientFlossing_XOR.jl. + + Returns tensors shaped as Julia uses conceptually: + s: [T, S, B] + target: [T, 1, B] + """ + rng = np.random.default_rng(seed) + raw_steps = steps + delay + s = rng.integers(0, 2, size=(raw_steps, input_dim, batch_size), dtype=np.int64).astype(np.float32) + target = np.zeros((raw_steps, 1, batch_size), dtype=np.float32) + + if task == 0: + for i in range(delay, raw_steps): + target[i, :, :] = s[i - delay, :1, :] + else: + h = abs(task) + span = delay * (2**h) + for i in range(span, raw_steps): + indices = [i - delay * j for j in range(1, 2**h + 1)] + values = np.flip(s[indices, 0, :], axis=0) + target[i, 0, :] = recabs(values) + + sx = torch.from_numpy(s[delay:] * np.float32(input_scale)).to(device) + ty = torch.from_numpy(target[delay:]).to(device) + return sx, ty + + +@dataclass +class Config: + hidden_size: int = 80 + train_epochs: int = 3001 + inter_period: int = 100 + inter_epochs: int = 500 + pre_epochs: int = 500 + max_inter_episodes: int = 2 + batch_size: int = 16 + input_dim: int = 1 + train_steps: int = 300 + lyap_steps: int = 55 + floss_input_steps: int = 300 + seed_ic: int = 1 + seed_input: int = 1 + seed_net: int = 1 + seed_ons: int = 1 + lr: float = 1e-3 + beta1: float = 0.9 + beta2: float = 0.999 + init_type: int = 1 + recurrent_gain: float = 1.0 + recurrent_mean_gain: float = 0.0 + input_scale: float = 1.0 + delay: int = 10 + ws_std: float = 1.0 + ws_mean: float = 0.0 + wr_std: float = 1.0 + wr_mean: float = 0.0 + b_std: float = 0.1 + b_mean: float = 0.0 + n_lyap: int = 75 + task: int = -1 + lyap_target: float = 0.0 + eval_every: int = 100 + eval_batches: int = 4 + log_every_floss: int = 50 + device: str = "cpu" + out: str = "research/flossing/engelken_python/run.json" + + +class VanillaRNN: + def __init__(self, cfg: Config, device: torch.device): + self.cfg = cfg + self.device = device + torch.manual_seed(cfg.seed_net) + n = cfg.hidden_size + + self.ws = torch.nn.Parameter( + cfg.ws_std * torch.randn(n, cfg.input_dim, device=device) + cfg.ws_mean + ) + self.wr = torch.nn.Parameter( + cfg.wr_std * torch.randn(1, n, device=device) + cfg.wr_mean + ) + self.b = torch.nn.Parameter( + cfg.b_mean + cfg.b_std * torch.rand(n, device=device) + ) + self.offset = torch.nn.Parameter(torch.tensor([0.001], device=device)) + + if cfg.init_type == 1: + j = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n) + if cfg.recurrent_mean_gain != 0: + j = j + cfg.recurrent_mean_gain / n + elif cfg.init_type == 3: + j0 = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n) + q, _ = torch.linalg.qr(j0) + j = cfg.recurrent_gain * q + else: + raise ValueError(f"unsupported init_type={cfg.init_type}") + self.j = torch.nn.Parameter(j) + + self.x_init = torch.nn.Parameter(self._relax_initial_condition(cfg.seed_ic, 100)) + + def _relax_initial_condition(self, seed: int, steps: int) -> torch.Tensor: + gen = torch.Generator(device=self.device).manual_seed(seed) + x = torch.randn(self.cfg.hidden_size, self.cfg.batch_size, generator=gen, device=self.device) + with torch.no_grad(): + for _ in range(steps): + x = self.j.detach() @ torch.tanh(x) + self.b.detach().unsqueeze(1) + return x + + def task_parameters(self) -> list[torch.nn.Parameter]: + return [self.ws, self.j, self.wr, self.b, self.offset, self.x_init] + + def floss_parameters(self) -> list[torch.nn.Parameter]: + return [self.ws, self.j, self.b] + + def task_loss_and_accuracy(self, s: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, float]: + cfg = self.cfg + first_loss_idx = 2 * cfg.delay if cfg.task > 0 else cfg.delay * (2 ** abs(cfg.task)) + # Julia is 1-indexed and starts loss at ti >= firstlossidx. + first_loss_idx = max(first_loss_idx - 1, 0) + x = self.x_init + loss = torch.zeros((), device=self.device) + correct = 0.0 + count = 0 + for ti in range(cfg.train_steps): + x = self.j @ torch.tanh(x) + self.b.unsqueeze(1) + (self.ws @ s[ti]) / cfg.input_dim + if ti >= first_loss_idx: + logits = self.wr @ x + self.offset.unsqueeze(1) + loss = loss + F.binary_cross_entropy_with_logits(logits, target[ti], reduction="sum") + pred = (torch.sigmoid(logits) >= 0.5).to(target.dtype) + correct += (pred == target[ti]).float().sum().item() + count += target[ti].numel() + denom = cfg.batch_size * max(cfg.train_steps - first_loss_idx, 1) + return loss / denom, correct / max(count, 1) + + def lyapunov_spectrum(self, s: torch.Tensor) -> torch.Tensor: + cfg = self.cfg + if cfg.n_lyap <= 0: + return torch.empty(0, device=self.device) + + x = self.x_init[:, 0].clone() + gen = torch.Generator(device=self.device).manual_seed(cfg.seed_ons) + q0 = torch.randn(cfg.hidden_size, cfg.n_lyap, generator=gen, device=self.device) + q, _ = torch.linalg.qr(q0) + accum = torch.zeros(cfg.n_lyap, device=self.device) + tsim = 0 + transient = 10 + transient_ons = math.ceil(cfg.lyap_steps / 10) + diag_idx = torch.arange(cfg.n_lyap, device=self.device) + + for n in range(1, cfg.lyap_steps + 1): + x = self.j @ torch.tanh(x) + self.b + (self.ws @ s[n - 1, :, 0]) / cfg.input_dim + phi_prime = 1.0 / torch.cosh(x).square() + tangent = phi_prime.unsqueeze(1) * self.j + q = tangent @ q + q, r = torch.linalg.qr(q) + if n > transient + transient_ons: + accum = accum + r[diag_idx, diag_idx].abs().clamp_min(1e-30).log() + tsim += 1 + + if tsim == 0: + raise ValueError("lyap_steps too short after transient discard") + return accum / tsim + + +def evaluate(model: VanillaRNN, cfg: Config, epoch: int) -> dict[str, float]: + losses = [] + accs = [] + with torch.no_grad(): + for i in range(cfg.eval_batches): + seed = 10_000_000 + epoch * 1000 + i + s, target = generate_input_output( + cfg.batch_size, cfg.train_steps, seed, cfg.input_dim, + cfg.input_scale, cfg.delay, cfg.task, model.device, + ) + loss, acc = model.task_loss_and_accuracy(s, target) + losses.append(float(loss.item())) + accs.append(acc) + return {"loss": float(np.mean(losses)), "accuracy": float(np.mean(accs))} + + +def run(cfg: Config) -> dict: + device = torch.device(cfg.device) + if cfg.n_lyap > cfg.hidden_size: + raise ValueError("n_lyap must be <= hidden_size") + if cfg.lyap_steps <= 10 + math.ceil(cfg.lyap_steps / 10): + raise ValueError("lyap_steps too short for official transient discard") + + model = VanillaRNN(cfg, device) + task_optim = torch.optim.AdamW( + model.task_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0 + ) + floss_optim = torch.optim.AdamW( + model.floss_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0 + ) + target = torch.full((cfg.n_lyap,), cfg.lyap_target, device=device) + log: dict = {"config": asdict(cfg), "evals": [], "floss": [], "task": []} + + t0 = time.time() + inter_count = 0 + for epoch in range(1, cfg.train_epochs + 1): + should_prefloss = epoch == 1 and cfg.pre_epochs > 0 + should_interfloss = ( + epoch % cfg.inter_period == 0 + and inter_count < cfg.max_inter_episodes + and cfg.inter_epochs > 0 + ) + if should_prefloss or should_interfloss: + n_steps = cfg.pre_epochs if should_prefloss else cfg.inter_epochs + kind = "pre" if should_prefloss else "inter" + if should_interfloss: + inter_count += 1 + for floss_step in range(1, n_steps + 1): + seed = cfg.train_epochs * cfg.seed_input + epoch + floss_step + s, _ = generate_input_output( + 1, cfg.floss_input_steps, seed, cfg.input_dim, + cfg.input_scale, cfg.delay, cfg.task, device, + ) + floss_optim.zero_grad(set_to_none=True) + spectrum = model.lyapunov_spectrum(s) + floss_loss = (spectrum - target).square().mean() + floss_loss.backward() + floss_optim.step() + if floss_step == 1 or floss_step == n_steps or floss_step % cfg.log_every_floss == 0: + rec = { + "epoch": epoch, + "kind": kind, + "floss_step": floss_step, + "loss": float(floss_loss.item()), + "lambda_mean": float(spectrum.detach().mean().item()), + "lambda_1": float(spectrum.detach()[0].item()), + "elapsed": time.time() - t0, + } + log["floss"].append(rec) + print( + f"F {kind} epoch={epoch} step={floss_step}/{n_steps} " + f"loss={rec['loss']:.6f} lambda1={rec['lambda_1']:+.4f}", + flush=True, + ) + + seed = cfg.train_epochs * cfg.seed_input + epoch + s, target_batch = generate_input_output( + cfg.batch_size, cfg.train_steps, seed, cfg.input_dim, + cfg.input_scale, cfg.delay, cfg.task, device, + ) + task_optim.zero_grad(set_to_none=True) + task_loss, task_acc = model.task_loss_and_accuracy(s, target_batch) + task_loss.backward() + torch.nn.utils.clip_grad_norm_(model.task_parameters(), 0.03) + task_optim.step() + + if epoch == 1 or epoch == cfg.train_epochs or epoch % cfg.eval_every == 0: + eval_metrics = evaluate(model, cfg, epoch) + rec = { + "epoch": epoch, + "train_loss": float(task_loss.item()), + "train_accuracy": float(task_acc), + "eval_loss": eval_metrics["loss"], + "eval_accuracy": eval_metrics["accuracy"], + "elapsed": time.time() - t0, + } + log["evals"].append(rec) + print( + f"T epoch={epoch}/{cfg.train_epochs} loss={rec['train_loss']:.4f} " + f"train_acc={rec['train_accuracy']:.3f} eval_acc={rec['eval_accuracy']:.3f}", + flush=True, + ) + + out_path = Path(cfg.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(log, indent=2)) + return log + + +def parse_args() -> Config: + parser = argparse.ArgumentParser() + defaults = Config() + for field_name, default in asdict(defaults).items(): + arg = "--" + field_name.replace("_", "-") + if isinstance(default, bool): + parser.add_argument(arg, action=argparse.BooleanOptionalAction, default=default) + else: + parser.add_argument(arg, type=type(default), default=default) + ns = parser.parse_args() + return Config(**vars(ns)) + + +if __name__ == "__main__": + run(parse_args()) -- cgit v1.2.3