import torch from files.models.snn import SimpleSNN from files.data_io.dataset_loader import get_dataloader import torch.nn as nn import torch.optim as optim def _one_step(model, xb, yb, lyapunov=False): model.train() opt = optim.Adam(model.parameters(), lr=1e-3) ce = nn.CrossEntropyLoss() opt.zero_grad(set_to_none=True) logits, lyap = model(xb, compute_lyapunov=lyapunov) loss = ce(logits, yb) if lyapunov and lyap is not None: loss = loss + 0.1 * (lyap - 0.0) ** 2 loss.backward() opt.step() assert torch.isfinite(loss).all() def test_train_step_baseline_and_lyapunov(): train_loader, _ = get_dataloader("data_io/configs/shd.yaml") xb, yb = next(iter(train_loader)) B, T, D = xb.shape C = 20 model = SimpleSNN(input_dim=D, hidden_dim=64, num_classes=C) # baseline _one_step(model, xb, yb, lyapunov=False) # lyapunov-regularized _one_step(model, xb, yb, lyapunov=True)