"""PyTorch port of Rainer Engelken's GradientFlossing_ExampleCode.jl. This script is intentionally a vanilla-RNN delayed-XOR reproduction scaffold, not an HRM/TRM experiment. It keeps the algorithmic choices that matter for checking gradient flossing itself: - flossing is a separate phase, not mixed with task loss; - flossing optimizes only input/recurrent/bias parameters, not readout; - flossing uses mean((lambda_i - target)^2); - flossing batch size is one, as in the Julia code; - QR is differentiable and participates in the backward pass; - Lyapunov accumulation discards the initial transient. Reference: research/flossing/external/GradientFlossing/GradientFlossing_ExampleCode.jl """ from __future__ import annotations import argparse import json import math import time from dataclasses import asdict, dataclass from pathlib import Path import numpy as np import torch import torch.nn.functional as F def recabs(bits: np.ndarray) -> np.ndarray: out = bits[0] for bit in bits[1:]: out = np.abs(out - bit) return out def generate_input_output( batch_size: int, steps: int, seed: int, input_dim: int, input_scale: float, delay: int, task: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: """Port of generateInputOutput in GradientFlossing_XOR.jl. Returns tensors shaped as Julia uses conceptually: s: [T, S, B] target: [T, 1, B] """ rng = np.random.default_rng(seed) raw_steps = steps + delay s = rng.integers(0, 2, size=(raw_steps, input_dim, batch_size), dtype=np.int64).astype(np.float32) target = np.zeros((raw_steps, 1, batch_size), dtype=np.float32) if task == 0: for i in range(delay, raw_steps): target[i, :, :] = s[i - delay, :1, :] else: h = abs(task) span = delay * (2**h) for i in range(span, raw_steps): indices = [i - delay * j for j in range(1, 2**h + 1)] values = np.flip(s[indices, 0, :], axis=0) target[i, 0, :] = recabs(values) sx = torch.from_numpy(s[delay:] * np.float32(input_scale)).to(device) ty = torch.from_numpy(target[delay:]).to(device) return sx, ty @dataclass class Config: hidden_size: int = 80 train_epochs: int = 3001 inter_period: int = 100 inter_epochs: int = 500 pre_epochs: int = 500 max_inter_episodes: int = 2 batch_size: int = 16 input_dim: int = 1 train_steps: int = 300 lyap_steps: int = 55 floss_input_steps: int = 300 seed_ic: int = 1 seed_input: int = 1 seed_net: int = 1 seed_ons: int = 1 lr: float = 1e-3 beta1: float = 0.9 beta2: float = 0.999 init_type: int = 1 recurrent_gain: float = 1.0 recurrent_mean_gain: float = 0.0 input_scale: float = 1.0 delay: int = 10 ws_std: float = 1.0 ws_mean: float = 0.0 wr_std: float = 1.0 wr_mean: float = 0.0 b_std: float = 0.1 b_mean: float = 0.0 n_lyap: int = 75 task: int = -1 lyap_target: float = 0.0 eval_every: int = 100 eval_batches: int = 4 log_every_floss: int = 50 device: str = "cpu" out: str = "research/flossing/engelken_python/run.json" class VanillaRNN: def __init__(self, cfg: Config, device: torch.device): self.cfg = cfg self.device = device torch.manual_seed(cfg.seed_net) n = cfg.hidden_size self.ws = torch.nn.Parameter( cfg.ws_std * torch.randn(n, cfg.input_dim, device=device) + cfg.ws_mean ) self.wr = torch.nn.Parameter( cfg.wr_std * torch.randn(1, n, device=device) + cfg.wr_mean ) self.b = torch.nn.Parameter( cfg.b_mean + cfg.b_std * torch.rand(n, device=device) ) self.offset = torch.nn.Parameter(torch.tensor([0.001], device=device)) if cfg.init_type == 1: j = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n) if cfg.recurrent_mean_gain != 0: j = j + cfg.recurrent_mean_gain / n elif cfg.init_type == 3: j0 = cfg.recurrent_gain * torch.randn(n, n, device=device) / math.sqrt(n) q, _ = torch.linalg.qr(j0) j = cfg.recurrent_gain * q else: raise ValueError(f"unsupported init_type={cfg.init_type}") self.j = torch.nn.Parameter(j) self.x_init = torch.nn.Parameter(self._relax_initial_condition(cfg.seed_ic, 100)) def _relax_initial_condition(self, seed: int, steps: int) -> torch.Tensor: gen = torch.Generator(device=self.device).manual_seed(seed) x = torch.randn(self.cfg.hidden_size, self.cfg.batch_size, generator=gen, device=self.device) with torch.no_grad(): for _ in range(steps): x = self.j.detach() @ torch.tanh(x) + self.b.detach().unsqueeze(1) return x def task_parameters(self) -> list[torch.nn.Parameter]: return [self.ws, self.j, self.wr, self.b, self.offset, self.x_init] def floss_parameters(self) -> list[torch.nn.Parameter]: return [self.ws, self.j, self.b] def task_loss_and_accuracy(self, s: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, float]: cfg = self.cfg first_loss_idx = 2 * cfg.delay if cfg.task > 0 else cfg.delay * (2 ** abs(cfg.task)) # Julia is 1-indexed and starts loss at ti >= firstlossidx. first_loss_idx = max(first_loss_idx - 1, 0) x = self.x_init loss = torch.zeros((), device=self.device) correct = 0.0 count = 0 for ti in range(cfg.train_steps): x = self.j @ torch.tanh(x) + self.b.unsqueeze(1) + (self.ws @ s[ti]) / cfg.input_dim if ti >= first_loss_idx: logits = self.wr @ x + self.offset.unsqueeze(1) loss = loss + F.binary_cross_entropy_with_logits(logits, target[ti], reduction="sum") pred = (torch.sigmoid(logits) >= 0.5).to(target.dtype) correct += (pred == target[ti]).float().sum().item() count += target[ti].numel() denom = cfg.batch_size * max(cfg.train_steps - first_loss_idx, 1) return loss / denom, correct / max(count, 1) def lyapunov_spectrum(self, s: torch.Tensor) -> torch.Tensor: cfg = self.cfg if cfg.n_lyap <= 0: return torch.empty(0, device=self.device) x = self.x_init[:, 0].clone() gen = torch.Generator(device=self.device).manual_seed(cfg.seed_ons) q0 = torch.randn(cfg.hidden_size, cfg.n_lyap, generator=gen, device=self.device) q, _ = torch.linalg.qr(q0) accum = torch.zeros(cfg.n_lyap, device=self.device) tsim = 0 transient = 10 transient_ons = math.ceil(cfg.lyap_steps / 10) diag_idx = torch.arange(cfg.n_lyap, device=self.device) for n in range(1, cfg.lyap_steps + 1): x = self.j @ torch.tanh(x) + self.b + (self.ws @ s[n - 1, :, 0]) / cfg.input_dim phi_prime = 1.0 / torch.cosh(x).square() tangent = phi_prime.unsqueeze(1) * self.j q = tangent @ q q, r = torch.linalg.qr(q) if n > transient + transient_ons: accum = accum + r[diag_idx, diag_idx].abs().clamp_min(1e-30).log() tsim += 1 if tsim == 0: raise ValueError("lyap_steps too short after transient discard") return accum / tsim def evaluate(model: VanillaRNN, cfg: Config, epoch: int) -> dict[str, float]: losses = [] accs = [] with torch.no_grad(): for i in range(cfg.eval_batches): seed = 10_000_000 + epoch * 1000 + i s, target = generate_input_output( cfg.batch_size, cfg.train_steps, seed, cfg.input_dim, cfg.input_scale, cfg.delay, cfg.task, model.device, ) loss, acc = model.task_loss_and_accuracy(s, target) losses.append(float(loss.item())) accs.append(acc) return {"loss": float(np.mean(losses)), "accuracy": float(np.mean(accs))} def run(cfg: Config) -> dict: device = torch.device(cfg.device) if cfg.n_lyap > cfg.hidden_size: raise ValueError("n_lyap must be <= hidden_size") if cfg.lyap_steps <= 10 + math.ceil(cfg.lyap_steps / 10): raise ValueError("lyap_steps too short for official transient discard") model = VanillaRNN(cfg, device) task_optim = torch.optim.AdamW( model.task_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0 ) floss_optim = torch.optim.AdamW( model.floss_parameters(), lr=cfg.lr, betas=(cfg.beta1, cfg.beta2), weight_decay=0.0 ) target = torch.full((cfg.n_lyap,), cfg.lyap_target, device=device) log: dict = {"config": asdict(cfg), "evals": [], "floss": [], "task": []} t0 = time.time() inter_count = 0 for epoch in range(1, cfg.train_epochs + 1): should_prefloss = epoch == 1 and cfg.pre_epochs > 0 should_interfloss = ( epoch % cfg.inter_period == 0 and inter_count < cfg.max_inter_episodes and cfg.inter_epochs > 0 ) if should_prefloss or should_interfloss: n_steps = cfg.pre_epochs if should_prefloss else cfg.inter_epochs kind = "pre" if should_prefloss else "inter" if should_interfloss: inter_count += 1 for floss_step in range(1, n_steps + 1): seed = cfg.train_epochs * cfg.seed_input + epoch + floss_step s, _ = generate_input_output( 1, cfg.floss_input_steps, seed, cfg.input_dim, cfg.input_scale, cfg.delay, cfg.task, device, ) floss_optim.zero_grad(set_to_none=True) spectrum = model.lyapunov_spectrum(s) floss_loss = (spectrum - target).square().mean() floss_loss.backward() floss_optim.step() if floss_step == 1 or floss_step == n_steps or floss_step % cfg.log_every_floss == 0: rec = { "epoch": epoch, "kind": kind, "floss_step": floss_step, "loss": float(floss_loss.item()), "lambda_mean": float(spectrum.detach().mean().item()), "lambda_1": float(spectrum.detach()[0].item()), "elapsed": time.time() - t0, } log["floss"].append(rec) print( f"F {kind} epoch={epoch} step={floss_step}/{n_steps} " f"loss={rec['loss']:.6f} lambda1={rec['lambda_1']:+.4f}", flush=True, ) seed = cfg.train_epochs * cfg.seed_input + epoch s, target_batch = generate_input_output( cfg.batch_size, cfg.train_steps, seed, cfg.input_dim, cfg.input_scale, cfg.delay, cfg.task, device, ) task_optim.zero_grad(set_to_none=True) task_loss, task_acc = model.task_loss_and_accuracy(s, target_batch) task_loss.backward() torch.nn.utils.clip_grad_norm_(model.task_parameters(), 0.03) task_optim.step() if epoch == 1 or epoch == cfg.train_epochs or epoch % cfg.eval_every == 0: eval_metrics = evaluate(model, cfg, epoch) rec = { "epoch": epoch, "train_loss": float(task_loss.item()), "train_accuracy": float(task_acc), "eval_loss": eval_metrics["loss"], "eval_accuracy": eval_metrics["accuracy"], "elapsed": time.time() - t0, } log["evals"].append(rec) print( f"T epoch={epoch}/{cfg.train_epochs} loss={rec['train_loss']:.4f} " f"train_acc={rec['train_accuracy']:.3f} eval_acc={rec['eval_accuracy']:.3f}", flush=True, ) out_path = Path(cfg.out) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(log, indent=2)) return log def parse_args() -> Config: parser = argparse.ArgumentParser() defaults = Config() for field_name, default in asdict(defaults).items(): arg = "--" + field_name.replace("_", "-") if isinstance(default, bool): parser.add_argument(arg, action=argparse.BooleanOptionalAction, default=default) else: parser.add_argument(arg, type=type(default), default=default) ns = parser.parse_args() return Config(**vars(ns)) if __name__ == "__main__": run(parse_args())