summaryrefslogtreecommitdiff
path: root/srm/models
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /srm/models
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'srm/models')
-rw-r--r--srm/models/common.py32
-rw-r--r--srm/models/hrm/hrm_act_v1.py283
-rw-r--r--srm/models/layers.py157
-rw-r--r--srm/models/losses.py101
-rw-r--r--srm/models/sparse_embedding.py132
-rw-r--r--srm/models/srm/__init__.py0
-rw-r--r--srm/models/srm/hrm_orth_v1.py376
-rw-r--r--srm/models/srm/srm_aol_v1.py494
8 files changed, 1575 insertions, 0 deletions
diff --git a/srm/models/common.py b/srm/models/common.py
new file mode 100644
index 0000000..1a04505
--- /dev/null
+++ b/srm/models/common.py
@@ -0,0 +1,32 @@
+import math
+
+import torch
+from torch import nn
+
+
+def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
+ # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
+ # This function is a PyTorch version of jax truncated normal init (default init method in flax)
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
+
+ with torch.no_grad():
+ if std == 0:
+ tensor.zero_()
+ else:
+ sqrt2 = math.sqrt(2)
+ a = math.erf(lower / sqrt2)
+ b = math.erf(upper / sqrt2)
+ z = (b - a) / 2
+
+ c = (2 * math.pi) ** -0.5
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
+ pdf_l = c * math.exp(-0.5 * upper ** 2)
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
+
+ tensor.uniform_(a, b)
+ tensor.erfinv_()
+ tensor.mul_(sqrt2 * comp_std)
+ tensor.clip_(lower * comp_std, upper * comp_std)
+
+ return tensor
diff --git a/srm/models/hrm/hrm_act_v1.py b/srm/models/hrm/hrm_act_v1.py
new file mode 100644
index 0000000..e91c7d1
--- /dev/null
+++ b/srm/models/hrm/hrm_act_v1.py
@@ -0,0 +1,283 @@
+from typing import Tuple, List, Dict, Optional
+from dataclasses import dataclass
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from pydantic import BaseModel
+
+from models.common import trunc_normal_init_
+from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
+from models.sparse_embedding import CastedSparseEmbedding
+
+
+@dataclass
+class HierarchicalReasoningModel_ACTV1InnerCarry:
+ z_H: torch.Tensor
+ z_L: torch.Tensor
+
+
+@dataclass
+class HierarchicalReasoningModel_ACTV1Carry:
+ inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
+
+ steps: torch.Tensor
+ halted: torch.Tensor
+
+ current_data: Dict[str, torch.Tensor]
+
+
+class HierarchicalReasoningModel_ACTV1Config(BaseModel):
+ batch_size: int
+ seq_len: int
+ puzzle_emb_ndim: int = 0
+ num_puzzle_identifiers: int
+ vocab_size: int
+
+ H_cycles: int
+ L_cycles: int
+
+ H_layers: int
+ L_layers: int
+
+ # Transformer config
+ hidden_size: int
+ expansion: float
+ num_heads: int
+ pos_encodings: str
+
+ rms_norm_eps: float = 1e-5
+ rope_theta: float = 10000.0
+
+ # Halting Q-learning config
+ halt_max_steps: int
+ halt_exploration_prob: float
+
+ forward_dtype: str = "bfloat16"
+
+
+class HierarchicalReasoningModel_ACTV1Block(nn.Module):
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
+ super().__init__()
+
+ self.self_attn = Attention(
+ hidden_size=config.hidden_size,
+ head_dim=config.hidden_size // config.num_heads,
+ num_heads=config.num_heads,
+ num_key_value_heads=config.num_heads,
+ causal=False
+ )
+ self.mlp = SwiGLU(
+ hidden_size=config.hidden_size,
+ expansion=config.expansion,
+ )
+ self.norm_eps = config.rms_norm_eps
+
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
+ # Post Norm
+ # Self Attention
+ hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
+ # Fully Connected
+ hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
+ return hidden_states
+
+
+class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
+ def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
+ super().__init__()
+
+ self.layers = torch.nn.ModuleList(layers)
+
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
+ # Input injection (add)
+ hidden_states = hidden_states + input_injection
+ # Layers
+ for layer in self.layers:
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
+
+ return hidden_states
+
+
+class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
+ super().__init__()
+ self.config = config
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
+
+ # I/O
+ self.embed_scale = math.sqrt(self.config.hidden_size)
+ embed_init_std = 1.0 / self.embed_scale
+
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
+
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
+ if self.config.puzzle_emb_ndim > 0:
+ # Zero init puzzle embeddings
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
+
+ # LM Blocks
+ if self.config.pos_encodings == "rope":
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
+ base=self.config.rope_theta)
+ elif self.config.pos_encodings == "learned":
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
+ else:
+ raise NotImplementedError()
+
+ # Reasoning Layers
+ self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
+ self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
+
+ # Initial states
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
+
+ # Q head special init
+ # Init Q to (almost) zero for faster learning during bootstrapping
+ with torch.no_grad():
+ self.q_head.weight.zero_()
+ self.q_head.bias.fill_(-5) # type: ignore
+
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
+ # Token embedding
+ embedding = self.embed_tokens(input.to(torch.int32))
+
+ # Puzzle embeddings
+ if self.config.puzzle_emb_ndim > 0:
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
+
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
+ if pad_count > 0:
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
+
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
+
+ # Position embeddings
+ if self.config.pos_encodings == "learned":
+ # scale by 1/sqrt(2) to maintain forward variance
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
+
+ # Scale
+ return self.embed_scale * embedding
+
+ def empty_carry(self, batch_size: int):
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
+ )
+
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
+ )
+
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ seq_info = dict(
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
+ )
+
+ # Input encoding
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ # Forward iterations
+ with torch.no_grad():
+ z_H, z_L = carry.z_H, carry.z_L
+
+ for _H_step in range(self.config.H_cycles):
+ for _L_step in range(self.config.L_cycles):
+ if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
+
+ if not (_H_step == self.config.H_cycles - 1):
+ z_H = self.H_level(z_H, z_L, **seq_info)
+
+ assert not z_H.requires_grad and not z_L.requires_grad
+
+ # 1-step grad
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
+ z_H = self.H_level(z_H, z_L, **seq_info)
+
+ # LM Outputs
+ new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
+
+ # Q head
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
+
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
+
+
+class HierarchicalReasoningModel_ACTV1(nn.Module):
+ """ACT wrapper."""
+
+ def __init__(self, config_dict: dict):
+ super().__init__()
+ self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
+ self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
+
+ @property
+ def puzzle_emb(self):
+ return self.inner.puzzle_emb
+
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
+ batch_size = batch["inputs"].shape[0]
+
+ return HierarchicalReasoningModel_ACTV1Carry(
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
+
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
+
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
+ )
+
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
+ # Update data, carry (removing halted sequences)
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
+
+ new_steps = torch.where(carry.halted, 0, carry.steps)
+
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
+
+ # Forward inner model
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
+
+ outputs = {
+ "logits": logits,
+ "q_halt_logits": q_halt_logits,
+ "q_continue_logits": q_continue_logits
+ }
+
+ with torch.no_grad():
+ # Step
+ new_steps = new_steps + 1
+ is_last_step = new_steps >= self.config.halt_max_steps
+
+ halted = is_last_step
+
+ # if training, and ACT is enabled
+ if self.training and (self.config.halt_max_steps > 1):
+ # Halt signal
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
+ halted = halted | (q_halt_logits > q_continue_logits)
+
+ # Exploration
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
+
+ halted = halted & (new_steps >= min_halt_steps)
+
+ # Compute target Q
+ # NOTE: No replay buffer and target networks for computing target Q-value.
+ # As batch_size is large, there're many parallel envs.
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
+ next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
+
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
+
+ return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
diff --git a/srm/models/layers.py b/srm/models/layers.py
new file mode 100644
index 0000000..0394744
--- /dev/null
+++ b/srm/models/layers.py
@@ -0,0 +1,157 @@
+from typing import Tuple
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+try:
+ from flash_attn_interface import flash_attn_func # type: ignore[import]
+except ImportError:
+ # Fallback to FlashAttention 2
+ from flash_attn import flash_attn_func # type: ignore[import]
+
+from models.common import trunc_normal_init_
+
+
+CosSin = Tuple[torch.Tensor, torch.Tensor]
+
+
+def _find_multiple(a, b):
+ return (-(a // -b)) * b
+
+
+def rotate_half(x: torch.Tensor):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
+ # q, k: [bs, seq_len, num_heads, head_dim]
+ # cos, sin: [seq_len, head_dim]
+ orig_dtype = q.dtype
+ q = q.to(cos.dtype)
+ k = k.to(cos.dtype)
+
+ q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
+ k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
+
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
+
+
+class CastedLinear(nn.Module):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias: bool):
+ super().__init__()
+ # Truncated LeCun normal init
+ self.weight = nn.Parameter(
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
+ )
+ self.bias = None
+ if bias:
+ # Zero init bias
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
+
+
+class CastedEmbedding(nn.Module):
+ def __init__(self,
+ num_embeddings: int,
+ embedding_dim: int,
+ init_std: float,
+ cast_to: torch.dtype):
+ super().__init__()
+ self.cast_to = cast_to
+
+ # Truncated LeCun normal init
+ self.embedding_weight = nn.Parameter(
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.embedding(input, self.embedding_weight.to(self.cast_to))
+
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings, base, device=None):
+ super().__init__()
+
+ # RoPE
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
+ freqs = torch.outer(t, inv_freq)
+
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
+ self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
+
+ def forward(self):
+ return self.cos_cached, self.sin_cached
+
+
+class Attention(nn.Module):
+ def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.head_dim = head_dim
+ self.output_size = head_dim * num_heads
+ self.num_heads = num_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.causal = causal
+
+ self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
+ self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
+
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_len, _ = hidden_states.shape
+
+ # hidden_states: [bs, seq_len, num_heads, head_dim]
+ qkv = self.qkv_proj(hidden_states)
+
+ # Split head
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
+ query = qkv[:, :, :self.num_heads]
+ key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
+ value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
+
+ # RoPE
+ if cos_sin is not None:
+ cos, sin = cos_sin
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
+
+ # flash attn
+ attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
+ if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
+ attn_output = attn_output[0]
+
+ attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
+ return self.o_proj(attn_output)
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, hidden_size: int, expansion: float):
+ super().__init__()
+ inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
+
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
+
+ def forward(self, x):
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
+ return self.down_proj(F.silu(gate) * up)
+
+
+def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ variance = hidden_states.square().mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
+ return hidden_states.to(input_dtype)
diff --git a/srm/models/losses.py b/srm/models/losses.py
new file mode 100644
index 0000000..b3118e7
--- /dev/null
+++ b/srm/models/losses.py
@@ -0,0 +1,101 @@
+from typing import Any, Tuple, Dict, Sequence, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+IGNORE_LABEL_ID = -100
+
+
+def s(x, epsilon=1e-30):
+ return torch.where(
+ x<0,
+ 1/(1-x+ epsilon),
+ x + 1
+ )
+
+
+def log_stablemax(x, dim=-1):
+ s_x = s(x)
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
+
+
+def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
+
+ valid_mask = labels != ignore_index
+ transformed_labels = torch.where(valid_mask, labels, 0)
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
+
+ return -torch.where(valid_mask, prediction_logprobs, 0)
+
+
+def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
+ # Cast logits to f32
+ # Flatten logits
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
+
+
+class ACTLossHead(nn.Module):
+ def __init__(self, model: nn.Module, loss_type: str):
+ super().__init__()
+ self.model = model
+ self.loss_fn = globals()[loss_type]
+
+ def initial_carry(self, *args, **kwargs):
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
+
+ def forward(
+ self,
+ return_keys: Sequence[str],
+ # Model args
+ **model_kwargs,
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
+ # Model logits
+ # B x SeqLen x D
+ new_carry, outputs = self.model(**model_kwargs)
+ labels = new_carry.current_data["labels"]
+
+ # Correctness
+ with torch.no_grad():
+ mask = labels != IGNORE_LABEL_ID
+ loss_counts = mask.sum(-1)
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
+
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
+ seq_is_correct = is_correct.sum(-1) == loss_counts
+
+ # Metrics (halted)
+ valid_metrics = new_carry.halted & (loss_counts > 0)
+ metrics = {
+ "count": valid_metrics.sum(),
+
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
+
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
+ }
+
+ # Losses
+ # FIXME: Assuming the batch is always full
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
+
+ metrics.update({
+ "lm_loss": lm_loss.detach(),
+ "q_halt_loss": q_halt_loss.detach(),
+ })
+
+ # Q continue (bootstrapping target loss)
+ q_continue_loss = 0
+ if "target_q_continue" in outputs:
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
+
+ metrics["q_continue_loss"] = q_continue_loss.detach()
+
+ # Filter outputs for return
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
+
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
diff --git a/srm/models/sparse_embedding.py b/srm/models/sparse_embedding.py
new file mode 100644
index 0000000..c701524
--- /dev/null
+++ b/srm/models/sparse_embedding.py
@@ -0,0 +1,132 @@
+from typing import Union
+
+import torch
+from torch import nn
+import torch.distributed as dist
+from torch.optim.optimizer import Optimizer, ParamsT
+
+from models.common import trunc_normal_init_
+
+
+class CastedSparseEmbedding(nn.Module):
+ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
+ super().__init__()
+ self.cast_to = cast_to
+
+ # Real Weights
+ # Truncated LeCun normal init
+ self.weights = nn.Buffer(
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
+ )
+
+ # Local weights and IDs
+ # Local embeddings, with gradient, not persistent
+ self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
+ # Local embedding IDs, not persistent
+ self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ if not self.training:
+ # Test mode, no gradient
+ return self.weights[inputs].to(self.cast_to)
+
+ # Training mode, fill puzzle embedding from weights
+ with torch.no_grad():
+ self.local_weights.copy_(self.weights[inputs])
+ self.local_ids.copy_(inputs)
+
+ return self.local_weights.to(self.cast_to)
+
+
+class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
+ def __init__(
+ self,
+ params: ParamsT,
+
+ world_size: int,
+ lr: Union[float, torch.Tensor] = 1e-3,
+ weight_decay: float = 1e-2,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= weight_decay:
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+
+ defaults = dict(
+ lr=lr,
+ weight_decay=weight_decay,
+ world_size=world_size
+ )
+ super().__init__(params, defaults)
+
+ @torch.no_grad
+ def step(self, closure=None): # type: ignore
+ for group in self.param_groups:
+ # Find the sparse embedding weights
+ local_weights_grad = None
+ local_ids = None
+ weights = None
+
+ assert len(group["params"]) == 3
+ for p in group["params"]:
+ if p.requires_grad:
+ local_weights_grad = p.grad
+ elif p.ndim == 1:
+ local_ids = p
+ elif p.ndim == 2:
+ weights = p
+ else:
+ assert False
+
+ assert local_weights_grad is not None
+ assert local_ids is not None
+ assert weights is not None
+
+ # Apply SignSGD
+ # Adam ≈ SignSGD if gradient is very sparse
+ _sparse_emb_signsgd_dist(
+ local_weights_grad,
+ local_ids,
+ weights,
+
+ lr=group["lr"],
+ weight_decay=group["weight_decay"],
+ world_size=group["world_size"]
+ )
+
+
+def _sparse_emb_signsgd_dist(
+ local_weights_grad: torch.Tensor,
+ local_ids: torch.Tensor,
+ weights: torch.Tensor,
+
+ lr: float,
+ weight_decay: float,
+ world_size: int
+) -> None:
+ N, D = local_weights_grad.shape
+
+ # All-gather
+ all_weights_grad = local_weights_grad
+ all_ids = local_ids
+
+ if world_size > 1:
+ all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
+ all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
+
+ dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
+ dist.all_gather_into_tensor(all_ids, local_ids)
+
+ # Unique
+ grad_ids, inv = all_ids.unique(return_inverse=True)
+
+ grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
+ grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
+
+ # SignSGD with decoupled weight decay
+ p = weights[grad_ids]
+
+ p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
+
+ # Write updated slices back
+ weights[grad_ids] = p
diff --git a/srm/models/srm/__init__.py b/srm/models/srm/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/srm/models/srm/__init__.py
diff --git a/srm/models/srm/hrm_orth_v1.py b/srm/models/srm/hrm_orth_v1.py
new file mode 100644
index 0000000..65656df
--- /dev/null
+++ b/srm/models/srm/hrm_orth_v1.py
@@ -0,0 +1,376 @@
+"""HRM-Orth v1 — orthogonal patch of HRM per codex round 2 recommendation.
+
+CORE IDEA (codex Q6 pivot, after pure-orthogonal retract Q1):
+Keep HRM's H_level/L_level/ACT structure, just patch the inner Block:
+ - Attention → cosine-normalized attention (≈ Lipschitz-bounded)
+ - SwiGLU MLP → CayleyOrth linear + MaxMin + CayleyOrth linear
+ - rms_norm + add → weighted residual: h_new = (1-σ(w)) · h + σ(w) · f(h)
+ - "Weak orthogonality": diag(s) scaling with most s≈1, some s∈[0.90, 0.97] for compression
+
+Per codex Q5 decomp: target +5~+7pp over SRM v1 (0.39 → 0.43-0.46).
+Per codex Q3: Cayley used (we have it from srm_aol_v1); Householder would be faster but more impl.
+"""
+from typing import Tuple, List, Dict, Optional
+from dataclasses import dataclass
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from pydantic import BaseModel
+
+from models.common import trunc_normal_init_
+from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
+from models.sparse_embedding import CastedSparseEmbedding
+from models.srm.srm_aol_v1 import CayleyOrthogonal
+
+
+def maxmin(x: torch.Tensor, group: int = 2) -> torch.Tensor:
+ """1-Lipschitz norm-preserving activation (Anil et al. 2019 GroupSort).
+
+ Pairs adjacent dims; outputs (min, max) per pair. Permutation a.e. → ||∇|| = 1.
+ Strictly better than ReLU under norm constraints (no rank-kill).
+ """
+ *prefix, d = x.shape
+ if d % group != 0:
+ pad = group - (d % group)
+ x = F.pad(x, (0, pad))
+ d = d + pad
+ xg = x.reshape(*prefix, d // group, group)
+ sorted_vals, _ = xg.sort(dim=-1)
+ return sorted_vals.reshape(*prefix, d)
+
+
+def cosine_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ tau: float = 8.0) -> torch.Tensor:
+ """Cosine-normalized softmax attention. Approximately Lipschitz-bounded
+ (exact bound depends on tau and value norms — see LipsFormer Qi 2023)."""
+ q = F.normalize(q, dim=-1)
+ k = F.normalize(k, dim=-1)
+ attn = (q @ k.transpose(-2, -1)) * tau
+ attn = attn.softmax(dim=-1)
+ return attn @ v
+
+
+class OrthLinear(nn.Module):
+ """Orthogonal linear layer via Cayley. Allows optional row-scaling diag(s)
+ where s_i ∈ [s_min, 1] to introduce 'weak orthogonality' (codex Q1 fix).
+
+ If s_min < 1, the operator is contractive in some directions:
+ Lip = max(s) ≤ 1, det = prod(s) ≤ 1 (weak contraction in compressing modes)
+ """
+ def __init__(self, dim: int, s_min: float = 0.85, learn_scale: bool = True,
+ init_std_scale: float = 5.0):
+ super().__init__()
+ self.Q = CayleyOrthogonal(dim)
+ # Bump init A by init_std_scale to push Cayley away from identity
+ with torch.no_grad():
+ self.Q.A.mul_(init_std_scale)
+ self.s_min = s_min
+ # diag scale: sigmoid -> [s_min, 1]
+ if learn_scale and s_min < 1.0:
+ self.log_s_raw = nn.Parameter(torch.zeros(dim)) # init sigmoid(0)=0.5 → scale=(s_min+1)/2
+ else:
+ self.register_buffer("log_s_raw", torch.zeros(dim))
+ self.learn_scale = learn_scale
+
+ def scale_diag(self) -> torch.Tensor:
+ if self.s_min >= 1.0 or not self.learn_scale:
+ return torch.ones_like(self.log_s_raw)
+ # Affine map sigmoid → [s_min, 1]
+ return self.s_min + (1.0 - self.s_min) * torch.sigmoid(self.log_s_raw)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ Q = self.Q() # (d, d) orthogonal
+ s = self.scale_diag().to(Q.dtype) # (d,) in [s_min, 1]
+ Qs = Q * s.unsqueeze(0) # rescale columns
+ return F.linear(x, Qs)
+
+
+@dataclass
+class HierarchicalReasoningModel_ACTV1InnerCarry:
+ z_H: torch.Tensor
+ z_L: torch.Tensor
+
+
+@dataclass
+class HierarchicalReasoningModel_ACTV1Carry:
+ inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
+
+ steps: torch.Tensor
+ halted: torch.Tensor
+
+ current_data: Dict[str, torch.Tensor]
+
+
+class HierarchicalReasoningModel_ACTV1Config(BaseModel):
+ batch_size: int
+ seq_len: int
+ puzzle_emb_ndim: int = 0
+ num_puzzle_identifiers: int
+ vocab_size: int
+
+ H_cycles: int
+ L_cycles: int
+
+ H_layers: int
+ L_layers: int
+
+ # Transformer config
+ hidden_size: int
+ expansion: float
+ num_heads: int
+ pos_encodings: str
+
+ rms_norm_eps: float = 1e-5
+ rope_theta: float = 10000.0
+
+ # Halting Q-learning config
+ halt_max_steps: int
+ halt_exploration_prob: float
+
+ forward_dtype: str = "bfloat16"
+
+
+class HierarchicalReasoningModel_ACTV1Block(nn.Module):
+ """Orthogonal-patched HRM Block.
+
+ Replaces (attn + SwiGLU + rms_norm) with (cosine attn + Orth-MLP + weighted residual).
+ The original class name preserved so the ReasoningModule wrapper is unchanged.
+ """
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
+ super().__init__()
+ d = config.hidden_size
+ s_min = getattr(config, "orth_s_min", 0.85) # v2: 0.95 → 0.85 for real contraction
+ cosine_tau = getattr(config, "cosine_attn_tau", 1.0) # v2: 8 → 1 for diverse softmax at init
+ init_std_scale = getattr(config, "orth_init_std_scale", 5.0)
+
+ # Lipschitz-bounded cosine attention: orthogonal Q/K/V/O projections
+ self.q_proj = OrthLinear(d, s_min=1.0, learn_scale=False, init_std_scale=init_std_scale)
+ self.k_proj = OrthLinear(d, s_min=1.0, learn_scale=False, init_std_scale=init_std_scale)
+ self.v_proj = OrthLinear(d, s_min=s_min, learn_scale=True, init_std_scale=init_std_scale)
+ self.o_proj = OrthLinear(d, s_min=s_min, learn_scale=True, init_std_scale=init_std_scale)
+ self.cosine_tau = cosine_tau
+
+ # Orth-MLP: OrthLinear -> MaxMin -> OrthLinear (no expansion; uses original d)
+ self.mlp_in = OrthLinear(d, s_min=s_min, learn_scale=True, init_std_scale=init_std_scale)
+ self.mlp_out = OrthLinear(d, s_min=s_min, learn_scale=True, init_std_scale=init_std_scale)
+
+ # Weighted residual gates — v2: init logit=+2 → sigmoid=0.88 so block dominates
+ self.w_attn_logit = nn.Parameter(torch.full((), 2.0))
+ self.w_mlp_logit = nn.Parameter(torch.full((), 2.0))
+
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
+ # Cosine attention
+ q = self.q_proj(hidden_states)
+ k = self.k_proj(hidden_states)
+ v = self.v_proj(hidden_states)
+ attn_out = self.o_proj(cosine_attention(q, k, v, tau=self.cosine_tau))
+ w_attn = torch.sigmoid(self.w_attn_logit)
+ hidden_states = (1.0 - w_attn) * hidden_states + w_attn * attn_out
+
+ # Orth-MLP with MaxMin
+ mlp_out = self.mlp_out(maxmin(self.mlp_in(hidden_states), group=2))
+ w_mlp = torch.sigmoid(self.w_mlp_logit)
+ hidden_states = (1.0 - w_mlp) * hidden_states + w_mlp * mlp_out
+ return hidden_states
+
+
+class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
+ def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
+ super().__init__()
+
+ self.layers = torch.nn.ModuleList(layers)
+
+ def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
+ # Input injection (add)
+ hidden_states = hidden_states + input_injection
+ # Layers
+ for layer in self.layers:
+ hidden_states = layer(hidden_states=hidden_states, **kwargs)
+
+ return hidden_states
+
+
+class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
+ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
+ super().__init__()
+ self.config = config
+ self.forward_dtype = getattr(torch, self.config.forward_dtype)
+
+ # I/O
+ self.embed_scale = math.sqrt(self.config.hidden_size)
+ embed_init_std = 1.0 / self.embed_scale
+
+ self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
+ self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
+ self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
+
+ self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
+ if self.config.puzzle_emb_ndim > 0:
+ # Zero init puzzle embeddings
+ self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
+ batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
+
+ # LM Blocks
+ if self.config.pos_encodings == "rope":
+ self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
+ max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
+ base=self.config.rope_theta)
+ elif self.config.pos_encodings == "learned":
+ self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
+ else:
+ raise NotImplementedError()
+
+ # Reasoning Layers
+ self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
+ self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
+
+ # Initial states
+ self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
+ self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
+
+ # Q head special init
+ # Init Q to (almost) zero for faster learning during bootstrapping
+ with torch.no_grad():
+ self.q_head.weight.zero_()
+ self.q_head.bias.fill_(-5) # type: ignore
+
+ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
+ # Token embedding
+ embedding = self.embed_tokens(input.to(torch.int32))
+
+ # Puzzle embeddings
+ if self.config.puzzle_emb_ndim > 0:
+ puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
+
+ pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
+ if pad_count > 0:
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
+
+ embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
+
+ # Position embeddings
+ if self.config.pos_encodings == "learned":
+ # scale by 1/sqrt(2) to maintain forward variance
+ embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
+
+ # Scale
+ return self.embed_scale * embedding
+
+ def empty_carry(self, batch_size: int):
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
+ z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
+ z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
+ )
+
+ def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
+ return HierarchicalReasoningModel_ACTV1InnerCarry(
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
+ )
+
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ seq_info = dict(
+ cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
+ )
+
+ # Input encoding
+ input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ # Forward iterations
+ with torch.no_grad():
+ z_H, z_L = carry.z_H, carry.z_L
+
+ for _H_step in range(self.config.H_cycles):
+ for _L_step in range(self.config.L_cycles):
+ if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
+
+ if not (_H_step == self.config.H_cycles - 1):
+ z_H = self.H_level(z_H, z_L, **seq_info)
+
+ assert not z_H.requires_grad and not z_L.requires_grad
+
+ # 1-step grad
+ z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
+ z_H = self.H_level(z_H, z_L, **seq_info)
+
+ # LM Outputs
+ new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
+
+ # Q head
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
+
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
+
+
+class HierarchicalReasoningModel_ACTV1(nn.Module):
+ """ACT wrapper."""
+
+ def __init__(self, config_dict: dict):
+ super().__init__()
+ self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
+ self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
+
+ @property
+ def puzzle_emb(self):
+ return self.inner.puzzle_emb
+
+ def initial_carry(self, batch: Dict[str, torch.Tensor]):
+ batch_size = batch["inputs"].shape[0]
+
+ return HierarchicalReasoningModel_ACTV1Carry(
+ inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
+
+ steps=torch.zeros((batch_size, ), dtype=torch.int32),
+ halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
+
+ current_data={k: torch.empty_like(v) for k, v in batch.items()}
+ )
+
+ def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
+ # Update data, carry (removing halted sequences)
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
+
+ new_steps = torch.where(carry.halted, 0, carry.steps)
+
+ new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
+
+ # Forward inner model
+ new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
+
+ outputs = {
+ "logits": logits,
+ "q_halt_logits": q_halt_logits,
+ "q_continue_logits": q_continue_logits
+ }
+
+ with torch.no_grad():
+ # Step
+ new_steps = new_steps + 1
+ is_last_step = new_steps >= self.config.halt_max_steps
+
+ halted = is_last_step
+
+ # if training, and ACT is enabled
+ if self.training and (self.config.halt_max_steps > 1):
+ # Halt signal
+ # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
+ halted = halted | (q_halt_logits > q_continue_logits)
+
+ # Exploration
+ min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
+
+ halted = halted & (new_steps >= min_halt_steps)
+
+ # Compute target Q
+ # NOTE: No replay buffer and target networks for computing target Q-value.
+ # As batch_size is large, there're many parallel envs.
+ # Similar concept as PQN https://arxiv.org/abs/2407.04811
+ next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
+
+ outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
+
+ return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
diff --git a/srm/models/srm/srm_aol_v1.py b/srm/models/srm/srm_aol_v1.py
new file mode 100644
index 0000000..c4e2719
--- /dev/null
+++ b/srm/models/srm/srm_aol_v1.py
@@ -0,0 +1,494 @@
+"""SRM-Joint-AOL v1 — Stable Recursive Model (forked from hrm_act_v1.py).
+
+Replaces HRM's separate H_level / L_level transformer stacks with ONE joint
+operator T on state z = (h, l) that is provably contractive under weighted
+P-norm ||z||²_P = ||h||² + η||l||² with Lipschitz constant ≤ κ ∈ (0.85, 0.95).
+
+Key replacement vs HRM:
+- HierarchicalReasoningModel_ACTV1Block (attn + SwiGLU)
+ → StableRecursionModel_ACTV1Block (joint SRM step on (h, l))
+- ReasoningModule wraps n_iters joint updates instead of separate H/L cycles.
+
+Lipschitz analysis (per step, in P-norm):
+ Lip_P(T) ≤ (1-α) + α · κ < 1
+ ⇒ joint top-1 Lyapunov per micro-step: λ_1 ≤ log((1-α) + α·κ) < 0
+
+ARCHITECTURE (one joint step):
+ z = concat(h + b_in_h(x), √η · (l + b_in_l(x))) # join with input bias
+ ψ = AOL_Block(z) # Lip_P(ψ) ≤ 1
+ ψ_h, ψ_l_scaled = split(ψ); ψ_l = ψ_l_scaled / √η
+ Az_h = a_HH·ψ_h + a_HL·(U_HL·ψ_l) # gain row sum ≤ κ
+ Az_l = a_LH·(U_LH·ψ_h) + a_LL·ψ_l
+ h_new = (1-α)·h + α·Az_h + b_out_h(x)
+ l_new = (1-α)·l + α·Az_l + b_out_l(x)
+
+REUSED FROM HRM:
+- ACT framework (q_head, halt logic)
+- CastedEmbedding/CastedLinear (bf16-safe linears)
+- CastedSparseEmbedding (puzzle_emb)
+- 1-step grad / DEQ-style truncation
+"""
+from typing import Tuple, Dict
+from dataclasses import dataclass
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from pydantic import BaseModel
+
+from models.common import trunc_normal_init_
+from models.layers import CastedEmbedding, CastedLinear
+from models.sparse_embedding import CastedSparseEmbedding
+
+
+# =============================================================================
+# Approximately 1-Lipschitz primitives
+# Normalization (AOL) and orthogonalization (Cayley) are computed in float32,
+# then cast to forward dtype (bf16 by default). The bound is *exact in fp32*
+# but only approximate after cast — bf16 rounding introduces a small error
+# that accumulates over n_aol_layers matmuls. Empirically the margin to the
+# theoretical κ-bound is large (~5×), so this is fine in practice, but the
+# guarantee is not strict. For applications where strictness matters, run
+# the bounded operators in float32.
+# =============================================================================
+
+class AOLLinear(nn.Module):
+ """≤1-Lipschitz linear layer via AOL (Prach & Lampert 2022) rescaling.
+
+ Given W ∈ R^(out × in), let A = W^T W (symmetric PSD).
+ Define D_jj = 1 / √(Σ_i |A_ij| + eps); set W̃ = W · diag(D).
+ Then ||W̃ x||_2 ≤ ||x||_2 in float32 (Prach & Lampert Theorem 1).
+ Bound is approximate (not exact) under bf16 due to rounding in W·diag(D)
+ and the subsequent matmul. Bias is unconstrained (shift only, doesn't
+ affect Lipschitz w.r.t. input).
+ """
+ def __init__(self, in_dim: int, out_dim: int, bias: bool = True,
+ cast_to: torch.dtype = torch.bfloat16, eps: float = 1e-6):
+ super().__init__()
+ std = 1.0 / math.sqrt(in_dim)
+ self.W = nn.Parameter(torch.randn(out_dim, in_dim) * std)
+ self.b = nn.Parameter(torch.zeros(out_dim)) if bias else None
+ self.cast_to = cast_to
+ self.eps = eps
+
+ def normalized_weight(self) -> torch.Tensor:
+ W32 = self.W.float()
+ WTW = W32.t() @ W32
+ col_abs_sum = WTW.abs().sum(dim=0)
+ scale = torch.rsqrt(col_abs_sum + self.eps)
+ return (W32 * scale.unsqueeze(0)).to(self.cast_to)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ W = self.normalized_weight()
+ out = F.linear(x, W)
+ if self.b is not None:
+ out = out + self.b.to(out.dtype)
+ return out
+
+
+class AOLBlock(nn.Module):
+ """Stack of AOLLinear with 1-Lipschitz activation (ReLU) between layers.
+
+ Composition of 1-Lipschitz maps is 1-Lipschitz. SiLU/GELU NOT allowed
+ (max derivative > 1 would break the bound).
+ """
+ def __init__(self, dim: int, n_layers: int = 2, cast_to: torch.dtype = torch.bfloat16):
+ super().__init__()
+ assert n_layers >= 1
+ self.layers = nn.ModuleList([
+ AOLLinear(dim, dim, bias=True, cast_to=cast_to) for _ in range(n_layers)
+ ])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i < len(self.layers) - 1:
+ x = F.relu(x)
+ return x
+
+
+class CayleyOrthogonal(nn.Module):
+ """Approximately orthogonal Q ∈ R^(d × d) via Cayley transform.
+
+ Q = (I + S)^{-1}(I - S) where S = (A - A^T)/2 is skew-symmetric.
+ Since (I+S) and (I-S) commute (both polynomials in S), the form is also
+ Q = (I - S)(I + S)^{-1}. Q^T Q = I exactly in float32 — approximate
+ after cast to bf16. Solve done in float32 for numerical stability.
+ NOTE: torch.linalg.solve may not be fullgraph-compile friendly. Test
+ before enabling torch.compile / FSDP.
+ """
+ def __init__(self, dim: int, cast_to: torch.dtype = torch.bfloat16):
+ super().__init__()
+ self.A = nn.Parameter(torch.randn(dim, dim) * (1.0 / math.sqrt(dim)))
+ self.dim = dim
+ self.cast_to = cast_to
+ self.register_buffer("I", torch.eye(dim), persistent=False)
+
+ def forward(self) -> torch.Tensor:
+ A32 = self.A.float()
+ S = 0.5 * (A32 - A32.t())
+ I = self.I.float()
+ Q = torch.linalg.solve(I + S, I - S)
+ return Q.to(self.cast_to)
+
+
+class BlockGain(nn.Module):
+ """Block gain matrix A with row sums ≤ κ under weighted P-norm.
+
+ H row entries (P-normalized): [a_HH, √η · a_HL], sum = κ
+ L row entries (P-normalized): [(1/√η) · a_LH, a_LL], sum = κ
+ Parameterized via softmax × κ ⇒ exact equality (saturation).
+ """
+ def __init__(self, kappa: float = 0.9, eta: float = 1.0, init_diag: float = 3.0):
+ super().__init__()
+ self.kappa = kappa
+ self.eta = eta
+ # init_diag=3.0 → softmax([3, 0]) ≈ [0.953, 0.047] ⇒ ~5% cross-coupling at start
+ # (init_diag=1.0 was too weak — gave 27% cross-coupling; init_diag=3.0 truly minimal)
+ self.logits_H = nn.Parameter(torch.tensor([init_diag, 0.0]))
+ self.logits_L = nn.Parameter(torch.tensor([0.0, init_diag]))
+
+ def forward(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ sqrt_eta = math.sqrt(self.eta)
+ gH = self.kappa * F.softmax(self.logits_H.float(), dim=0)
+ a_HH, a_HL_scaled = gH[0], gH[1]
+ a_HL = a_HL_scaled / sqrt_eta
+ gL = self.kappa * F.softmax(self.logits_L.float(), dim=0)
+ a_LH_scaled, a_LL = gL[0], gL[1]
+ a_LH = a_LH_scaled * sqrt_eta
+ return a_HH, a_HL, a_LH, a_LL
+
+
+class AOLTokenMixer(nn.Module):
+ """1-Lipschitz mixing across token (seq) and channel dims via AOL.
+
+ Pipeline for x of shape (B, seq, dim):
+ 1) Channel mix (AOL across `dim`)
+ 2) ReLU
+ 3) Token mix (AOL across `seq`, applied after transpose)
+ 4) ReLU
+ Composition of 1-Lipschitz maps ⇒ Lip ≤ 1.
+ """
+ def __init__(self, seq_len: int, dim: int, n_layers: int = 1,
+ cast_to: torch.dtype = torch.bfloat16):
+ super().__init__()
+ self.channel_mix = AOLBlock(dim=dim, n_layers=n_layers, cast_to=cast_to)
+ self.token_mix = AOLBlock(dim=seq_len, n_layers=n_layers, cast_to=cast_to)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ y = self.channel_mix(x)
+ y = F.relu(y)
+ y = y.transpose(-2, -1) # (B, dim, seq)
+ y = self.token_mix(y)
+ y = y.transpose(-2, -1) # (B, seq, dim)
+ return y
+
+
+# =============================================================================
+# Carry types and config
+# =============================================================================
+
+@dataclass
+class StableRecursionModel_ACTV1InnerCarry:
+ z_H: torch.Tensor
+ z_L: torch.Tensor
+
+
+@dataclass
+class StableRecursionModel_ACTV1Carry:
+ inner_carry: StableRecursionModel_ACTV1InnerCarry
+ steps: torch.Tensor
+ halted: torch.Tensor
+ current_data: Dict[str, torch.Tensor]
+
+
+class StableRecursionModel_ACTV1Config(BaseModel):
+ batch_size: int
+ seq_len: int
+ puzzle_emb_ndim: int = 0
+ num_puzzle_identifiers: int
+ vocab_size: int
+
+ # SRM-specific
+ n_iters: int = 12 # joint micro-steps per ACT step
+ n_aol_layers: int = 2 # depth of ψ AOL block
+ kappa: float = 0.9
+ eta: float = 1.0
+ alpha: float = 1.0
+
+ # Shared with HRM
+ hidden_size: int
+ halt_max_steps: int
+ halt_exploration_prob: float = 0.1
+ forward_dtype: str = "bfloat16"
+
+
+# =============================================================================
+# SRM joint step (replaces HRM's H/L transformer blocks)
+# =============================================================================
+
+class StableRecursionModel_ACTV1Block(nn.Module):
+ """One SRM joint step on (h, l). Per-step Lip_P ≤ (1-α) + α·κ < 1."""
+ def __init__(self, config: StableRecursionModel_ACTV1Config, seq_full: int) -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.kappa = config.kappa
+ self.eta = config.eta
+ self.alpha = config.alpha
+ cast = getattr(torch, config.forward_dtype)
+
+ joint_dim = 2 * config.hidden_size
+ self.psi = AOLTokenMixer(seq_len=seq_full, dim=joint_dim,
+ n_layers=config.n_aol_layers, cast_to=cast)
+ self.gain = BlockGain(kappa=config.kappa, eta=config.eta)
+ self.U_HL = CayleyOrthogonal(config.hidden_size, cast_to=cast)
+ self.U_LH = CayleyOrthogonal(config.hidden_size, cast_to=cast)
+
+ # Input biases — unconstrained (only affect Lip w.r.t. x, not w.r.t. z;
+ # x is fixed across recursion so doesn't affect Lyapunov)
+ self.bias_in_h = CastedLinear(config.hidden_size, config.hidden_size, bias=True)
+ self.bias_in_l = CastedLinear(config.hidden_size, config.hidden_size, bias=True)
+ self.bias_out_h = CastedLinear(config.hidden_size, config.hidden_size, bias=True)
+ self.bias_out_l = CastedLinear(config.hidden_size, config.hidden_size, bias=True)
+
+ def forward(self, h: torch.Tensor, l: torch.Tensor, input_emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ sqrt_eta = math.sqrt(self.eta)
+
+ # 1. Join with input bias
+ h_in = h + self.bias_in_h(input_emb)
+ l_in = l + self.bias_in_l(input_emb)
+ z = torch.cat([h_in, sqrt_eta * l_in], dim=-1) # (B, seq, 2h)
+
+ # 2. 1-Lipschitz feature map ψ
+ psi = self.psi(z)
+ psi_h, psi_l_scaled = psi.chunk(2, dim=-1)
+ psi_l = psi_l_scaled / sqrt_eta
+
+ # 3. Block-gain matrix A (κ-bounded row sums)
+ a_HH, a_HL, a_LH, a_LL = self.gain()
+ U_HL = self.U_HL()
+ U_LH = self.U_LH()
+ psi_l_mix = F.linear(psi_l, U_HL) # ≡ psi_l @ U_HL.T
+ psi_h_mix = F.linear(psi_h, U_LH)
+ Az_h = a_HH * psi_h + a_HL * psi_l_mix
+ Az_l = a_LH * psi_h_mix + a_LL * psi_l
+
+ # 4. Damped update + output bias
+ h_new = (1.0 - self.alpha) * h + self.alpha * Az_h + self.bias_out_h(input_emb)
+ l_new = (1.0 - self.alpha) * l + self.alpha * Az_l + self.bias_out_l(input_emb)
+ return h_new, l_new
+
+
+# =============================================================================
+# Inner model + ACT wrapper (matches HRM_ACTV1 interface)
+# =============================================================================
+
+class StableRecursionModel_ACTV1_Inner(nn.Module):
+ def __init__(self, config: StableRecursionModel_ACTV1Config) -> None:
+ super().__init__()
+ self.config = config
+ self.forward_dtype = getattr(torch, config.forward_dtype)
+
+ self.embed_scale = math.sqrt(config.hidden_size)
+ embed_init_std = 1.0 / self.embed_scale
+
+ self.embed_tokens = CastedEmbedding(config.vocab_size, config.hidden_size,
+ init_std=embed_init_std, cast_to=self.forward_dtype)
+ self.lm_head = CastedLinear(config.hidden_size, config.vocab_size, bias=False)
+ self.q_head = CastedLinear(config.hidden_size, 2, bias=True)
+ with torch.no_grad():
+ self.q_head.weight.zero_()
+ self.q_head.bias.fill_(-5)
+
+ self.puzzle_emb_len = -(config.puzzle_emb_ndim // -config.hidden_size)
+ if config.puzzle_emb_ndim > 0:
+ self.puzzle_emb = CastedSparseEmbedding(
+ config.num_puzzle_identifiers, config.puzzle_emb_ndim,
+ batch_size=config.batch_size, init_std=0, cast_to=self.forward_dtype,
+ )
+ seq_full = config.seq_len + self.puzzle_emb_len
+
+ # Single tied SRM block used n_iters times per ACT step
+ self.srm_block = StableRecursionModel_ACTV1Block(config, seq_full=seq_full)
+
+ self.H_init = nn.Buffer(
+ trunc_normal_init_(torch.empty(config.hidden_size, dtype=self.forward_dtype), std=1),
+ persistent=True,
+ )
+ self.L_init = nn.Buffer(
+ trunc_normal_init_(torch.empty(config.hidden_size, dtype=self.forward_dtype), std=1),
+ persistent=True,
+ )
+
+ def _input_embeddings(self, input_ids: torch.Tensor, puzzle_ids: torch.Tensor) -> torch.Tensor:
+ emb = self.embed_tokens(input_ids.to(torch.int32))
+ if self.config.puzzle_emb_ndim > 0:
+ puzzle_embedding = self.puzzle_emb(puzzle_ids)
+ pad = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
+ if pad > 0:
+ puzzle_embedding = F.pad(puzzle_embedding, (0, pad))
+ puzzle_embedding = puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size)
+ emb = torch.cat((puzzle_embedding, emb), dim=-2)
+ return self.embed_scale * emb
+
+ def empty_carry(self, batch_size: int) -> StableRecursionModel_ACTV1InnerCarry:
+ seq_full = self.config.seq_len + self.puzzle_emb_len
+ return StableRecursionModel_ACTV1InnerCarry(
+ z_H=torch.empty(batch_size, seq_full, self.config.hidden_size, dtype=self.forward_dtype),
+ z_L=torch.empty(batch_size, seq_full, self.config.hidden_size, dtype=self.forward_dtype),
+ )
+
+ def reset_carry(self, reset_flag: torch.Tensor,
+ carry: StableRecursionModel_ACTV1InnerCarry) -> StableRecursionModel_ACTV1InnerCarry:
+ return StableRecursionModel_ACTV1InnerCarry(
+ z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
+ z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
+ )
+
+ def forward(self, carry: StableRecursionModel_ACTV1InnerCarry,
+ batch: Dict[str, torch.Tensor]) -> Tuple[StableRecursionModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ input_emb = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
+
+ # n_iters - 1 no-grad iterations + 1 grad iteration (HRM-style DEQ truncation)
+ with torch.no_grad():
+ z_H, z_L = carry.z_H, carry.z_L
+ for _ in range(self.config.n_iters - 1):
+ z_H, z_L = self.srm_block(z_H, z_L, input_emb)
+ assert not z_H.requires_grad and not z_L.requires_grad
+
+ z_H, z_L = self.srm_block(z_H, z_L, input_emb)
+
+ new_carry = StableRecursionModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())
+ output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
+ q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
+ return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
+
+
+class StableRecursionModel_ACTV1(nn.Module):
+ """ACT wrapper — mirrors HierarchicalReasoningModel_ACTV1 1-to-1."""
+ def __init__(self, config_dict: dict):
+ super().__init__()
+ self.config = StableRecursionModel_ACTV1Config(**config_dict)
+ self.inner = StableRecursionModel_ACTV1_Inner(self.config)
+
+ @property
+ def puzzle_emb(self):
+ return self.inner.puzzle_emb
+
+ def initial_carry(self, batch: Dict[str, torch.Tensor]) -> StableRecursionModel_ACTV1Carry:
+ B = batch["inputs"].shape[0]
+ return StableRecursionModel_ACTV1Carry(
+ inner_carry=self.inner.empty_carry(B),
+ steps=torch.zeros((B,), dtype=torch.int32),
+ halted=torch.ones((B,), dtype=torch.bool),
+ current_data={k: torch.empty_like(v) for k, v in batch.items()},
+ )
+
+ def forward(self, carry: StableRecursionModel_ACTV1Carry,
+ batch: Dict[str, torch.Tensor]) -> Tuple[StableRecursionModel_ACTV1Carry, Dict[str, torch.Tensor]]:
+ new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
+ new_steps = torch.where(carry.halted, 0, carry.steps)
+ new_current_data = {
+ k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v)
+ for k, v in carry.current_data.items()
+ }
+
+ new_inner_carry, logits, (q_halt, q_continue) = self.inner(new_inner_carry, new_current_data)
+ outputs = {"logits": logits, "q_halt_logits": q_halt, "q_continue_logits": q_continue}
+
+ with torch.no_grad():
+ new_steps = new_steps + 1
+ is_last = new_steps >= self.config.halt_max_steps
+ halted = is_last
+
+ if self.training and self.config.halt_max_steps > 1:
+ halted = halted | (q_halt > q_continue)
+ min_halt = (torch.rand_like(q_halt) < self.config.halt_exploration_prob) * \
+ torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
+ halted = halted & (new_steps >= min_halt)
+
+ next_q_halt, next_q_continue = self.inner(new_inner_carry, new_current_data)[-1]
+ outputs["target_q_continue"] = torch.sigmoid(
+ torch.where(is_last, next_q_halt, torch.maximum(next_q_halt, next_q_continue))
+ )
+
+ return StableRecursionModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
+
+
+# =============================================================================
+# Empirical Lipschitz diagnostic
+# =============================================================================
+
+@torch.no_grad()
+def measure_lipschitz_constant(inner: StableRecursionModel_ACTV1_Inner,
+ sample_batch: Dict[str, torch.Tensor],
+ n_probes: int = 64, eps: float = 1e-3) -> Dict[str, float]:
+ """Estimate Lip_P of srm_block via random perturbations.
+
+ Returns ratio ||T(z+δ) - T(z)||_P / ||δ||_P. Should be ≤ (1-α) + α·κ.
+ """
+ cfg = inner.config
+ B = sample_batch["inputs"].shape[0]
+ seq_full = cfg.seq_len + inner.puzzle_emb_len
+ h = inner.H_init.unsqueeze(0).expand(B, seq_full, cfg.hidden_size).to(inner.forward_dtype).clone()
+ l = inner.L_init.unsqueeze(0).expand(B, seq_full, cfg.hidden_size).to(inner.forward_dtype).clone()
+ input_emb = inner._input_embeddings(sample_batch["inputs"], sample_batch["puzzle_identifiers"])
+ h_new, l_new = inner.srm_block(h, l, input_emb)
+
+ ratios = []
+ for _ in range(n_probes):
+ dh = torch.randn_like(h) * eps
+ dl = torch.randn_like(l) * eps
+ h_p, l_p = inner.srm_block(h + dh, l + dl, input_emb)
+ d_in_h = dh.float().flatten(1).pow(2).sum(1)
+ d_in_l = dl.float().flatten(1).pow(2).sum(1)
+ d_in_P = (d_in_h + cfg.eta * d_in_l).sqrt()
+ d_out_h = (h_p - h_new).float().flatten(1).pow(2).sum(1)
+ d_out_l = (l_p - l_new).float().flatten(1).pow(2).sum(1)
+ d_out_P = (d_out_h + cfg.eta * d_out_l).sqrt()
+ ratios.append((d_out_P / d_in_P.clamp_min(1e-12)).cpu())
+ R = torch.cat(ratios)
+ bound = (1 - cfg.alpha) + cfg.alpha * cfg.kappa
+ return {
+ "lip_emp_mean": float(R.mean()),
+ "lip_emp_max": float(R.max()),
+ "lip_emp_99p": float(R.quantile(0.99)),
+ "lip_theoretical_bound": float(bound),
+ "passes_bound": bool(R.max() <= bound * 1.05),
+ }
+
+
+if __name__ == "__main__":
+ cfg = dict(
+ batch_size=4, seq_len=81, vocab_size=11,
+ num_puzzle_identifiers=1, puzzle_emb_ndim=512,
+ hidden_size=256, n_iters=6, n_aol_layers=2,
+ kappa=0.9, eta=1.0, alpha=1.0,
+ halt_max_steps=4, halt_exploration_prob=0.1,
+ forward_dtype="bfloat16",
+ )
+ model = StableRecursionModel_ACTV1(cfg).cuda()
+ print(f"params={sum(p.numel() for p in model.parameters()):,}")
+
+ batch = {
+ "inputs": torch.randint(0, 11, (4, 81), dtype=torch.int32).cuda(),
+ "labels": torch.randint(0, 11, (4, 81), dtype=torch.int32).cuda(),
+ "puzzle_identifiers": torch.zeros(4, dtype=torch.int32).cuda(),
+ }
+ carry = model.initial_carry(batch)
+ carry.inner_carry.z_H = carry.inner_carry.z_H.cuda()
+ carry.inner_carry.z_L = carry.inner_carry.z_L.cuda()
+ carry.steps = carry.steps.cuda()
+ carry.halted = carry.halted.cuda()
+ for k in carry.current_data:
+ carry.current_data[k] = batch[k]
+
+ model.eval()
+ new_carry, out = model(carry, batch)
+ print(f"forward OK | logits={out['logits'].shape}")
+
+ lip = measure_lipschitz_constant(model.inner, batch, n_probes=32)
+ print(f"Lipschitz check: emp_max={lip['lip_emp_max']:.4f} bound={lip['lip_theoretical_bound']:.4f} "
+ f"passes={lip['passes_bound']} emp_mean={lip['lip_emp_mean']:.4f}")