diff options
Diffstat (limited to 'trm/models')
| -rw-r--r-- | trm/models/common.py | 32 | ||||
| -rw-r--r-- | trm/models/ema.py | 40 | ||||
| -rw-r--r-- | trm/models/layers.py | 169 | ||||
| -rw-r--r-- | trm/models/losses.py | 103 | ||||
| -rw-r--r-- | trm/models/recursive_reasoning/hrm.py | 294 | ||||
| -rw-r--r-- | trm/models/recursive_reasoning/transformers_baseline.py | 342 | ||||
| -rw-r--r-- | trm/models/recursive_reasoning/trm.py | 297 | ||||
| -rw-r--r-- | trm/models/recursive_reasoning/trm_hier6.py | 323 | ||||
| -rw-r--r-- | trm/models/recursive_reasoning/trm_singlez.py | 294 | ||||
| -rw-r--r-- | trm/models/sparse_embedding.py | 132 |
10 files changed, 2026 insertions, 0 deletions
diff --git a/trm/models/common.py b/trm/models/common.py new file mode 100644 index 0000000..1a04505 --- /dev/null +++ b/trm/models/common.py @@ -0,0 +1,32 @@ +import math + +import torch +from torch import nn + + +def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): + # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor + # This function is a PyTorch version of jax truncated normal init (default init method in flax) + # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848 + # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199 + + with torch.no_grad(): + if std == 0: + tensor.zero_() + else: + sqrt2 = math.sqrt(2) + a = math.erf(lower / sqrt2) + b = math.erf(upper / sqrt2) + z = (b - a) / 2 + + c = (2 * math.pi) ** -0.5 + pdf_u = c * math.exp(-0.5 * lower ** 2) + pdf_l = c * math.exp(-0.5 * upper ** 2) + comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) + + tensor.uniform_(a, b) + tensor.erfinv_() + tensor.mul_(sqrt2 * comp_std) + tensor.clip_(lower * comp_std, upper * comp_std) + + return tensor diff --git a/trm/models/ema.py b/trm/models/ema.py new file mode 100644 index 0000000..2e52933 --- /dev/null +++ b/trm/models/ema.py @@ -0,0 +1,40 @@ +import copy +import torch.nn as nn + +class EMAHelper(object): + def __init__(self, mu=0.999): + self.mu = mu + self.shadow = {} + + def register(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() + + def update(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data + + def ema(self, module): + if isinstance(module, nn.DataParallel): + module = module.module + for name, param in module.named_parameters(): + if param.requires_grad: + param.data.copy_(self.shadow[name].data) + + def ema_copy(self, module): + module_copy = copy.deepcopy(module) + self.ema(module_copy) + return module_copy + + def state_dict(self): + return self.shadow + + def load_state_dict(self, state_dict): + self.shadow = state_dict + diff --git a/trm/models/layers.py b/trm/models/layers.py new file mode 100644 index 0000000..705bcaf --- /dev/null +++ b/trm/models/layers.py @@ -0,0 +1,169 @@ +from typing import Tuple +import einops +import torch +from torch import nn +import torch.nn.functional as F + +#try: +# from flash_attn_interface import flash_attn_func # type: ignore[import] +#except ImportError: +# # Fallback to FlashAttention 2 +# from flash_attn import flash_attn_func # type: ignore[import] +from torch.nn.functional import scaled_dot_product_attention + +from models.common import trunc_normal_init_ + + +CosSin = Tuple[torch.Tensor, torch.Tensor] + + +def _find_multiple(a, b): + return (-(a // -b)) * b + + +def rotate_half(x: torch.Tensor): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + # q, k: [bs, seq_len, num_heads, head_dim] + # cos, sin: [seq_len, head_dim] + orig_dtype = q.dtype + q = q.to(cos.dtype) + k = k.to(cos.dtype) + + q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2)) + k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2)) + + return q_embed.to(orig_dtype), k_embed.to(orig_dtype) + + +class CastedLinear(nn.Module): + def __init__(self, + in_features: int, + out_features: int, + bias: bool): + super().__init__() + # Truncated LeCun normal init + self.weight = nn.Parameter( + trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5)) + ) + self.bias = None + if bias: + # Zero init bias + self.bias = nn.Parameter(torch.zeros((out_features, ))) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None) + + +class CastedEmbedding(nn.Module): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + init_std: float, + cast_to: torch.dtype): + super().__init__() + self.cast_to = cast_to + + # Truncated LeCun normal init + self.embedding_weight = nn.Parameter( + trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.embedding(input, self.embedding_weight.to(self.cast_to)) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings, base, device=None): + super().__init__() + + # RoPE + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) + freqs = torch.outer(t, inv_freq) + + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = nn.Buffer(emb.cos(), persistent=False) + self.sin_cached = nn.Buffer(emb.sin(), persistent=False) + + def forward(self): + return self.cos_cached, self.sin_cached + + +class Attention(nn.Module): + def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False): + super().__init__() + + self.hidden_size = hidden_size + self.head_dim = head_dim + self.output_size = head_dim * num_heads + self.num_heads = num_heads + self.num_key_value_heads = num_key_value_heads + self.causal = causal + + self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False) + self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False) + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = hidden_states.shape + + # hidden_states: [bs, seq_len, num_heads, head_dim] + qkv = self.qkv_proj(hidden_states) + + # Split head + qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + query = qkv[:, :, :self.num_heads] + key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads] + value = qkv[:, :, self.num_heads + self.num_key_value_heads:] + + # RoPE + if cos_sin is not None: + cos, sin = cos_sin + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + # flash attn + query, key, value = map(lambda t: einops.rearrange(t, 'B S H D -> B H S D'), (query, key, value)) # needed for scaled_dot_product_attention but not flash_attn_func + attn_output = scaled_dot_product_attention(query=query, key=key, value=value, is_causal=self.causal) + attn_output = einops.rearrange(attn_output, 'B H S D -> B S H D') + attn_output = attn_output.reshape(batch_size, seq_len, self.output_size) # type: ignore + return self.o_proj(attn_output) + +class LinearSwish(nn.Module): + def __init__(self, hidden_size: int, reverse=False): + super().__init__() + + self.linear = CastedLinear(hidden_size, hidden_size, bias=False) + self.reverse = reverse + + def forward(self, x): + if self.reverse: + return F.silu(self.linear(x)) + else: + return self.linear(F.silu(x)) + + +class SwiGLU(nn.Module): + def __init__(self, hidden_size: int, expansion: float): + super().__init__() + inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256) + + self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False) + self.down_proj = CastedLinear(inter, hidden_size, bias=False) + + def forward(self, x): + gate, up = self.gate_up_proj(x).chunk(2, dim=-1) + return self.down_proj(F.silu(gate) * up) + +def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + variance = hidden_states.square().mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) + return hidden_states.to(input_dtype) diff --git a/trm/models/losses.py b/trm/models/losses.py new file mode 100644 index 0000000..2597cd6 --- /dev/null +++ b/trm/models/losses.py @@ -0,0 +1,103 @@ +from typing import Any, Tuple, Dict, Sequence, Optional + +import torch +import torch.nn.functional as F +from torch import nn +import math + +IGNORE_LABEL_ID = -100 + + +def s(x, epsilon=1e-30): + return torch.where( + x<0, + 1/(1-x+ epsilon), + x + 1 + ) + + +def log_stablemax(x, dim=-1): + s_x = s(x) + return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True)) + + +def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None): + logprobs = log_stablemax(logits.to(torch.float64), dim=-1) + + if valid_mask is None: + valid_mask = (labels != ignore_index) + transformed_labels = torch.where(valid_mask, labels, 0) + prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1) + + return -torch.where(valid_mask, prediction_logprobs, 0) + + +def softmax_cross_entropy(logits, labels, ignore_index: int = -100): + # Cast logits to f32 + # Flatten logits + return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape) + + +class ACTLossHead(nn.Module): + def __init__(self, model: nn.Module, loss_type: str): + super().__init__() + self.model = model + self.loss_fn = globals()[loss_type] + + def initial_carry(self, *args, **kwargs): + return self.model.initial_carry(*args, **kwargs) # type: ignore + + def forward( + self, + return_keys: Sequence[str], + # Model args + **model_kwargs, + ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]: + # Model logits + # B x SeqLen x D + new_carry, outputs = self.model(**model_kwargs) + labels = new_carry.current_data["labels"] + + with torch.no_grad(): + # Preds + outputs["preds"] = torch.argmax(outputs["logits"], dim=-1) + + # Correctness + mask = (labels != IGNORE_LABEL_ID) + loss_counts = mask.sum(-1) + loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division + + is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) + seq_is_correct = is_correct.sum(-1) == loss_counts + + # Metrics (halted) + valid_metrics = new_carry.halted & (loss_counts > 0) + metrics = { + "count": valid_metrics.sum(), + + "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), + "exact_accuracy": (valid_metrics & seq_is_correct).sum(), + + "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), + "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(), + } + + # Losses + + lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask) / loss_divisor).sum() + q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum") + metrics.update({ + "lm_loss": lm_loss.detach(), + "q_halt_loss": q_halt_loss.detach(), + }) + # Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary + q_continue_loss = 0 + if "target_q_continue" in outputs: + q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum") + + metrics["q_continue_loss"] = q_continue_loss.detach() + # Filter outputs for return + detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs} + + return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all() + diff --git a/trm/models/recursive_reasoning/hrm.py b/trm/models/recursive_reasoning/hrm.py new file mode 100644 index 0000000..9a1a503 --- /dev/null +++ b/trm/models/recursive_reasoning/hrm.py @@ -0,0 +1,294 @@ +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math +import torch +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel + +from models.common import trunc_normal_init_ +from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + +@dataclass +class HierarchicalReasoningModel_ACTV1InnerCarry: + z_H: torch.Tensor + z_L: torch.Tensor + + +@dataclass +class HierarchicalReasoningModel_ACTV1Carry: + inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class HierarchicalReasoningModel_ACTV1Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + L_cycles: int + + H_layers: int + L_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + + forward_dtype: str = "bfloat16" + + # Alexia: added + mlp_t: bool=False # use mlp on L instead of transformer + +class HierarchicalReasoningModel_ACTV1Block(nn.Module): + def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: + super().__init__() + + self.config = config + if self.config.mlp_t: + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) + self.mlp_t = SwiGLU( + hidden_size=self.config.seq_len + self.puzzle_emb_len, # L + expansion=config.expansion, + ) + else: + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # B, L, D = hidden_states.shape + # Post Norm + if self.config.mlp_t: + hidden_states = hidden_states.transpose(1,2) + out = self.mlp_t(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + hidden_states = hidden_states.transpose(1,2) + else: + # Self Attention + hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + # Fully Connected + out = self.mlp(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + return hidden_states + +class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module): + def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]): + super().__init__() + + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: + # Input injection (add) + hidden_states = hidden_states + input_injection + # Layers + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + + return hidden_states + + +class HierarchicalReasoningModel_ACTV1_Inner(nn.Module): + def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + else: + pass + + # Reasoning Layers + self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) + self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # Initial states + self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return HierarchicalReasoningModel_ACTV1InnerCarry( + z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry): + return HierarchicalReasoningModel_ACTV1InnerCarry( + z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + ) + + def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Forward iterations + with torch.no_grad(): + z_H, z_L = carry.z_H, carry.z_L + for _H_step in range(self.config.H_cycles): + for _L_step in range(self.config.L_cycles): + if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)): + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + if not (_H_step == self.config.H_cycles - 1): + z_H = self.H_level(z_H, z_L, **seq_info) + assert not z_H.requires_grad and not z_L.requires_grad + # 1-step grad + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + z_H = self.H_level(z_H, z_L, **seq_info) + + # LM Outputs + new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad + output = self.lm_head(z_H)[:, self.puzzle_emb_len:] + + # Q head + q_logits = self.q_head(z_H[:, 0]).to(torch.float32) + + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class HierarchicalReasoningModel_ACTV1(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict) + self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return HierarchicalReasoningModel_ACTV1Carry( + inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. + + steps=torch.zeros((batch_size, ), dtype=torch.int32), + halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted + + current_data={k: torch.empty_like(v) for k, v in batch.items()} + ) + + def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) + + outputs = { + "logits": logits, + "q_halt_logits": q_halt_logits, + "q_continue_logits": q_continue_logits + } + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # if training, and ACT is enabled + if self.training and (self.config.halt_max_steps > 1): + # Halt signal + # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes + halted = halted | (q_halt_logits > q_continue_logits) + + # Exploration + min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + + halted = halted & (new_steps >= min_halt_steps) + + # Compute target Q + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1] + + outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) + + return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/trm/models/recursive_reasoning/transformers_baseline.py b/trm/models/recursive_reasoning/transformers_baseline.py new file mode 100644 index 0000000..7a08acc --- /dev/null +++ b/trm/models/recursive_reasoning/transformers_baseline.py @@ -0,0 +1,342 @@ +""" +HRM ACT V2: Transformer Baseline for Architecture Ablation + +This is an architecture ablation of the Hierarchical Reasoning Model (HRM). +Key changes from V1: +1. REMOVED hierarchical split (no separate H and L levels) +2. REMOVED inner cycles (no H_cycles/L_cycles loops within reasoning) +3. KEPT ACT outer loop structure intact +4. KEPT all data preprocessing, embeddings, and evaluation infrastructure + +Architecture: Single-level transformer that processes the full 30x30 grid as a +900-token sequence, with the same positional encodings and sparse embeddings as V1. + +""" + +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math + +import torch +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel + +from models.common import trunc_normal_init_ +from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + + +@dataclass +class Model_ACTV2InnerCarry: + z_H: torch.Tensor + + +@dataclass +class Model_ACTV2Carry: + inner_carry: Model_ACTV2InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class Model_ACTV2Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + + H_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + act_enabled: bool = True # If False, always run halt_max_steps (no early stopping during training) + act_inference: bool = False # If True, use adaptive computation during inference + + forward_dtype: str = "bfloat16" + + +class Model_ACTV2Block(nn.Module): + def __init__(self, config: Model_ACTV2Config) -> None: + super().__init__() + + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False, + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # Post Norm + # Self Attention + hidden_states = rms_norm( + hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), + variance_epsilon=self.norm_eps, + ) + # Fully Connected + hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps) + return hidden_states + + +class Model_ACTV2ReasoningModule(nn.Module): + def __init__(self, layers: List[Model_ACTV2Block]): + super().__init__() + + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: + # Input injection (add) + hidden_states = hidden_states + input_injection + # Layers + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + + return hidden_states + + +class Model_ACTV2_Inner(nn.Module): + def __init__(self, config: Model_ACTV2Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding( + self.config.vocab_size, + self.config.hidden_size, + init_std=embed_init_std, + cast_to=self.forward_dtype, + ) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding( + self.config.num_puzzle_identifiers, + self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, + init_std=0, + cast_to=self.forward_dtype, + ) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding( + dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta, + ) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding( + self.config.seq_len + self.puzzle_emb_len, + self.config.hidden_size, + init_std=embed_init_std, + cast_to=self.forward_dtype, + ) + else: + raise NotImplementedError() + + # Reasoning Layers + self.H_level = Model_ACTV2ReasoningModule( + layers=[Model_ACTV2Block(self.config) for _i in range(self.config.H_layers)] + ) + + # Initial states + self.H_init = nn.Buffer( + trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), + persistent=True, + ) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat( + (puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2 + ) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return Model_ACTV2InnerCarry( + z_H=torch.empty( + batch_size, + self.config.seq_len + self.puzzle_emb_len, + self.config.hidden_size, + dtype=self.forward_dtype, + ), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: Model_ACTV2InnerCarry): + return Model_ACTV2InnerCarry( + z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + ) + + def forward( + self, carry: Model_ACTV2InnerCarry, batch: Dict[str, torch.Tensor] + ) -> Tuple[Model_ACTV2InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # 1-step grad + z_H = self.H_level(carry.z_H, input_embeddings, **seq_info) + + # LM Outputs + new_carry = Model_ACTV2InnerCarry( + z_H=z_H.detach(), + ) # New carry no grad + output = self.lm_head(z_H)[:, self.puzzle_emb_len :] + + # Q head + q_logits = self.q_head(z_H[:, 0]).to(torch.float32) + + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class Model_ACTV2(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = Model_ACTV2Config(**config_dict) + self.inner = Model_ACTV2_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return Model_ACTV2Carry( + inner_carry=self.inner.empty_carry( + batch_size + ), # Empty is expected, it will be reseted in first pass as all sequences are halted. + steps=torch.zeros((batch_size,), dtype=torch.int32), + halted=torch.ones((batch_size,), dtype=torch.bool), # Default to halted + current_data={k: torch.empty_like(v) for k, v in batch.items()}, + ) + + def forward( + self, + carry: Model_ACTV2Carry, + batch: Dict[str, torch.Tensor], + compute_target_q: bool = False, + ) -> Tuple[Model_ACTV2Carry, Dict[str, torch.Tensor]]: + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = { + k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), batch[k], v) + for k, v in carry.current_data.items() + } + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner( + new_inner_carry, new_current_data + ) + + outputs = {"logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits} + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # Check if adaptive computation should be used + use_adaptive = (self.config.halt_max_steps > 1) and ( + (self.training and self.config.act_enabled) + or (not self.training and self.config.act_inference) + ) + + if use_adaptive: + # Halt signal based on Q-values (but always halt at max steps) + q_halt_signal = q_halt_logits > q_continue_logits + halted = halted | q_halt_signal + + # Store actual steps used for logging (only during inference) + if not self.training: + outputs["actual_steps"] = new_steps.float() + + # Exploration (only during training) + if self.training: + min_halt_steps = ( + torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob + ) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + halted = halted & (new_steps >= min_halt_steps) + + # Compute target Q (only during training) + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + if self.training and compute_target_q: + next_q_halt_logits, next_q_continue_logits = self.inner( + new_inner_carry, new_current_data + )[-1] + + outputs["target_q_continue"] = torch.sigmoid( + torch.where( + is_last_step, + next_q_halt_logits, + torch.maximum(next_q_halt_logits, next_q_continue_logits), + ) + ) + + return Model_ACTV2Carry( + new_inner_carry, new_steps, halted, new_current_data + ), outputs diff --git a/trm/models/recursive_reasoning/trm.py b/trm/models/recursive_reasoning/trm.py new file mode 100644 index 0000000..5c3e39d --- /dev/null +++ b/trm/models/recursive_reasoning/trm.py @@ -0,0 +1,297 @@ +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math +import torch +import copy +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel +import random +from models.common import trunc_normal_init_ +from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + +IGNORE_LABEL_ID = -100 + +@dataclass +class TinyRecursiveReasoningModel_ACTV1InnerCarry: + z_H: torch.Tensor + z_L: torch.Tensor + + +@dataclass +class TinyRecursiveReasoningModel_ACTV1Carry: + inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class TinyRecursiveReasoningModel_ACTV1Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + L_cycles: int + + H_layers: int # ignored + L_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + + forward_dtype: str = "bfloat16" + + # Alexia: added + mlp_t: bool = False # use mlp on L instead of transformer + puzzle_emb_len: int = 16 # if non-zero, its specified to this value + no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense + +class TinyRecursiveReasoningModel_ACTV1Block(nn.Module): + def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: + super().__init__() + + self.config = config + if self.config.mlp_t: + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len + self.mlp_t = SwiGLU( + hidden_size=self.config.seq_len + self.puzzle_emb_len, # L + expansion=config.expansion, + ) + else: + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # B, L, D = hidden_states.shape + # Post Norm + if self.config.mlp_t: + hidden_states = hidden_states.transpose(1,2) + out = self.mlp_t(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + hidden_states = hidden_states.transpose(1,2) + else: + # Self Attention + hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + # Fully Connected + out = self.mlp(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + return hidden_states + +class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module): + def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]): + super().__init__() + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: + hidden_states = hidden_states + input_injection + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + return hidden_states + + +class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module): + def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + else: + pass + + # Reasoning Layers + self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # Initial states + self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + ) + + def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Forward iterations + it = 0 + z_H, z_L = carry.z_H, carry.z_L + # H_cycles-1 without grad + with torch.no_grad(): + for _H_step in range(self.config.H_cycles-1): + for _L_step in range(self.config.L_cycles): + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + z_H = self.L_level(z_H, z_L, **seq_info) + # 1 with grad + for _L_step in range(self.config.L_cycles): + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + z_H = self.L_level(z_H, z_L, **seq_info) + + # LM Outputs + new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad + output = self.lm_head(z_H)[:, self.puzzle_emb_len:] + q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class TinyRecursiveReasoningModel_ACTV1(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict) + self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return TinyRecursiveReasoningModel_ACTV1Carry( + inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. + + steps=torch.zeros((batch_size, ), dtype=torch.int32), + halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted + + current_data={k: torch.empty_like(v) for k, v in batch.items()} + ) + + def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: + + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) + + outputs = { + "logits": logits, + "q_halt_logits": q_halt_logits, + "q_continue_logits": q_continue_logits + } + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # if training, and ACT is enabled + if self.training and (self.config.halt_max_steps > 1): + + # Halt signal + # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes + + if self.config.no_ACT_continue: + halted = halted | (q_halt_logits > 0) + else: + halted = halted | (q_halt_logits > q_continue_logits) + + # Exploration + min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + halted = halted & (new_steps >= min_halt_steps) + + if not self.config.no_ACT_continue: + # Compute target Q + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data) + outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) + + return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/trm/models/recursive_reasoning/trm_hier6.py b/trm/models/recursive_reasoning/trm_hier6.py new file mode 100644 index 0000000..c654474 --- /dev/null +++ b/trm/models/recursive_reasoning/trm_hier6.py @@ -0,0 +1,323 @@ +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math +import torch +import copy +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel +import random +from models.common import trunc_normal_init_ +from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + +IGNORE_LABEL_ID = -100 + +@dataclass +class TinyRecursiveReasoningModel_ACTV1InnerCarry: + z_H: torch.Tensor + z_L1: torch.Tensor + z_L2: torch.Tensor + z_L3: torch.Tensor + z_L4: torch.Tensor + z_L5: torch.Tensor + z_L6: torch.Tensor + + + +@dataclass +class TinyRecursiveReasoningModel_ACTV1Carry: + inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class TinyRecursiveReasoningModel_ACTV1Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + L_cycles: int + + H_layers: int # ignored + L_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + + forward_dtype: str = "bfloat16" + + # Alexia: added + mlp_t: bool = False # use mlp on L instead of transformer + puzzle_emb_len: int = 16 # if non-zero, its specified to this value + no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense + +class TinyRecursiveReasoningModel_ACTV1Block(nn.Module): + def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: + super().__init__() + + self.config = config + if self.config.mlp_t: + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len + self.mlp_t = SwiGLU( + hidden_size=self.config.seq_len + self.puzzle_emb_len, # L + expansion=config.expansion, + ) + else: + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # B, L, D = hidden_states.shape + # Post Norm + if self.config.mlp_t: + hidden_states = hidden_states.transpose(1,2) + out = self.mlp_t(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + hidden_states = hidden_states.transpose(1,2) + else: + # Self Attention + hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + # Fully Connected + out = self.mlp(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + return hidden_states + +class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module): + def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]): + super().__init__() + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: + hidden_states = hidden_states + input_injection + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + return hidden_states + + +class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module): + def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + else: + pass + + # Reasoning Layers + self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # Initial states + self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L1_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L2_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L3_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L4_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L5_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L6_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L1=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L2=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L3=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L4=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L5=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L6=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + z_L1=torch.where(reset_flag.view(-1, 1, 1), self.L1_init, carry.z_L1), + z_L2=torch.where(reset_flag.view(-1, 1, 1), self.L2_init, carry.z_L2), + z_L3=torch.where(reset_flag.view(-1, 1, 1), self.L3_init, carry.z_L3), + z_L4=torch.where(reset_flag.view(-1, 1, 1), self.L4_init, carry.z_L4), + z_L5=torch.where(reset_flag.view(-1, 1, 1), self.L5_init, carry.z_L5), + z_L6=torch.where(reset_flag.view(-1, 1, 1), self.L6_init, carry.z_L6), + ) + + + def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Forward iterations + it = 0 + z_H, z_L = carry.z_H, [carry.z_L1, carry.z_L2, carry.z_L3, carry.z_L4, carry.z_L5, carry.z_L6] + # H_cycles-1 without grad + with torch.no_grad(): + for _H_step in range(self.config.H_cycles-1): + for _L_step in range(self.config.L_cycles): + z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5] + z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info) + z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5] + z_H = self.L_level(z_H, z_L_, **seq_info) + # 1 with grad + for _L_step in range(self.config.L_cycles): + z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5] + z_L[_L_step] = self.L_level(z_L_, z_H + input_embeddings, **seq_info) + z_L_ = z_L[0] + z_L[1] + z_L[2] + z_L[3] + z_L[4] + z_L[5] + z_H = self.L_level(z_H, z_L_, **seq_info) + + # LM Outputs + new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L1=z_L[0].detach(), z_L2=z_L[1].detach(), z_L3=z_L[2].detach(), z_L4=z_L[3].detach(), z_L5=z_L[4].detach(), z_L6=z_L[5].detach()) # New carry no grad + output = self.lm_head(z_H)[:, self.puzzle_emb_len:] + q_logits = self.q_head(z_H[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class TinyRecursiveReasoningModel_ACTV1(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict) + self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return TinyRecursiveReasoningModel_ACTV1Carry( + inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. + + steps=torch.zeros((batch_size, ), dtype=torch.int32), + halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted + + current_data={k: torch.empty_like(v) for k, v in batch.items()} + ) + + def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: + + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) + + outputs = { + "logits": logits, + "q_halt_logits": q_halt_logits, + "q_continue_logits": q_continue_logits + } + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # if training, and ACT is enabled + if self.training and (self.config.halt_max_steps > 1): + + # Halt signal + # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes + + if self.config.no_ACT_continue: + halted = halted | (q_halt_logits > 0) + else: + halted = halted | (q_halt_logits > q_continue_logits) + + # Exploration + min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + halted = halted & (new_steps >= min_halt_steps) + + if not self.config.no_ACT_continue: + # Compute target Q + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data) + outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) + + return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/trm/models/recursive_reasoning/trm_singlez.py b/trm/models/recursive_reasoning/trm_singlez.py new file mode 100644 index 0000000..e5e1a7d --- /dev/null +++ b/trm/models/recursive_reasoning/trm_singlez.py @@ -0,0 +1,294 @@ +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math +import torch +import copy +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel +import random +from models.common import trunc_normal_init_ +from models.layers import rms_norm, LinearSwish, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + +IGNORE_LABEL_ID = -100 + +@dataclass +class TinyRecursiveReasoningModel_ACTV1InnerCarry: + z_L: torch.Tensor + + + +@dataclass +class TinyRecursiveReasoningModel_ACTV1Carry: + inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class TinyRecursiveReasoningModel_ACTV1Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + L_cycles: int + + H_layers: int # ignored + L_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + + forward_dtype: str = "bfloat16" + + # Alexia: added + mlp_t: bool = False # use mlp on L instead of transformer + puzzle_emb_len: int = 16 # if non-zero, its specified to this value + no_ACT_continue: bool = True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense + +class TinyRecursiveReasoningModel_ACTV1Block(nn.Module): + def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: + super().__init__() + + self.config = config + if self.config.mlp_t: + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len + self.mlp_t = SwiGLU( + hidden_size=self.config.seq_len + self.puzzle_emb_len, # L + expansion=config.expansion, + ) + else: + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # B, L, D = hidden_states.shape + # Post Norm + if self.config.mlp_t: + hidden_states = hidden_states.transpose(1,2) + out = self.mlp_t(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + hidden_states = hidden_states.transpose(1,2) + else: + # Self Attention + hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + # Fully Connected + out = self.mlp(hidden_states) + hidden_states = rms_norm(hidden_states + out, variance_epsilon=self.norm_eps) + return hidden_states + +class TinyRecursiveReasoningModel_ACTV1ReasoningModule(nn.Module): + def __init__(self, layers: List[TinyRecursiveReasoningModel_ACTV1Block]): + super().__init__() + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + return hidden_states + + +class TinyRecursiveReasoningModel_ACTV1_Inner(nn.Module): + def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + else: + pass + + # Reasoning Layers + self.L_level = TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers=[TinyRecursiveReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # Initial states + self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry): + return TinyRecursiveReasoningModel_ACTV1InnerCarry( + z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + ) + + def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Forward iterations + it = 0 + z_L = carry.z_L + # H_cycles-1 without grad + with torch.no_grad(): + for _H_step in range(self.config.H_cycles-1): + for _L_step in range(self.config.L_cycles): + z_L = self.L_level(z_L + input_embeddings, **seq_info) + z_L = self.L_level(z_L, **seq_info) + # 1 with grad + for _L_step in range(self.config.L_cycles): + z_L = self.L_level(z_L + input_embeddings, **seq_info) + z_L = self.L_level(z_L, **seq_info) + z_out = z_L + + # LM Outputs + new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(z_L=z_L.detach()) # New carry no grad + output = self.lm_head(z_out)[:, self.puzzle_emb_len:] + q_logits = self.q_head(z_out[:, 0]).to(torch.float32) # Q-head; uses the first puzzle_emb position + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class TinyRecursiveReasoningModel_ACTV1(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict) + self.inner = TinyRecursiveReasoningModel_ACTV1_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return TinyRecursiveReasoningModel_ACTV1Carry( + inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. + + steps=torch.zeros((batch_size, ), dtype=torch.int32), + halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted + + current_data={k: torch.empty_like(v) for k, v in batch.items()} + ) + + def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: + + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) + + outputs = { + "logits": logits, + "q_halt_logits": q_halt_logits, + "q_continue_logits": q_continue_logits + } + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # if training, and ACT is enabled + if self.training and (self.config.halt_max_steps > 1): + + # Halt signal + # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes + + if self.config.no_ACT_continue: + halted = halted | (q_halt_logits > 0) + else: + halted = halted | (q_halt_logits > q_continue_logits) + + # Exploration + min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + halted = halted & (new_steps >= min_halt_steps) + + if not self.config.no_ACT_continue: + # Compute target Q + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data) + outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) + + return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/trm/models/sparse_embedding.py b/trm/models/sparse_embedding.py new file mode 100644 index 0000000..f369205 --- /dev/null +++ b/trm/models/sparse_embedding.py @@ -0,0 +1,132 @@ +from typing import Union + +import torch +from torch import nn +import torch.distributed as dist +from torch.optim.optimizer import Optimizer, ParamsT + +from models.common import trunc_normal_init_ + + +class CastedSparseEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): + super().__init__() + self.cast_to = cast_to + + # Real Weights + # Truncated LeCun normal init + self.weights = nn.Buffer( + trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True + ) + + # Local weights and IDs + # Local embeddings, with gradient, not persistent + self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) + # Local embedding IDs, not persistent + self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if not self.training: + # Test mode, no gradient + return self.weights[inputs].to(self.cast_to) + + # Training mode, fill puzzle embedding from weights + with torch.no_grad(): + self.local_weights.copy_(self.weights[inputs]) + self.local_ids.copy_(inputs) + + return self.local_weights.to(self.cast_to) + + +class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): + def __init__( + self, + params: ParamsT, + + world_size: int, + lr: Union[float, torch.Tensor] = 1e-3, + weight_decay: float = 1e-2, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + weight_decay=weight_decay, + world_size=world_size + ) + super().__init__(params, defaults) + + @torch.no_grad + def step(self, closure=None): # type: ignore + for group in self.param_groups: + # Find the sparse embedding weights + local_weights_grad = None + local_ids = None + weights = None + + assert len(group["params"]) == 3 + for p in group["params"]: + if p.requires_grad: + local_weights_grad = p.grad + elif p.ndim == 1: + local_ids = p + elif p.ndim == 2: + weights = p + else: + assert False + + assert local_ids is not None + assert weights is not None + + # Apply SignSGD + # Adam ≈ SignSGD if gradient is very sparse + if local_weights_grad is not None: + _sparse_emb_signsgd_dist( + local_weights_grad, + local_ids, + weights, + + lr=group["lr"], + weight_decay=group["weight_decay"], + world_size=group["world_size"] + ) + + +def _sparse_emb_signsgd_dist( + local_weights_grad: torch.Tensor, + local_ids: torch.Tensor, + weights: torch.Tensor, + + lr: float, + weight_decay: float, + world_size: int +) -> None: + N, D = local_weights_grad.shape + + # All-gather + all_weights_grad = local_weights_grad + all_ids = local_ids + + if world_size > 1: + all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) + all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) + + dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) + dist.all_gather_into_tensor(all_ids, local_ids) + + # Unique + grad_ids, inv = all_ids.unique(return_inverse=True) + + grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) + grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) + + # SignSGD with decoupled weight decay + p = weights[grad_ids] + + p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) + + # Write updated slices back + weights[grad_ids] = p |
