""" Training script for snnTorch-based deep SNNs with Lyapunov regularization. Usage: # Baseline (no Lyapunov) python files/train_snntorch.py --hidden 256 128 --epochs 10 # With Lyapunov regularization python files/train_snntorch.py --hidden 256 128 --epochs 10 --lyapunov --lambda_reg 0.1 # Recurrent model python files/train_snntorch.py --model recurrent --hidden 256 --epochs 10 --lyapunov """ import os import sys import json import csv _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(_HERE) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import time from typing import List, Optional import torch import torch.nn as nn import torch.optim as optim from tqdm.auto import tqdm from files.data_io.dataset_loader import get_dataloader from files.models.snn_snntorch import create_snn def _prepare_run_dir(base_dir: str) -> str: ts = time.strftime("%Y%m%d-%H%M%S") run_dir = os.path.join(base_dir, ts) os.makedirs(run_dir, exist_ok=True) return run_dir def _append_metrics(csv_path: str, row: dict): write_header = not os.path.exists(csv_path) with open(csv_path, "a", newline="") as f: writer = csv.DictWriter(f, fieldnames=row.keys()) if write_header: writer.writeheader() writer.writerow(row) def parse_args(): p = argparse.ArgumentParser( description="Train deep SNN with snnTorch and optional Lyapunov regularization" ) # Model architecture p.add_argument( "--model", type=str, default="feedforward", choices=["feedforward", "recurrent"], help="Model type: 'feedforward' (LIF) or 'recurrent' (RSynaptic)" ) p.add_argument( "--hidden", type=int, nargs="+", default=[256], help="Hidden layer sizes (e.g., --hidden 256 128 for 2 layers)" ) p.add_argument("--classes", type=int, default=20, help="Number of output classes") p.add_argument("--beta", type=float, default=0.9, help="Membrane decay (beta)") p.add_argument("--threshold", type=float, default=1.0, help="Firing threshold") p.add_argument("--dropout", type=float, default=0.0, help="Dropout between layers") p.add_argument( "--surrogate_slope", type=float, default=25.0, help="Slope for fast_sigmoid surrogate gradient" ) # Recurrent-specific (only for --model recurrent) p.add_argument("--alpha", type=float, default=0.9, help="Synaptic current decay (recurrent only)") # Training p.add_argument("--epochs", type=int, default=10) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--weight_decay", type=float, default=0.0, help="L2 regularization") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") # Lyapunov regularization p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regularization") p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov penalty weight") p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation magnitude") p.add_argument( "--lyap_layers", type=int, nargs="*", default=None, help="Which layers to measure (default: all). E.g., --lyap_layers 0 1" ) # Output p.add_argument("--out_dir", type=str, default="runs/snntorch", help="Output directory") p.add_argument("--log_batches", action="store_true", help="Log per-batch metrics") p.add_argument("--no-progress", action="store_true", help="Disable progress bar") p.add_argument("--save_model", action="store_true", help="Save model checkpoint") return p.parse_args() def train_one_epoch( model: nn.Module, loader, optimizer: optim.Optimizer, device: torch.device, ce_loss: nn.Module, lyapunov: bool, lambda_reg: float, lambda_target: float, lyap_eps: float, lyap_layers: Optional[List[int]], progress: bool, run_dir: Optional[str] = None, epoch_idx: Optional[int] = None, log_batches: bool = False, ): model.train() total = 0 correct = 0 running_loss = 0.0 lyap_vals = [] iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader for bidx, (x, y) in enumerate(iterator): x = x.to(device) # (B, T, D) y = y.to(device) optimizer.zero_grad(set_to_none=True) logits, lyap_est = model( x, compute_lyapunov=lyapunov, lyap_eps=lyap_eps, lyap_layers=lyap_layers, ) ce = ce_loss(logits, y) if lyapunov and lyap_est is not None: # Penalize deviation from target Lyapunov exponent reg = (lyap_est - lambda_target) ** 2 loss = ce + lambda_reg * reg lyap_vals.append(lyap_est.detach().item()) else: loss = ce loss.backward() optimizer.step() running_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) batch_correct = (preds == y).sum().item() correct += batch_correct total += x.size(0) if log_batches and run_dir is not None and epoch_idx is not None: _append_metrics( os.path.join(run_dir, "metrics.csv"), { "step": "batch", "epoch": int(epoch_idx), "batch": int(bidx), "loss": float(loss.item()), "acc": float(batch_correct / max(x.size(0), 1)), "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"), "time_sec": 0.0, }, ) if progress: avg_loss = running_loss / max(total, 1) avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None postfix = {"loss": f"{avg_loss:.4f}", "acc": f"{correct / total:.3f}"} if avg_lyap is not None: postfix["lyap"] = f"{avg_lyap:.4f}" iterator.set_postfix(postfix) avg_loss = running_loss / max(total, 1) acc = correct / max(total, 1) avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None return avg_loss, acc, avg_lyap @torch.no_grad() def evaluate( model: nn.Module, loader, device: torch.device, ce_loss: nn.Module, progress: bool, ): model.eval() total = 0 correct = 0 running_loss = 0.0 iterator = tqdm(loader, desc="eval", leave=False, dynamic_ncols=True) if progress else loader for x, y in iterator: x = x.to(device) y = y.to(device) logits, _ = model(x, compute_lyapunov=False) loss = ce_loss(logits, y) running_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) correct += (preds == y).sum().item() total += x.size(0) avg_loss = running_loss / max(total, 1) acc = correct / max(total, 1) return avg_loss, acc def main(): args = parse_args() device = torch.device(args.device) # Prepare output directory run_dir = _prepare_run_dir(args.out_dir) with open(os.path.join(run_dir, "args.json"), "w") as f: json.dump(vars(args), f, indent=2) # Load data train_loader, val_loader = get_dataloader(args.cfg) # Infer dimensions from data xb, yb = next(iter(train_loader)) _, T, D = xb.shape C = args.classes print(f"Data: T={T}, D={D}, classes={C}") print(f"Model: {args.model}, hidden={args.hidden}") # Create model from snntorch import surrogate spike_grad = surrogate.fast_sigmoid(slope=args.surrogate_slope) if args.model == "feedforward": model = create_snn( model_type="feedforward", input_dim=D, hidden_dims=args.hidden, num_classes=C, beta=args.beta, threshold=args.threshold, spike_grad=spike_grad, dropout=args.dropout, ) else: # recurrent model = create_snn( model_type="recurrent", input_dim=D, hidden_dims=args.hidden, num_classes=C, alpha=args.alpha, beta=args.beta, threshold=args.threshold, spike_grad=spike_grad, ) model = model.to(device) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Parameters: {num_params:,}") optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) ce_loss = nn.CrossEntropyLoss() print(f"\nTraining on {device} | lyapunov={args.lyapunov} λ_reg={args.lambda_reg} λ_target={args.lambda_target}") print(f"Output: {run_dir}\n") best_val_acc = 0.0 for epoch in range(1, args.epochs + 1): t0 = time.time() tr_loss, tr_acc, tr_lyap = train_one_epoch( model=model, loader=train_loader, optimizer=optimizer, device=device, ce_loss=ce_loss, lyapunov=args.lyapunov, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, lyap_eps=args.lyap_eps, lyap_layers=args.lyap_layers, progress=(not args.no_progress), run_dir=run_dir, epoch_idx=epoch, log_batches=args.log_batches, ) val_loss, val_acc = evaluate( model=model, loader=val_loader, device=device, ce_loss=ce_loss, progress=(not args.no_progress), ) dt = time.time() - t0 lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else "" print( f"[Epoch {epoch:3d}] " f"train_loss={tr_loss:.4f} train_acc={tr_acc:.3f}{lyap_str} | " f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} ({dt:.1f}s)" ) _append_metrics( os.path.join(run_dir, "metrics.csv"), { "step": "epoch", "epoch": int(epoch), "batch": -1, "loss": float(tr_loss), "acc": float(tr_acc), "val_loss": float(val_loss), "val_acc": float(val_acc), "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"), "time_sec": float(dt), }, ) # Save best model if args.save_model and val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), os.path.join(run_dir, "best_model.pt")) print(f"\nTraining complete. Best val_acc: {best_val_acc:.3f}") if args.save_model: torch.save(model.state_dict(), os.path.join(run_dir, "final_model.pt")) print(f"Model saved to {run_dir}") if __name__ == "__main__": main()