import os import sys import json import csv _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(_HERE) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import time from typing import Optional import torch import torch.nn as nn import torch.optim as optim from files.data_io.dataset_loader import get_dataloader, SHDDataset from files.models.snn import SimpleSNN from tqdm.auto import tqdm def _prepare_run_dir(base_dir: str): ts = time.strftime("%Y%m%d-%H%M%S") run_dir = os.path.join(base_dir, ts) os.makedirs(run_dir, exist_ok=True) return run_dir def _append_metrics(csv_path: str, row: dict): write_header = not os.path.exists(csv_path) with open(csv_path, "a", newline="") as f: writer = csv.DictWriter(f, fieldnames=row.keys()) if write_header: writer.writeheader() writer.writerow(row) def parse_args(): p = argparse.ArgumentParser(description="MVP training: baseline vs Lyapunov-regularized") p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="YAML config for dataloader") p.add_argument("--epochs", type=int, default=2) p.add_argument("--hidden", type=int, default=256) p.add_argument("--classes", type=int, default=20, help="Number of classes") p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regulation") p.add_argument("--lambda_reg", type=float, default=0.1, help="Weight for Lyapunov penalty") p.add_argument("--lambda_target", type=float, default=0.0, help="Target average log growth (≈0 for neutral)") p.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bar") p.add_argument("--out_dir", type=str, default="runs/mvp", help="Directory to save metrics/checkpoints") p.add_argument("--log_batches", action="store_true", help="Also log per-batch metrics to CSV") # Model dynamics and recurrence controls p.add_argument("--spike_alpha", type=float, default=5.0, help="Surrogate spike sharpness") p.add_argument("--decay", type=float, default=0.95, help="Membrane decay") p.add_argument("--v_threshold", type=float, default=1.0, help="Firing threshold") p.add_argument("--rec_strength", type=float, default=0.0, help="Recurrent coupling strength on spikes") p.add_argument("--rec_init_scale", type=float, default=1.0, help="Gain for recurrent weight init") # Lyapunov measurement controls p.add_argument("--lyap_measure", type=str, default="v", choices=["v", "s"], help="Measure divergence on 'v' or 's'") p.add_argument("--lyap_eps", type=float, default=1e-3, help="Initial perturbation magnitude") return p.parse_args() def train_one_epoch( model: SimpleSNN, loader, optimizer, device, ce_loss: nn.Module, lyapunov: bool, lambda_reg: float, lambda_target: float, progress: bool, run_dir: str | None = None, epoch_idx: int | None = None, log_batches: bool = False, lyap_measure: str = "v", lyap_eps: float = 1e-3, ): model.train() total = 0 correct = 0 running_loss = 0.0 lyap_vals = [] iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader for bidx, (x, y) in enumerate(iterator): x = x.to(device) # (B, T, D) y = y.to(device) optimizer.zero_grad(set_to_none=True) logits, lyap_est = model( x, compute_lyapunov=lyapunov, lyap_eps=lyap_eps, lyap_measure=lyap_measure, ) ce = ce_loss(logits, y) if lyapunov and lyap_est is not None: reg = (lyap_est - lambda_target) ** 2 loss = ce + lambda_reg * reg lyap_vals.append(lyap_est.detach().item()) else: loss = ce loss.backward() optimizer.step() running_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) batch_correct = (preds == y).sum().item() correct += batch_correct total += x.size(0) if log_batches and run_dir is not None and epoch_idx is not None: _append_metrics( os.path.join(run_dir, "metrics.csv"), { "step": "batch", "epoch": int(epoch_idx), "batch": int(bidx), "loss": float(loss.item()), "acc": float(batch_correct / max(x.size(0), 1)), "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"), "time_sec": float(0.0), }, ) if progress: avg_loss = running_loss / max(total, 1) avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None postfix = {"loss": f"{avg_loss:.4f}"} if avg_lyap is not None: postfix["lyap"] = f"{avg_lyap:.4f}" iterator.set_postfix(postfix) avg_loss = running_loss / max(total, 1) acc = correct / max(total, 1) avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None return avg_loss, acc, avg_lyap def main(): args = parse_args() device = torch.device(args.device) # Prepare output directory and save run config run_dir = _prepare_run_dir(args.out_dir) with open(os.path.join(run_dir, "args.json"), "w") as f: json.dump(vars(args), f, indent=2) train_loader, val_loader = get_dataloader(args.cfg) # Infer input dim and classes from a sample and args xb, yb = next(iter(train_loader)) _, T, D = xb.shape C = args.classes model = SimpleSNN( input_dim=D, hidden_dim=args.hidden, num_classes=C, v_threshold=args.v_threshold, decay=args.decay, spike_alpha=args.spike_alpha, rec_strength=args.rec_strength, rec_init_scale=args.rec_init_scale, ).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) ce_loss = nn.CrossEntropyLoss() print(f"Starting training on {device} | lyapunov={args.lyapunov} lambda={args.lambda_reg} target={args.lambda_target}") print(f"Saving run to: {run_dir}") for epoch in range(1, args.epochs + 1): t0 = time.time() tr_loss, tr_acc, tr_lyap = train_one_epoch( model, train_loader, optimizer, device, ce_loss, lyapunov=args.lyapunov, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, progress=(not args.no_progress), run_dir=run_dir, epoch_idx=epoch, log_batches=args.log_batches, lyap_measure=args.lyap_measure, lyap_eps=args.lyap_eps, ) dt = time.time() - t0 lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else "" print(f"[Epoch {epoch}] loss={tr_loss:.4f} acc={tr_acc:.3f}{lyap_str} ({dt:.1f}s)") _append_metrics( os.path.join(run_dir, "metrics.csv"), { "step": "epoch", "epoch": int(epoch), "batch": int(-1), "loss": float(tr_loss), "acc": float(tr_acc), "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"), "time_sec": float(dt), }, ) if __name__ == "__main__": main()