summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 16:21:21 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 16:21:21 -0500
commitcd4b331eef4a74e47fc46a4c1d8ca43c1f779b5a (patch)
treeaf53daec3227e4463e044f5a3899c34ff4a7160d /experiments
parent9d1eaacab11510793e36fc9bba271fd7c330f6e4 (diff)
Fix CNN state bridge: use custom CNNStateBridge for variable input dims
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cnn_baseline.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
index af754c0..75c3ff8 100644
--- a/experiments/cnn_baseline.py
+++ b/experiments/cnn_baseline.py
@@ -434,10 +434,19 @@ def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb
flat_dims = model.flat_dims # [32768, 16384, 8192, 256]
d_terminal = 256 # h3 is the terminal hidden state
- # One SB net per layer (each takes flat_dim_l as input, outputs 256)
+ # One SB net per layer: MLP from flat_dim_l + time_embed + s_dim -> 256
+ from models.value_net import SinusoidalTimeEmbed
+ class CNNStateBridge(nn.Module):
+ def __init__(self, in_dim, out_dim, s_dim, te_dim=32):
+ super().__init__()
+ self.ln = nn.LayerNorm(in_dim)
+ self.te = SinusoidalTimeEmbed(te_dim)
+ total = in_dim + te_dim + s_dim
+ self.net = nn.Sequential(nn.Linear(total, 256), nn.GELU(), nn.Linear(256, 256), nn.GELU(), nn.Linear(256, out_dim))
+ def forward(self, h, t, s):
+ return self.net(torch.cat([self.ln(h), self.te(t), s], dim=-1))
state_preds = nn.ModuleList([
- StateBridgeNet(d_hidden=flat_dims[l], s_dim=C,
- time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev)
+ CNNStateBridge(flat_dims[l], d_terminal, C).to(dev)
for l in range(L)
])