diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 16:21:21 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 16:21:21 -0500 |
| commit | cd4b331eef4a74e47fc46a4c1d8ca43c1f779b5a (patch) | |
| tree | af53daec3227e4463e044f5a3899c34ff4a7160d /experiments | |
| parent | 9d1eaacab11510793e36fc9bba271fd7c330f6e4 (diff) | |
Fix CNN state bridge: use custom CNNStateBridge for variable input dims
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cnn_baseline.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index af754c0..75c3ff8 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -434,10 +434,19 @@ def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb flat_dims = model.flat_dims # [32768, 16384, 8192, 256] d_terminal = 256 # h3 is the terminal hidden state - # One SB net per layer (each takes flat_dim_l as input, outputs 256) + # One SB net per layer: MLP from flat_dim_l + time_embed + s_dim -> 256 + from models.value_net import SinusoidalTimeEmbed + class CNNStateBridge(nn.Module): + def __init__(self, in_dim, out_dim, s_dim, te_dim=32): + super().__init__() + self.ln = nn.LayerNorm(in_dim) + self.te = SinusoidalTimeEmbed(te_dim) + total = in_dim + te_dim + s_dim + self.net = nn.Sequential(nn.Linear(total, 256), nn.GELU(), nn.Linear(256, 256), nn.GELU(), nn.Linear(256, out_dim)) + def forward(self, h, t, s): + return self.net(torch.cat([self.ln(h), self.te(t), s], dim=-1)) state_preds = nn.ModuleList([ - StateBridgeNet(d_hidden=flat_dims[l], s_dim=C, - time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev) + CNNStateBridge(flat_dims[l], d_terminal, C).to(dev) for l in range(L) ]) |
