From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/tests/test_train_smoke.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 files/tests/test_train_smoke.py (limited to 'files/tests/test_train_smoke.py') diff --git a/files/tests/test_train_smoke.py b/files/tests/test_train_smoke.py new file mode 100644 index 0000000..e04e40b --- /dev/null +++ b/files/tests/test_train_smoke.py @@ -0,0 +1,33 @@ +import torch +from files.models.snn import SimpleSNN +from files.data_io.dataset_loader import get_dataloader +import torch.nn as nn +import torch.optim as optim + + +def _one_step(model, xb, yb, lyapunov=False): + model.train() + opt = optim.Adam(model.parameters(), lr=1e-3) + ce = nn.CrossEntropyLoss() + opt.zero_grad(set_to_none=True) + logits, lyap = model(xb, compute_lyapunov=lyapunov) + loss = ce(logits, yb) + if lyapunov and lyap is not None: + loss = loss + 0.1 * (lyap - 0.0) ** 2 + loss.backward() + opt.step() + assert torch.isfinite(loss).all() + + +def test_train_step_baseline_and_lyapunov(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + B, T, D = xb.shape + C = 20 + model = SimpleSNN(input_dim=D, hidden_dim=64, num_classes=C) + # baseline + _one_step(model, xb, yb, lyapunov=False) + # lyapunov-regularized + _one_step(model, xb, yb, lyapunov=True) + + -- cgit v1.2.3