From cd4b331eef4a74e47fc46a4c1d8ca43c1f779b5a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Thu, 2 Apr 2026 16:21:21 -0500 Subject: Fix CNN state bridge: use custom CNNStateBridge for variable input dims Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/cnn_baseline.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'experiments') diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index af754c0..75c3ff8 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -434,10 +434,19 @@ def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb flat_dims = model.flat_dims # [32768, 16384, 8192, 256] d_terminal = 256 # h3 is the terminal hidden state - # One SB net per layer (each takes flat_dim_l as input, outputs 256) + # One SB net per layer: MLP from flat_dim_l + time_embed + s_dim -> 256 + from models.value_net import SinusoidalTimeEmbed + class CNNStateBridge(nn.Module): + def __init__(self, in_dim, out_dim, s_dim, te_dim=32): + super().__init__() + self.ln = nn.LayerNorm(in_dim) + self.te = SinusoidalTimeEmbed(te_dim) + total = in_dim + te_dim + s_dim + self.net = nn.Sequential(nn.Linear(total, 256), nn.GELU(), nn.Linear(256, 256), nn.GELU(), nn.Linear(256, out_dim)) + def forward(self, h, t, s): + return self.net(torch.cat([self.ln(h), self.te(t), s], dim=-1)) state_preds = nn.ModuleList([ - StateBridgeNet(d_hidden=flat_dims[l], s_dim=C, - time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev) + CNNStateBridge(flat_dims[l], d_terminal, C).to(dev) for l in range(L) ]) -- cgit v1.2.3