diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
| commit | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch) | |
| tree | 59a233959932ca0e4f12f196275e07fcf443b33f /files/tests/test_train_smoke.py | |
init commit
Diffstat (limited to 'files/tests/test_train_smoke.py')
| -rw-r--r-- | files/tests/test_train_smoke.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/files/tests/test_train_smoke.py b/files/tests/test_train_smoke.py new file mode 100644 index 0000000..e04e40b --- /dev/null +++ b/files/tests/test_train_smoke.py @@ -0,0 +1,33 @@ +import torch +from files.models.snn import SimpleSNN +from files.data_io.dataset_loader import get_dataloader +import torch.nn as nn +import torch.optim as optim + + +def _one_step(model, xb, yb, lyapunov=False): + model.train() + opt = optim.Adam(model.parameters(), lr=1e-3) + ce = nn.CrossEntropyLoss() + opt.zero_grad(set_to_none=True) + logits, lyap = model(xb, compute_lyapunov=lyapunov) + loss = ce(logits, yb) + if lyapunov and lyap is not None: + loss = loss + 0.1 * (lyap - 0.0) ** 2 + loss.backward() + opt.step() + assert torch.isfinite(loss).all() + + +def test_train_step_baseline_and_lyapunov(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + B, T, D = xb.shape + C = 20 + model = SimpleSNN(input_dim=D, hidden_dim=64, num_classes=C) + # baseline + _one_step(model, xb, yb, lyapunov=False) + # lyapunov-regularized + _one_step(model, xb, yb, lyapunov=True) + + |
