summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--experiments/cnn_baseline.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
index af754c0..75c3ff8 100644
--- a/experiments/cnn_baseline.py
+++ b/experiments/cnn_baseline.py
@@ -434,10 +434,19 @@ def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb
flat_dims = model.flat_dims # [32768, 16384, 8192, 256]
d_terminal = 256 # h3 is the terminal hidden state
- # One SB net per layer (each takes flat_dim_l as input, outputs 256)
+ # One SB net per layer: MLP from flat_dim_l + time_embed + s_dim -> 256
+ from models.value_net import SinusoidalTimeEmbed
+ class CNNStateBridge(nn.Module):
+ def __init__(self, in_dim, out_dim, s_dim, te_dim=32):
+ super().__init__()
+ self.ln = nn.LayerNorm(in_dim)
+ self.te = SinusoidalTimeEmbed(te_dim)
+ total = in_dim + te_dim + s_dim
+ self.net = nn.Sequential(nn.Linear(total, 256), nn.GELU(), nn.Linear(256, 256), nn.GELU(), nn.Linear(256, out_dim))
+ def forward(self, h, t, s):
+ return self.net(torch.cat([self.ln(h), self.te(t), s], dim=-1))
state_preds = nn.ModuleList([
- StateBridgeNet(d_hidden=flat_dims[l], s_dim=C,
- time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev)
+ CNNStateBridge(flat_dims[l], d_terminal, C).to(dev)
for l in range(L)
])