"""SRM-Joint-AOL v1 — Stable Recursive Model (forked from hrm_act_v1.py). Replaces HRM's separate H_level / L_level transformer stacks with ONE joint operator T on state z = (h, l) that is provably contractive under weighted P-norm ||z||²_P = ||h||² + η||l||² with Lipschitz constant ≤ κ ∈ (0.85, 0.95). Key replacement vs HRM: - HierarchicalReasoningModel_ACTV1Block (attn + SwiGLU) → StableRecursionModel_ACTV1Block (joint SRM step on (h, l)) - ReasoningModule wraps n_iters joint updates instead of separate H/L cycles. Lipschitz analysis (per step, in P-norm): Lip_P(T) ≤ (1-α) + α · κ < 1 ⇒ joint top-1 Lyapunov per micro-step: λ_1 ≤ log((1-α) + α·κ) < 0 ARCHITECTURE (one joint step): z = concat(h + b_in_h(x), √η · (l + b_in_l(x))) # join with input bias ψ = AOL_Block(z) # Lip_P(ψ) ≤ 1 ψ_h, ψ_l_scaled = split(ψ); ψ_l = ψ_l_scaled / √η Az_h = a_HH·ψ_h + a_HL·(U_HL·ψ_l) # gain row sum ≤ κ Az_l = a_LH·(U_LH·ψ_h) + a_LL·ψ_l h_new = (1-α)·h + α·Az_h + b_out_h(x) l_new = (1-α)·l + α·Az_l + b_out_l(x) REUSED FROM HRM: - ACT framework (q_head, halt logic) - CastedEmbedding/CastedLinear (bf16-safe linears) - CastedSparseEmbedding (puzzle_emb) - 1-step grad / DEQ-style truncation """ from typing import Tuple, Dict from dataclasses import dataclass import math import torch import torch.nn.functional as F from torch import nn from pydantic import BaseModel from models.common import trunc_normal_init_ from models.layers import CastedEmbedding, CastedLinear from models.sparse_embedding import CastedSparseEmbedding # ============================================================================= # Approximately 1-Lipschitz primitives # Normalization (AOL) and orthogonalization (Cayley) are computed in float32, # then cast to forward dtype (bf16 by default). The bound is *exact in fp32* # but only approximate after cast — bf16 rounding introduces a small error # that accumulates over n_aol_layers matmuls. Empirically the margin to the # theoretical κ-bound is large (~5×), so this is fine in practice, but the # guarantee is not strict. For applications where strictness matters, run # the bounded operators in float32. # ============================================================================= class AOLLinear(nn.Module): """≤1-Lipschitz linear layer via AOL (Prach & Lampert 2022) rescaling. Given W ∈ R^(out × in), let A = W^T W (symmetric PSD). Define D_jj = 1 / √(Σ_i |A_ij| + eps); set W̃ = W · diag(D). Then ||W̃ x||_2 ≤ ||x||_2 in float32 (Prach & Lampert Theorem 1). Bound is approximate (not exact) under bf16 due to rounding in W·diag(D) and the subsequent matmul. Bias is unconstrained (shift only, doesn't affect Lipschitz w.r.t. input). """ def __init__(self, in_dim: int, out_dim: int, bias: bool = True, cast_to: torch.dtype = torch.bfloat16, eps: float = 1e-6): super().__init__() std = 1.0 / math.sqrt(in_dim) self.W = nn.Parameter(torch.randn(out_dim, in_dim) * std) self.b = nn.Parameter(torch.zeros(out_dim)) if bias else None self.cast_to = cast_to self.eps = eps def normalized_weight(self) -> torch.Tensor: W32 = self.W.float() WTW = W32.t() @ W32 col_abs_sum = WTW.abs().sum(dim=0) scale = torch.rsqrt(col_abs_sum + self.eps) return (W32 * scale.unsqueeze(0)).to(self.cast_to) def forward(self, x: torch.Tensor) -> torch.Tensor: W = self.normalized_weight() out = F.linear(x, W) if self.b is not None: out = out + self.b.to(out.dtype) return out class AOLBlock(nn.Module): """Stack of AOLLinear with 1-Lipschitz activation (ReLU) between layers. Composition of 1-Lipschitz maps is 1-Lipschitz. SiLU/GELU NOT allowed (max derivative > 1 would break the bound). """ def __init__(self, dim: int, n_layers: int = 2, cast_to: torch.dtype = torch.bfloat16): super().__init__() assert n_layers >= 1 self.layers = nn.ModuleList([ AOLLinear(dim, dim, bias=True, cast_to=cast_to) for _ in range(n_layers) ]) def forward(self, x: torch.Tensor) -> torch.Tensor: for i, layer in enumerate(self.layers): x = layer(x) if i < len(self.layers) - 1: x = F.relu(x) return x class CayleyOrthogonal(nn.Module): """Approximately orthogonal Q ∈ R^(d × d) via Cayley transform. Q = (I + S)^{-1}(I - S) where S = (A - A^T)/2 is skew-symmetric. Since (I+S) and (I-S) commute (both polynomials in S), the form is also Q = (I - S)(I + S)^{-1}. Q^T Q = I exactly in float32 — approximate after cast to bf16. Solve done in float32 for numerical stability. NOTE: torch.linalg.solve may not be fullgraph-compile friendly. Test before enabling torch.compile / FSDP. """ def __init__(self, dim: int, cast_to: torch.dtype = torch.bfloat16): super().__init__() self.A = nn.Parameter(torch.randn(dim, dim) * (1.0 / math.sqrt(dim))) self.dim = dim self.cast_to = cast_to self.register_buffer("I", torch.eye(dim), persistent=False) def forward(self) -> torch.Tensor: A32 = self.A.float() S = 0.5 * (A32 - A32.t()) I = self.I.float() Q = torch.linalg.solve(I + S, I - S) return Q.to(self.cast_to) class BlockGain(nn.Module): """Block gain matrix A with row sums ≤ κ under weighted P-norm. H row entries (P-normalized): [a_HH, √η · a_HL], sum = κ L row entries (P-normalized): [(1/√η) · a_LH, a_LL], sum = κ Parameterized via softmax × κ ⇒ exact equality (saturation). """ def __init__(self, kappa: float = 0.9, eta: float = 1.0, init_diag: float = 3.0): super().__init__() self.kappa = kappa self.eta = eta # init_diag=3.0 → softmax([3, 0]) ≈ [0.953, 0.047] ⇒ ~5% cross-coupling at start # (init_diag=1.0 was too weak — gave 27% cross-coupling; init_diag=3.0 truly minimal) self.logits_H = nn.Parameter(torch.tensor([init_diag, 0.0])) self.logits_L = nn.Parameter(torch.tensor([0.0, init_diag])) def forward(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: sqrt_eta = math.sqrt(self.eta) gH = self.kappa * F.softmax(self.logits_H.float(), dim=0) a_HH, a_HL_scaled = gH[0], gH[1] a_HL = a_HL_scaled / sqrt_eta gL = self.kappa * F.softmax(self.logits_L.float(), dim=0) a_LH_scaled, a_LL = gL[0], gL[1] a_LH = a_LH_scaled * sqrt_eta return a_HH, a_HL, a_LH, a_LL class AOLTokenMixer(nn.Module): """1-Lipschitz mixing across token (seq) and channel dims via AOL. Pipeline for x of shape (B, seq, dim): 1) Channel mix (AOL across `dim`) 2) ReLU 3) Token mix (AOL across `seq`, applied after transpose) 4) ReLU Composition of 1-Lipschitz maps ⇒ Lip ≤ 1. """ def __init__(self, seq_len: int, dim: int, n_layers: int = 1, cast_to: torch.dtype = torch.bfloat16): super().__init__() self.channel_mix = AOLBlock(dim=dim, n_layers=n_layers, cast_to=cast_to) self.token_mix = AOLBlock(dim=seq_len, n_layers=n_layers, cast_to=cast_to) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.channel_mix(x) y = F.relu(y) y = y.transpose(-2, -1) # (B, dim, seq) y = self.token_mix(y) y = y.transpose(-2, -1) # (B, seq, dim) return y # ============================================================================= # Carry types and config # ============================================================================= @dataclass class StableRecursionModel_ACTV1InnerCarry: z_H: torch.Tensor z_L: torch.Tensor @dataclass class StableRecursionModel_ACTV1Carry: inner_carry: StableRecursionModel_ACTV1InnerCarry steps: torch.Tensor halted: torch.Tensor current_data: Dict[str, torch.Tensor] class StableRecursionModel_ACTV1Config(BaseModel): batch_size: int seq_len: int puzzle_emb_ndim: int = 0 num_puzzle_identifiers: int vocab_size: int # SRM-specific n_iters: int = 12 # joint micro-steps per ACT step n_aol_layers: int = 2 # depth of ψ AOL block kappa: float = 0.9 eta: float = 1.0 alpha: float = 1.0 # Shared with HRM hidden_size: int halt_max_steps: int halt_exploration_prob: float = 0.1 forward_dtype: str = "bfloat16" # ============================================================================= # SRM joint step (replaces HRM's H/L transformer blocks) # ============================================================================= class StableRecursionModel_ACTV1Block(nn.Module): """One SRM joint step on (h, l). Per-step Lip_P ≤ (1-α) + α·κ < 1.""" def __init__(self, config: StableRecursionModel_ACTV1Config, seq_full: int) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size self.kappa = config.kappa self.eta = config.eta self.alpha = config.alpha cast = getattr(torch, config.forward_dtype) joint_dim = 2 * config.hidden_size self.psi = AOLTokenMixer(seq_len=seq_full, dim=joint_dim, n_layers=config.n_aol_layers, cast_to=cast) self.gain = BlockGain(kappa=config.kappa, eta=config.eta) self.U_HL = CayleyOrthogonal(config.hidden_size, cast_to=cast) self.U_LH = CayleyOrthogonal(config.hidden_size, cast_to=cast) # Input biases — unconstrained (only affect Lip w.r.t. x, not w.r.t. z; # x is fixed across recursion so doesn't affect Lyapunov) self.bias_in_h = CastedLinear(config.hidden_size, config.hidden_size, bias=True) self.bias_in_l = CastedLinear(config.hidden_size, config.hidden_size, bias=True) self.bias_out_h = CastedLinear(config.hidden_size, config.hidden_size, bias=True) self.bias_out_l = CastedLinear(config.hidden_size, config.hidden_size, bias=True) def forward(self, h: torch.Tensor, l: torch.Tensor, input_emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: sqrt_eta = math.sqrt(self.eta) # 1. Join with input bias h_in = h + self.bias_in_h(input_emb) l_in = l + self.bias_in_l(input_emb) z = torch.cat([h_in, sqrt_eta * l_in], dim=-1) # (B, seq, 2h) # 2. 1-Lipschitz feature map ψ psi = self.psi(z) psi_h, psi_l_scaled = psi.chunk(2, dim=-1) psi_l = psi_l_scaled / sqrt_eta # 3. Block-gain matrix A (κ-bounded row sums) a_HH, a_HL, a_LH, a_LL = self.gain() U_HL = self.U_HL() U_LH = self.U_LH() psi_l_mix = F.linear(psi_l, U_HL) # ≡ psi_l @ U_HL.T psi_h_mix = F.linear(psi_h, U_LH) Az_h = a_HH * psi_h + a_HL * psi_l_mix Az_l = a_LH * psi_h_mix + a_LL * psi_l # 4. Damped update + output bias h_new = (1.0 - self.alpha) * h + self.alpha * Az_h + self.bias_out_h(input_emb) l_new = (1.0 - self.alpha) * l + self.alpha * Az_l + self.bias_out_l(input_emb) return h_new, l_new # ============================================================================= # Inner model + ACT wrapper (matches HRM_ACTV1 interface) # ============================================================================= class StableRecursionModel_ACTV1_Inner(nn.Module): def __init__(self, config: StableRecursionModel_ACTV1Config) -> None: super().__init__() self.config = config self.forward_dtype = getattr(torch, config.forward_dtype) self.embed_scale = math.sqrt(config.hidden_size) embed_init_std = 1.0 / self.embed_scale self.embed_tokens = CastedEmbedding(config.vocab_size, config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) self.lm_head = CastedLinear(config.hidden_size, config.vocab_size, bias=False) self.q_head = CastedLinear(config.hidden_size, 2, bias=True) with torch.no_grad(): self.q_head.weight.zero_() self.q_head.bias.fill_(-5) self.puzzle_emb_len = -(config.puzzle_emb_ndim // -config.hidden_size) if config.puzzle_emb_ndim > 0: self.puzzle_emb = CastedSparseEmbedding( config.num_puzzle_identifiers, config.puzzle_emb_ndim, batch_size=config.batch_size, init_std=0, cast_to=self.forward_dtype, ) seq_full = config.seq_len + self.puzzle_emb_len # Single tied SRM block used n_iters times per ACT step self.srm_block = StableRecursionModel_ACTV1Block(config, seq_full=seq_full) self.H_init = nn.Buffer( trunc_normal_init_(torch.empty(config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True, ) self.L_init = nn.Buffer( trunc_normal_init_(torch.empty(config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True, ) def _input_embeddings(self, input_ids: torch.Tensor, puzzle_ids: torch.Tensor) -> torch.Tensor: emb = self.embed_tokens(input_ids.to(torch.int32)) if self.config.puzzle_emb_ndim > 0: puzzle_embedding = self.puzzle_emb(puzzle_ids) pad = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] if pad > 0: puzzle_embedding = F.pad(puzzle_embedding, (0, pad)) puzzle_embedding = puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size) emb = torch.cat((puzzle_embedding, emb), dim=-2) return self.embed_scale * emb def empty_carry(self, batch_size: int) -> StableRecursionModel_ACTV1InnerCarry: seq_full = self.config.seq_len + self.puzzle_emb_len return StableRecursionModel_ACTV1InnerCarry( z_H=torch.empty(batch_size, seq_full, self.config.hidden_size, dtype=self.forward_dtype), z_L=torch.empty(batch_size, seq_full, self.config.hidden_size, dtype=self.forward_dtype), ) def reset_carry(self, reset_flag: torch.Tensor, carry: StableRecursionModel_ACTV1InnerCarry) -> StableRecursionModel_ACTV1InnerCarry: return StableRecursionModel_ACTV1InnerCarry( z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), ) def forward(self, carry: StableRecursionModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[StableRecursionModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: input_emb = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) # n_iters - 1 no-grad iterations + 1 grad iteration (HRM-style DEQ truncation) with torch.no_grad(): z_H, z_L = carry.z_H, carry.z_L for _ in range(self.config.n_iters - 1): z_H, z_L = self.srm_block(z_H, z_L, input_emb) assert not z_H.requires_grad and not z_L.requires_grad z_H, z_L = self.srm_block(z_H, z_L, input_emb) new_carry = StableRecursionModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) output = self.lm_head(z_H)[:, self.puzzle_emb_len:] q_logits = self.q_head(z_H[:, 0]).to(torch.float32) return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) class StableRecursionModel_ACTV1(nn.Module): """ACT wrapper — mirrors HierarchicalReasoningModel_ACTV1 1-to-1.""" def __init__(self, config_dict: dict): super().__init__() self.config = StableRecursionModel_ACTV1Config(**config_dict) self.inner = StableRecursionModel_ACTV1_Inner(self.config) @property def puzzle_emb(self): return self.inner.puzzle_emb def initial_carry(self, batch: Dict[str, torch.Tensor]) -> StableRecursionModel_ACTV1Carry: B = batch["inputs"].shape[0] return StableRecursionModel_ACTV1Carry( inner_carry=self.inner.empty_carry(B), steps=torch.zeros((B,), dtype=torch.int32), halted=torch.ones((B,), dtype=torch.bool), current_data={k: torch.empty_like(v) for k, v in batch.items()}, ) def forward(self, carry: StableRecursionModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[StableRecursionModel_ACTV1Carry, Dict[str, torch.Tensor]]: new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) new_steps = torch.where(carry.halted, 0, carry.steps) new_current_data = { k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items() } new_inner_carry, logits, (q_halt, q_continue) = self.inner(new_inner_carry, new_current_data) outputs = {"logits": logits, "q_halt_logits": q_halt, "q_continue_logits": q_continue} with torch.no_grad(): new_steps = new_steps + 1 is_last = new_steps >= self.config.halt_max_steps halted = is_last if self.training and self.config.halt_max_steps > 1: halted = halted | (q_halt > q_continue) min_halt = (torch.rand_like(q_halt) < self.config.halt_exploration_prob) * \ torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) halted = halted & (new_steps >= min_halt) next_q_halt, next_q_continue = self.inner(new_inner_carry, new_current_data)[-1] outputs["target_q_continue"] = torch.sigmoid( torch.where(is_last, next_q_halt, torch.maximum(next_q_halt, next_q_continue)) ) return StableRecursionModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs # ============================================================================= # Empirical Lipschitz diagnostic # ============================================================================= @torch.no_grad() def measure_lipschitz_constant(inner: StableRecursionModel_ACTV1_Inner, sample_batch: Dict[str, torch.Tensor], n_probes: int = 64, eps: float = 1e-3) -> Dict[str, float]: """Estimate Lip_P of srm_block via random perturbations. Returns ratio ||T(z+δ) - T(z)||_P / ||δ||_P. Should be ≤ (1-α) + α·κ. """ cfg = inner.config B = sample_batch["inputs"].shape[0] seq_full = cfg.seq_len + inner.puzzle_emb_len h = inner.H_init.unsqueeze(0).expand(B, seq_full, cfg.hidden_size).to(inner.forward_dtype).clone() l = inner.L_init.unsqueeze(0).expand(B, seq_full, cfg.hidden_size).to(inner.forward_dtype).clone() input_emb = inner._input_embeddings(sample_batch["inputs"], sample_batch["puzzle_identifiers"]) h_new, l_new = inner.srm_block(h, l, input_emb) ratios = [] for _ in range(n_probes): dh = torch.randn_like(h) * eps dl = torch.randn_like(l) * eps h_p, l_p = inner.srm_block(h + dh, l + dl, input_emb) d_in_h = dh.float().flatten(1).pow(2).sum(1) d_in_l = dl.float().flatten(1).pow(2).sum(1) d_in_P = (d_in_h + cfg.eta * d_in_l).sqrt() d_out_h = (h_p - h_new).float().flatten(1).pow(2).sum(1) d_out_l = (l_p - l_new).float().flatten(1).pow(2).sum(1) d_out_P = (d_out_h + cfg.eta * d_out_l).sqrt() ratios.append((d_out_P / d_in_P.clamp_min(1e-12)).cpu()) R = torch.cat(ratios) bound = (1 - cfg.alpha) + cfg.alpha * cfg.kappa return { "lip_emp_mean": float(R.mean()), "lip_emp_max": float(R.max()), "lip_emp_99p": float(R.quantile(0.99)), "lip_theoretical_bound": float(bound), "passes_bound": bool(R.max() <= bound * 1.05), } if __name__ == "__main__": cfg = dict( batch_size=4, seq_len=81, vocab_size=11, num_puzzle_identifiers=1, puzzle_emb_ndim=512, hidden_size=256, n_iters=6, n_aol_layers=2, kappa=0.9, eta=1.0, alpha=1.0, halt_max_steps=4, halt_exploration_prob=0.1, forward_dtype="bfloat16", ) model = StableRecursionModel_ACTV1(cfg).cuda() print(f"params={sum(p.numel() for p in model.parameters()):,}") batch = { "inputs": torch.randint(0, 11, (4, 81), dtype=torch.int32).cuda(), "labels": torch.randint(0, 11, (4, 81), dtype=torch.int32).cuda(), "puzzle_identifiers": torch.zeros(4, dtype=torch.int32).cuda(), } carry = model.initial_carry(batch) carry.inner_carry.z_H = carry.inner_carry.z_H.cuda() carry.inner_carry.z_L = carry.inner_carry.z_L.cuda() carry.steps = carry.steps.cuda() carry.halted = carry.halted.cuda() for k in carry.current_data: carry.current_data[k] = batch[k] model.eval() new_carry, out = model(carry, batch) print(f"forward OK | logits={out['logits'].shape}") lip = measure_lipschitz_constant(model.inner, batch, n_probes=32) print(f"Lipschitz check: emp_max={lip['lip_emp_max']:.4f} bound={lip['lip_theoretical_bound']:.4f} " f"passes={lip['passes_bound']} emp_mean={lip['lip_emp_mean']:.4f}")