diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-23 18:21:26 -0500 |
| commit | 6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch) | |
| tree | d7c63adcd19c4f5d46c8a937e5047fece55dea62 /models | |
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching.
Credit bridge matches state bridge on linear system (~0.94 cosine).
CIFAR experiments in progress.
Diffstat (limited to 'models')
| -rw-r--r-- | models/__init__.py | 0 | ||||
| -rw-r--r-- | models/__pycache__/__init__.cpython-313.pyc | bin | 0 -> 132 bytes | |||
| -rw-r--r-- | models/__pycache__/residual_mlp.cpython-313.pyc | bin | 0 -> 4692 bytes | |||
| -rw-r--r-- | models/__pycache__/state_bridge.cpython-313.pyc | bin | 0 -> 2468 bytes | |||
| -rw-r--r-- | models/__pycache__/value_net.cpython-313.pyc | bin | 0 -> 5308 bytes | |||
| -rw-r--r-- | models/residual_mlp.py | 73 | ||||
| -rw-r--r-- | models/state_bridge.py | 35 | ||||
| -rw-r--r-- | models/value_net.py | 77 |
8 files changed, 185 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/__pycache__/__init__.cpython-313.pyc b/models/__pycache__/__init__.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..cb3f264 --- /dev/null +++ b/models/__pycache__/__init__.cpython-313.pyc diff --git a/models/__pycache__/residual_mlp.cpython-313.pyc b/models/__pycache__/residual_mlp.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..c758f50 --- /dev/null +++ b/models/__pycache__/residual_mlp.cpython-313.pyc diff --git a/models/__pycache__/state_bridge.cpython-313.pyc b/models/__pycache__/state_bridge.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..69e1071 --- /dev/null +++ b/models/__pycache__/state_bridge.cpython-313.pyc diff --git a/models/__pycache__/value_net.cpython-313.pyc b/models/__pycache__/value_net.cpython-313.pyc Binary files differnew file mode 100644 index 0000000..a6187ee --- /dev/null +++ b/models/__pycache__/value_net.cpython-313.pyc diff --git a/models/residual_mlp.py b/models/residual_mlp.py new file mode 100644 index 0000000..c16778c --- /dev/null +++ b/models/residual_mlp.py @@ -0,0 +1,73 @@ +""" +Deep Residual MLP for classification. +Architecture: Input -> Linear embedding -> L residual blocks -> LayerNorm -> Linear output head. +Each block: h_{l+1} = h_l + W2 * GELU(W1 * LN(h_l)) +""" +import torch +import torch.nn as nn + + +class ResidualBlock(nn.Module): + """Single pre-LayerNorm residual MLP block.""" + + def __init__(self, d_hidden: int): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w1 = nn.Linear(d_hidden, d_hidden) + self.w2 = nn.Linear(d_hidden, d_hidden) + # Small init for residual branch + nn.init.normal_(self.w2.weight, std=0.01) + nn.init.zeros_(self.w2.bias) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + """Returns the residual F_l(h), NOT h + F_l(h).""" + z = self.ln(h) + z = self.w1(z) + z = torch.nn.functional.gelu(z) + z = self.w2(z) + return z + + +class ResidualMLP(nn.Module): + """Deep residual MLP: embed -> L blocks -> LN -> output head.""" + + def __init__(self, input_dim: int, d_hidden: int, num_classes: int, num_blocks: int): + super().__init__() + self.embed = nn.Linear(input_dim, d_hidden) + self.blocks = nn.ModuleList([ResidualBlock(d_hidden) for _ in range(num_blocks)]) + self.out_ln = nn.LayerNorm(d_hidden) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def forward(self, x: torch.Tensor, return_hidden: bool = False): + """ + Args: + x: (batch, input_dim) + return_hidden: if True, also return list of hidden states [h_0, ..., h_L] + Returns: + logits: (batch, num_classes) + hiddens: list of (batch, d_hidden) if return_hidden + """ + h = self.embed(x) + hiddens = [h] if return_hidden else None + + for block in self.blocks: + f = block(h) + h = h + f + if return_hidden: + hiddens.append(h) + + logits = self.out_head(self.out_ln(h)) + + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h: torch.Tensor, start_layer: int): + """Run forward from a given layer index to output. Used for perturbation tests.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f + logits = self.out_head(self.out_ln(h)) + return logits diff --git a/models/state_bridge.py b/models/state_bridge.py new file mode 100644 index 0000000..0a0e7aa --- /dev/null +++ b/models/state_bridge.py @@ -0,0 +1,35 @@ +""" +State Bridge predictor G_psi(h_l, t_l, s) -> predicted h_L. +Used by the State Bridge method. +""" +import torch +import torch.nn as nn +from .value_net import SinusoidalTimeEmbed + + +class StateBridgeNet(nn.Module): + """ + State predictor G_psi(h_l, t_l, s) -> predicted terminal state h_L. + """ + + def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, + hidden_dim: int = 256, num_layers: int = 3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + """Returns predicted h_L as (batch, d_hidden).""" + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) diff --git a/models/value_net.py b/models/value_net.py new file mode 100644 index 0000000..3c72f75 --- /dev/null +++ b/models/value_net.py @@ -0,0 +1,77 @@ +""" +Value network V_phi(h_l, t_l, s) -> scalar. +Used by the Credit Bridge method. +Input: [LN(h_l), time_embed(t_l), s] concatenated. +""" +import torch +import torch.nn as nn +import math +import copy + + +class SinusoidalTimeEmbed(nn.Module): + """Sinusoidal positional encoding for scalar depth-time t_l = l/L.""" + + def __init__(self, embed_dim: int): + super().__init__() + self.embed_dim = embed_dim + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """t: (batch,) or (batch, 1) scalar in [0,1].""" + if t.dim() == 1: + t = t.unsqueeze(-1) # (batch, 1) + half = self.embed_dim // 2 + freqs = torch.exp( + -math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half + ) + args = t * freqs.unsqueeze(0) # (batch, half) + return torch.cat([torch.sin(args), torch.cos(args)], dim=-1) # (batch, embed_dim) + + +class ValueNet(nn.Module): + """ + Scalar value network V_phi(h_l, t_l, s). + Inputs: + h: hidden state (batch, d_hidden) + t: depth-time scalar (batch,) in [0, 1] + s: terminal modulation code (batch, s_dim) + Output: + V: scalar (batch,) + """ + + def __init__(self, d_hidden: int, s_dim: int, time_embed_dim: int = 32, + hidden_dim: int = 256, num_layers: int = 3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, h: torch.Tensor, t: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + """Returns V(h, t, s) as (batch,) scalar.""" + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp).squeeze(-1) + + +def create_ema_model(model: nn.Module) -> nn.Module: + """Create an EMA copy of a model.""" + ema = copy.deepcopy(model) + for p in ema.parameters(): + p.requires_grad_(False) + return ema + + +@torch.no_grad() +def update_ema(model: nn.Module, ema_model: nn.Module, momentum: float = 0.99): + """Update EMA model parameters.""" + for p, ep in zip(model.parameters(), ema_model.parameters()): + ep.data.mul_(momentum).add_(p.data, alpha=1 - momentum) |
