summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
commit6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch)
treed7c63adcd19c4f5d46c8a937e5047fece55dea62 /models
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching. Credit bridge matches state bridge on linear system (~0.94 cosine). CIFAR experiments in progress.
Diffstat (limited to 'models')
-rw-r--r--models/__init__.py0
-rw-r--r--models/__pycache__/__init__.cpython-313.pycbin0 -> 132 bytes
-rw-r--r--models/__pycache__/residual_mlp.cpython-313.pycbin0 -> 4692 bytes
-rw-r--r--models/__pycache__/state_bridge.cpython-313.pycbin0 -> 2468 bytes
-rw-r--r--models/__pycache__/value_net.cpython-313.pycbin0 -> 5308 bytes
-rw-r--r--models/residual_mlp.py73
-rw-r--r--models/state_bridge.py35
-rw-r--r--models/value_net.py77
8 files changed, 185 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/__pycache__/__init__.cpython-313.pyc b/models/__pycache__/__init__.cpython-313.pyc
new file mode 100644
index 0000000..cb3f264
--- /dev/null
+++ b/models/__pycache__/__init__.cpython-313.pyc
Binary files differ
diff --git a/models/__pycache__/residual_mlp.cpython-313.pyc b/models/__pycache__/residual_mlp.cpython-313.pyc
new file mode 100644
index 0000000..c758f50
--- /dev/null
+++ b/models/__pycache__/residual_mlp.cpython-313.pyc
Binary files differ
diff --git a/models/__pycache__/state_bridge.cpython-313.pyc b/models/__pycache__/state_bridge.cpython-313.pyc
new file mode 100644
index 0000000..69e1071
--- /dev/null
+++ b/models/__pycache__/state_bridge.cpython-313.pyc
Binary files differ
diff --git a/models/__pycache__/value_net.cpython-313.pyc b/models/__pycache__/value_net.cpython-313.pyc
new file mode 100644
index 0000000..a6187ee
--- /dev/null
+++ b/models/__pycache__/value_net.cpython-313.pyc
Binary files differ
diff --git a/models/residual_mlp.py b/models/residual_mlp.py
new file mode 100644
index 0000000..c16778c
--- /dev/null
+++ b/models/residual_mlp.py
@@ -0,0 +1,73 @@
+"""
+Deep Residual MLP for classification.
+Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head.
+Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l))
+"""
+import torch
+import torch.nn as nn
+
+
+class ResidualBlock(nn.Module):
+ """Single pre-LayerNorm residual MLP block."""
+
+ def __init__(self, d_hidden: int):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w1 = nn.Linear(d_hidden, d_hidden)
+ self.w2 = nn.Linear(d_hidden, d_hidden)
+ # Small init for residual branch
+ nn.init.normal_(self.w2.weight, std=0.01)
+ nn.init.zeros_(self.w2.bias)
+
+ def forward(self, h: torch.Tensor) -> torch.Tensor:
+ """Returns the residual F_l(h), NOT h + F_l(h)."""
+ z = self.ln(h)
+ z = self.w1(z)
+ z = torch.nn.functional.gelu(z)
+ z = self.w2(z)
+ return z
+
+
+class ResidualMLP(nn.Module):
+ """Deep residual MLP: embed -> L blocks -> LN -> output head."""
+
+ def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int):
+ super().__init__()
+ self.embed = nn.Linear(input_dim, d_hidden)
+ self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)])
+ self.out_ln = nn.LayerNorm(d_hidden)
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def forward(self, x: torch.Tensor, return_hidden: bool = False):
+ """
+ Args:
+ x: (batch, input_dim)
+ return_hidden: if True, also return list of hidden states [h_0, ..., h_L]
+ Returns:
+ logits: (batch, num_classes)
+ hiddens: list of (batch, d_hidden) if return_hidden
+ """
+ h = self.embed(x)
+ hiddens = [h] if return_hidden else None
+
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+
+ logits = self.out_head(self.out_ln(h))
+
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h: torch.Tensor, start_layer: int):
+ """Run forward from a given layer index to output. Used for perturbation tests."""
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f
+ logits = self.out_head(self.out_ln(h))
+ return logits
diff --git a/models/state_bridge.py b/models/state_bridge.py
new file mode 100644
index 0000000..0a0e7aa
--- /dev/null
+++ b/models/state_bridge.py
@@ -0,0 +1,35 @@
+"""
+State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L.
+Used by the State Bridge method.
+"""
+import torch
+import torch.nn as nn
+from .value_net import SinusoidalTimeEmbed
+
+
+class StateBridgeNet(nn.Module):
+ """
+ State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L.
+ """
+
+ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
+ hidden_dim: int = 256, num_layers: int = 3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ """Returns predicted h_L as (batch, d_hidden)."""
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
diff --git a/models/value_net.py b/models/value_net.py
new file mode 100644
index 0000000..3c72f75
--- /dev/null
+++ b/models/value_net.py
@@ -0,0 +1,77 @@
+"""
+Value network V_phi(h_l, t_l, s) -> scalar.
+Used by the Credit Bridge method.
+Input: [LN(h_l), time_embed(t_l), s] concatenated.
+"""
+import torch
+import torch.nn as nn
+import math
+import copy
+
+
+class SinusoidalTimeEmbed(nn.Module):
+ """Sinusoidal positional encoding for scalar depth-time t_l = l/L."""
+
+ def __init__(self, embed_dim: int):
+ super().__init__()
+ self.embed_dim = embed_dim
+
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
+ """t: (batch,) or (batch, 1) scalar in [0,1]."""
+ if t.dim() == 1:
+ t = t.unsqueeze(-1) # (batch, 1)
+ half = self.embed_dim // 2
+ freqs = torch.exp(
+ -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half
+ )
+ args = t * freqs.unsqueeze(0) # (batch, half)
+ return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # (batch, embed_dim)
+
+
+class ValueNet(nn.Module):
+ """
+ Scalar value network V_phi(h_l, t_l, s).
+ Inputs:
+ h: hidden state (batch, d_hidden)
+ t: depth-time scalar (batch,) in [0, 1]
+ s: terminal modulation code (batch, s_dim)
+ Output:
+ V: scalar (batch,)
+ """
+
+ def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32,
+ hidden_dim: int = 256, num_layers: int = 3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, 1))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
+ """Returns V(h, t, s) as (batch,) scalar."""
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp).squeeze(-1)
+
+
+def create_ema_model(model: nn.Module) -> nn.Module:
+ """Create an EMA copy of a model."""
+ ema = copy.deepcopy(model)
+ for p in ema.parameters():
+ p.requires_grad_(False)
+ return ema
+
+
+@torch.no_grad()
+def update_ema(model: nn.Module, ema_model: nn.Module, momentum: float = 0.99):
+ """Update EMA model parameters."""
+ for p, ep in zip(model.parameters(), ema_model.parameters()):
+ ep.data.mul_(momentum).add_(p.data, alpha=1 - momentum)