summaryrefslogtreecommitdiff
path: root/files/train_snntorch.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/train_snntorch.py')
-rw-r--r--files/train_snntorch.py345
1 files changed, 345 insertions, 0 deletions
diff --git a/files/train_snntorch.py b/files/train_snntorch.py
new file mode 100644
index 0000000..dbdce37
--- /dev/null
+++ b/files/train_snntorch.py
@@ -0,0 +1,345 @@
+"""
+Training script for snnTorch-based deep SNNs with Lyapunov regularization.
+
+Usage:
+ # Baseline (no Lyapunov)
+ python files/train_snntorch.py --hidden 256 128 --epochs 10
+
+ # With Lyapunov regularization
+ python files/train_snntorch.py --hidden 256 128 --epochs 10 --lyapunov --lambda_reg 0.1
+
+ # Recurrent model
+ python files/train_snntorch.py --model recurrent --hidden 256 --epochs 10 --lyapunov
+"""
+
+import os
+import sys
+import json
+import csv
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(_HERE)
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import time
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from tqdm.auto import tqdm
+
+from files.data_io.dataset_loader import get_dataloader
+from files.models.snn_snntorch import create_snn
+
+
+def _prepare_run_dir(base_dir: str) -> str:
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ run_dir = os.path.join(base_dir, ts)
+ os.makedirs(run_dir, exist_ok=True)
+ return run_dir
+
+
+def _append_metrics(csv_path: str, row: dict):
+ write_header = not os.path.exists(csv_path)
+ with open(csv_path, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=row.keys())
+ if write_header:
+ writer.writeheader()
+ writer.writerow(row)
+
+
+def parse_args():
+ p = argparse.ArgumentParser(
+ description="Train deep SNN with snnTorch and optional Lyapunov regularization"
+ )
+
+ # Model architecture
+ p.add_argument(
+ "--model", type=str, default="feedforward", choices=["feedforward", "recurrent"],
+ help="Model type: 'feedforward' (LIF) or 'recurrent' (RSynaptic)"
+ )
+ p.add_argument(
+ "--hidden", type=int, nargs="+", default=[256],
+ help="Hidden layer sizes (e.g., --hidden 256 128 for 2 layers)"
+ )
+ p.add_argument("--classes", type=int, default=20, help="Number of output classes")
+ p.add_argument("--beta", type=float, default=0.9, help="Membrane decay (beta)")
+ p.add_argument("--threshold", type=float, default=1.0, help="Firing threshold")
+ p.add_argument("--dropout", type=float, default=0.0, help="Dropout between layers")
+ p.add_argument(
+ "--surrogate_slope", type=float, default=25.0,
+ help="Slope for fast_sigmoid surrogate gradient"
+ )
+
+ # Recurrent-specific (only for --model recurrent)
+ p.add_argument("--alpha", type=float, default=0.9, help="Synaptic current decay (recurrent only)")
+
+ # Training
+ p.add_argument("--epochs", type=int, default=10)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--weight_decay", type=float, default=0.0, help="L2 regularization")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config")
+
+ # Lyapunov regularization
+ p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regularization")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov penalty weight")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent")
+ p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation magnitude")
+ p.add_argument(
+ "--lyap_layers", type=int, nargs="*", default=None,
+ help="Which layers to measure (default: all). E.g., --lyap_layers 0 1"
+ )
+
+ # Output
+ p.add_argument("--out_dir", type=str, default="runs/snntorch", help="Output directory")
+ p.add_argument("--log_batches", action="store_true", help="Log per-batch metrics")
+ p.add_argument("--no-progress", action="store_true", help="Disable progress bar")
+ p.add_argument("--save_model", action="store_true", help="Save model checkpoint")
+
+ return p.parse_args()
+
+
+def train_one_epoch(
+ model: nn.Module,
+ loader,
+ optimizer: optim.Optimizer,
+ device: torch.device,
+ ce_loss: nn.Module,
+ lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ lyap_layers: Optional[List[int]],
+ progress: bool,
+ run_dir: Optional[str] = None,
+ epoch_idx: Optional[int] = None,
+ log_batches: bool = False,
+):
+ model.train()
+ total = 0
+ correct = 0
+ running_loss = 0.0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader
+
+ for bidx, (x, y) in enumerate(iterator):
+ x = x.to(device) # (B, T, D)
+ y = y.to(device)
+
+ optimizer.zero_grad(set_to_none=True)
+
+ logits, lyap_est = model(
+ x,
+ compute_lyapunov=lyapunov,
+ lyap_eps=lyap_eps,
+ lyap_layers=lyap_layers,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if lyapunov and lyap_est is not None:
+ # Penalize deviation from target Lyapunov exponent
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.detach().item())
+ else:
+ loss = ce
+
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ batch_correct = (preds == y).sum().item()
+ correct += batch_correct
+ total += x.size(0)
+
+ if log_batches and run_dir is not None and epoch_idx is not None:
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "batch",
+ "epoch": int(epoch_idx),
+ "batch": int(bidx),
+ "loss": float(loss.item()),
+ "acc": float(batch_correct / max(x.size(0), 1)),
+ "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"),
+ "time_sec": 0.0,
+ },
+ )
+
+ if progress:
+ avg_loss = running_loss / max(total, 1)
+ avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None
+ postfix = {"loss": f"{avg_loss:.4f}", "acc": f"{correct / total:.3f}"}
+ if avg_lyap is not None:
+ postfix["lyap"] = f"{avg_lyap:.4f}"
+ iterator.set_postfix(postfix)
+
+ avg_loss = running_loss / max(total, 1)
+ acc = correct / max(total, 1)
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return avg_loss, acc, avg_lyap
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader,
+ device: torch.device,
+ ce_loss: nn.Module,
+ progress: bool,
+):
+ model.eval()
+ total = 0
+ correct = 0
+ running_loss = 0.0
+
+ iterator = tqdm(loader, desc="eval", leave=False, dynamic_ncols=True) if progress else loader
+
+ for x, y in iterator:
+ x = x.to(device)
+ y = y.to(device)
+
+ logits, _ = model(x, compute_lyapunov=False)
+ loss = ce_loss(logits, y)
+
+ running_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+
+ avg_loss = running_loss / max(total, 1)
+ acc = correct / max(total, 1)
+ return avg_loss, acc
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ # Prepare output directory
+ run_dir = _prepare_run_dir(args.out_dir)
+ with open(os.path.join(run_dir, "args.json"), "w") as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Load data
+ train_loader, val_loader = get_dataloader(args.cfg)
+
+ # Infer dimensions from data
+ xb, yb = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = args.classes
+
+ print(f"Data: T={T}, D={D}, classes={C}")
+ print(f"Model: {args.model}, hidden={args.hidden}")
+
+ # Create model
+ from snntorch import surrogate
+ spike_grad = surrogate.fast_sigmoid(slope=args.surrogate_slope)
+
+ if args.model == "feedforward":
+ model = create_snn(
+ model_type="feedforward",
+ input_dim=D,
+ hidden_dims=args.hidden,
+ num_classes=C,
+ beta=args.beta,
+ threshold=args.threshold,
+ spike_grad=spike_grad,
+ dropout=args.dropout,
+ )
+ else: # recurrent
+ model = create_snn(
+ model_type="recurrent",
+ input_dim=D,
+ hidden_dims=args.hidden,
+ num_classes=C,
+ alpha=args.alpha,
+ beta=args.beta,
+ threshold=args.threshold,
+ spike_grad=spike_grad,
+ )
+
+ model = model.to(device)
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"Parameters: {num_params:,}")
+
+ optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ ce_loss = nn.CrossEntropyLoss()
+
+ print(f"\nTraining on {device} | lyapunov={args.lyapunov} λ_reg={args.lambda_reg} λ_target={args.lambda_target}")
+ print(f"Output: {run_dir}\n")
+
+ best_val_acc = 0.0
+
+ for epoch in range(1, args.epochs + 1):
+ t0 = time.time()
+
+ tr_loss, tr_acc, tr_lyap = train_one_epoch(
+ model=model,
+ loader=train_loader,
+ optimizer=optimizer,
+ device=device,
+ ce_loss=ce_loss,
+ lyapunov=args.lyapunov,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ lyap_layers=args.lyap_layers,
+ progress=(not args.no_progress),
+ run_dir=run_dir,
+ epoch_idx=epoch,
+ log_batches=args.log_batches,
+ )
+
+ val_loss, val_acc = evaluate(
+ model=model,
+ loader=val_loader,
+ device=device,
+ ce_loss=ce_loss,
+ progress=(not args.no_progress),
+ )
+
+ dt = time.time() - t0
+ lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else ""
+
+ print(
+ f"[Epoch {epoch:3d}] "
+ f"train_loss={tr_loss:.4f} train_acc={tr_acc:.3f}{lyap_str} | "
+ f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} ({dt:.1f}s)"
+ )
+
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "epoch",
+ "epoch": int(epoch),
+ "batch": -1,
+ "loss": float(tr_loss),
+ "acc": float(tr_acc),
+ "val_loss": float(val_loss),
+ "val_acc": float(val_acc),
+ "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"),
+ "time_sec": float(dt),
+ },
+ )
+
+ # Save best model
+ if args.save_model and val_acc > best_val_acc:
+ best_val_acc = val_acc
+ torch.save(model.state_dict(), os.path.join(run_dir, "best_model.pt"))
+
+ print(f"\nTraining complete. Best val_acc: {best_val_acc:.3f}")
+ if args.save_model:
+ torch.save(model.state_dict(), os.path.join(run_dir, "final_model.pt"))
+ print(f"Model saved to {run_dir}")
+
+
+if __name__ == "__main__":
+ main()