diff options
Diffstat (limited to 'files/train_snntorch.py')
| -rw-r--r-- | files/train_snntorch.py | 345 |
1 files changed, 345 insertions, 0 deletions
diff --git a/files/train_snntorch.py b/files/train_snntorch.py new file mode 100644 index 0000000..dbdce37 --- /dev/null +++ b/files/train_snntorch.py @@ -0,0 +1,345 @@ +""" +Training script for snnTorch-based deep SNNs with Lyapunov regularization. + +Usage: + # Baseline (no Lyapunov) + python files/train_snntorch.py --hidden 256 128 --epochs 10 + + # With Lyapunov regularization + python files/train_snntorch.py --hidden 256 128 --epochs 10 --lyapunov --lambda_reg 0.1 + + # Recurrent model + python files/train_snntorch.py --model recurrent --hidden 256 --epochs 10 --lyapunov +""" + +import os +import sys +import json +import csv + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(_HERE) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import time +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm.auto import tqdm + +from files.data_io.dataset_loader import get_dataloader +from files.models.snn_snntorch import create_snn + + +def _prepare_run_dir(base_dir: str) -> str: + ts = time.strftime("%Y%m%d-%H%M%S") + run_dir = os.path.join(base_dir, ts) + os.makedirs(run_dir, exist_ok=True) + return run_dir + + +def _append_metrics(csv_path: str, row: dict): + write_header = not os.path.exists(csv_path) + with open(csv_path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=row.keys()) + if write_header: + writer.writeheader() + writer.writerow(row) + + +def parse_args(): + p = argparse.ArgumentParser( + description="Train deep SNN with snnTorch and optional Lyapunov regularization" + ) + + # Model architecture + p.add_argument( + "--model", type=str, default="feedforward", choices=["feedforward", "recurrent"], + help="Model type: 'feedforward' (LIF) or 'recurrent' (RSynaptic)" + ) + p.add_argument( + "--hidden", type=int, nargs="+", default=[256], + help="Hidden layer sizes (e.g., --hidden 256 128 for 2 layers)" + ) + p.add_argument("--classes", type=int, default=20, help="Number of output classes") + p.add_argument("--beta", type=float, default=0.9, help="Membrane decay (beta)") + p.add_argument("--threshold", type=float, default=1.0, help="Firing threshold") + p.add_argument("--dropout", type=float, default=0.0, help="Dropout between layers") + p.add_argument( + "--surrogate_slope", type=float, default=25.0, + help="Slope for fast_sigmoid surrogate gradient" + ) + + # Recurrent-specific (only for --model recurrent) + p.add_argument("--alpha", type=float, default=0.9, help="Synaptic current decay (recurrent only)") + + # Training + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--weight_decay", type=float, default=0.0, help="L2 regularization") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") + + # Lyapunov regularization + p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regularization") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov penalty weight") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") + p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation magnitude") + p.add_argument( + "--lyap_layers", type=int, nargs="*", default=None, + help="Which layers to measure (default: all). E.g., --lyap_layers 0 1" + ) + + # Output + p.add_argument("--out_dir", type=str, default="runs/snntorch", help="Output directory") + p.add_argument("--log_batches", action="store_true", help="Log per-batch metrics") + p.add_argument("--no-progress", action="store_true", help="Disable progress bar") + p.add_argument("--save_model", action="store_true", help="Save model checkpoint") + + return p.parse_args() + + +def train_one_epoch( + model: nn.Module, + loader, + optimizer: optim.Optimizer, + device: torch.device, + ce_loss: nn.Module, + lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + lyap_layers: Optional[List[int]], + progress: bool, + run_dir: Optional[str] = None, + epoch_idx: Optional[int] = None, + log_batches: bool = False, +): + model.train() + total = 0 + correct = 0 + running_loss = 0.0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader + + for bidx, (x, y) in enumerate(iterator): + x = x.to(device) # (B, T, D) + y = y.to(device) + + optimizer.zero_grad(set_to_none=True) + + logits, lyap_est = model( + x, + compute_lyapunov=lyapunov, + lyap_eps=lyap_eps, + lyap_layers=lyap_layers, + ) + + ce = ce_loss(logits, y) + + if lyapunov and lyap_est is not None: + # Penalize deviation from target Lyapunov exponent + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.detach().item()) + else: + loss = ce + + loss.backward() + optimizer.step() + + running_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + batch_correct = (preds == y).sum().item() + correct += batch_correct + total += x.size(0) + + if log_batches and run_dir is not None and epoch_idx is not None: + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "batch", + "epoch": int(epoch_idx), + "batch": int(bidx), + "loss": float(loss.item()), + "acc": float(batch_correct / max(x.size(0), 1)), + "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"), + "time_sec": 0.0, + }, + ) + + if progress: + avg_loss = running_loss / max(total, 1) + avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None + postfix = {"loss": f"{avg_loss:.4f}", "acc": f"{correct / total:.3f}"} + if avg_lyap is not None: + postfix["lyap"] = f"{avg_lyap:.4f}" + iterator.set_postfix(postfix) + + avg_loss = running_loss / max(total, 1) + acc = correct / max(total, 1) + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return avg_loss, acc, avg_lyap + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader, + device: torch.device, + ce_loss: nn.Module, + progress: bool, +): + model.eval() + total = 0 + correct = 0 + running_loss = 0.0 + + iterator = tqdm(loader, desc="eval", leave=False, dynamic_ncols=True) if progress else loader + + for x, y in iterator: + x = x.to(device) + y = y.to(device) + + logits, _ = model(x, compute_lyapunov=False) + loss = ce_loss(logits, y) + + running_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + correct += (preds == y).sum().item() + total += x.size(0) + + avg_loss = running_loss / max(total, 1) + acc = correct / max(total, 1) + return avg_loss, acc + + +def main(): + args = parse_args() + device = torch.device(args.device) + + # Prepare output directory + run_dir = _prepare_run_dir(args.out_dir) + with open(os.path.join(run_dir, "args.json"), "w") as f: + json.dump(vars(args), f, indent=2) + + # Load data + train_loader, val_loader = get_dataloader(args.cfg) + + # Infer dimensions from data + xb, yb = next(iter(train_loader)) + _, T, D = xb.shape + C = args.classes + + print(f"Data: T={T}, D={D}, classes={C}") + print(f"Model: {args.model}, hidden={args.hidden}") + + # Create model + from snntorch import surrogate + spike_grad = surrogate.fast_sigmoid(slope=args.surrogate_slope) + + if args.model == "feedforward": + model = create_snn( + model_type="feedforward", + input_dim=D, + hidden_dims=args.hidden, + num_classes=C, + beta=args.beta, + threshold=args.threshold, + spike_grad=spike_grad, + dropout=args.dropout, + ) + else: # recurrent + model = create_snn( + model_type="recurrent", + input_dim=D, + hidden_dims=args.hidden, + num_classes=C, + alpha=args.alpha, + beta=args.beta, + threshold=args.threshold, + spike_grad=spike_grad, + ) + + model = model.to(device) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Parameters: {num_params:,}") + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + ce_loss = nn.CrossEntropyLoss() + + print(f"\nTraining on {device} | lyapunov={args.lyapunov} λ_reg={args.lambda_reg} λ_target={args.lambda_target}") + print(f"Output: {run_dir}\n") + + best_val_acc = 0.0 + + for epoch in range(1, args.epochs + 1): + t0 = time.time() + + tr_loss, tr_acc, tr_lyap = train_one_epoch( + model=model, + loader=train_loader, + optimizer=optimizer, + device=device, + ce_loss=ce_loss, + lyapunov=args.lyapunov, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + lyap_layers=args.lyap_layers, + progress=(not args.no_progress), + run_dir=run_dir, + epoch_idx=epoch, + log_batches=args.log_batches, + ) + + val_loss, val_acc = evaluate( + model=model, + loader=val_loader, + device=device, + ce_loss=ce_loss, + progress=(not args.no_progress), + ) + + dt = time.time() - t0 + lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else "" + + print( + f"[Epoch {epoch:3d}] " + f"train_loss={tr_loss:.4f} train_acc={tr_acc:.3f}{lyap_str} | " + f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} ({dt:.1f}s)" + ) + + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "epoch", + "epoch": int(epoch), + "batch": -1, + "loss": float(tr_loss), + "acc": float(tr_acc), + "val_loss": float(val_loss), + "val_acc": float(val_acc), + "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"), + "time_sec": float(dt), + }, + ) + + # Save best model + if args.save_model and val_acc > best_val_acc: + best_val_acc = val_acc + torch.save(model.state_dict(), os.path.join(run_dir, "best_model.pt")) + + print(f"\nTraining complete. Best val_acc: {best_val_acc:.3f}") + if args.save_model: + torch.save(model.state_dict(), os.path.join(run_dir, "final_model.pt")) + print(f"Model saved to {run_dir}") + + +if __name__ == "__main__": + main() |
