summaryrefslogtreecommitdiff
path: root/files/train_mvp.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/train_mvp.py')
-rw-r--r--files/train_mvp.py202
1 files changed, 202 insertions, 0 deletions
diff --git a/files/train_mvp.py b/files/train_mvp.py
new file mode 100644
index 0000000..b89ddc6
--- /dev/null
+++ b/files/train_mvp.py
@@ -0,0 +1,202 @@
+import os
+import sys
+import json
+import csv
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(_HERE)
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+import argparse
+import time
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from files.data_io.dataset_loader import get_dataloader, SHDDataset
+from files.models.snn import SimpleSNN
+from tqdm.auto import tqdm
+
+
+def _prepare_run_dir(base_dir: str):
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ run_dir = os.path.join(base_dir, ts)
+ os.makedirs(run_dir, exist_ok=True)
+ return run_dir
+
+
+def _append_metrics(csv_path: str, row: dict):
+ write_header = not os.path.exists(csv_path)
+ with open(csv_path, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=row.keys())
+ if write_header:
+ writer.writeheader()
+ writer.writerow(row)
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="MVP training: baseline vs Lyapunov-regularized")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="YAML config for dataloader")
+ p.add_argument("--epochs", type=int, default=2)
+ p.add_argument("--hidden", type=int, default=256)
+ p.add_argument("--classes", type=int, default=20, help="Number of classes")
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regulation")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Weight for Lyapunov penalty")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target average log growth (≈0 for neutral)")
+ p.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bar")
+ p.add_argument("--out_dir", type=str, default="runs/mvp", help="Directory to save metrics/checkpoints")
+ p.add_argument("--log_batches", action="store_true", help="Also log per-batch metrics to CSV")
+ # Model dynamics and recurrence controls
+ p.add_argument("--spike_alpha", type=float, default=5.0, help="Surrogate spike sharpness")
+ p.add_argument("--decay", type=float, default=0.95, help="Membrane decay")
+ p.add_argument("--v_threshold", type=float, default=1.0, help="Firing threshold")
+ p.add_argument("--rec_strength", type=float, default=0.0, help="Recurrent coupling strength on spikes")
+ p.add_argument("--rec_init_scale", type=float, default=1.0, help="Gain for recurrent weight init")
+ # Lyapunov measurement controls
+ p.add_argument("--lyap_measure", type=str, default="v", choices=["v", "s"], help="Measure divergence on 'v' or 's'")
+ p.add_argument("--lyap_eps", type=float, default=1e-3, help="Initial perturbation magnitude")
+ return p.parse_args()
+
+
+def train_one_epoch(
+ model: SimpleSNN,
+ loader,
+ optimizer,
+ device,
+ ce_loss: nn.Module,
+ lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ progress: bool,
+ run_dir: str | None = None,
+ epoch_idx: int | None = None,
+ log_batches: bool = False,
+ lyap_measure: str = "v",
+ lyap_eps: float = 1e-3,
+):
+ model.train()
+ total = 0
+ correct = 0
+ running_loss = 0.0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader
+ for bidx, (x, y) in enumerate(iterator):
+ x = x.to(device) # (B, T, D)
+ y = y.to(device)
+ optimizer.zero_grad(set_to_none=True)
+ logits, lyap_est = model(
+ x,
+ compute_lyapunov=lyapunov,
+ lyap_eps=lyap_eps,
+ lyap_measure=lyap_measure,
+ )
+ ce = ce_loss(logits, y)
+ if lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.detach().item())
+ else:
+ loss = ce
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ batch_correct = (preds == y).sum().item()
+ correct += batch_correct
+ total += x.size(0)
+ if log_batches and run_dir is not None and epoch_idx is not None:
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "batch",
+ "epoch": int(epoch_idx),
+ "batch": int(bidx),
+ "loss": float(loss.item()),
+ "acc": float(batch_correct / max(x.size(0), 1)),
+ "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"),
+ "time_sec": float(0.0),
+ },
+ )
+ if progress:
+ avg_loss = running_loss / max(total, 1)
+ avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None
+ postfix = {"loss": f"{avg_loss:.4f}"}
+ if avg_lyap is not None:
+ postfix["lyap"] = f"{avg_lyap:.4f}"
+ iterator.set_postfix(postfix)
+
+ avg_loss = running_loss / max(total, 1)
+ acc = correct / max(total, 1)
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return avg_loss, acc, avg_lyap
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ # Prepare output directory and save run config
+ run_dir = _prepare_run_dir(args.out_dir)
+ with open(os.path.join(run_dir, "args.json"), "w") as f:
+ json.dump(vars(args), f, indent=2)
+
+ train_loader, val_loader = get_dataloader(args.cfg)
+
+ # Infer input dim and classes from a sample and args
+ xb, yb = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = args.classes
+
+ model = SimpleSNN(
+ input_dim=D,
+ hidden_dim=args.hidden,
+ num_classes=C,
+ v_threshold=args.v_threshold,
+ decay=args.decay,
+ spike_alpha=args.spike_alpha,
+ rec_strength=args.rec_strength,
+ rec_init_scale=args.rec_init_scale,
+ ).to(device)
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ print(f"Starting training on {device} | lyapunov={args.lyapunov} lambda={args.lambda_reg} target={args.lambda_target}")
+ print(f"Saving run to: {run_dir}")
+ for epoch in range(1, args.epochs + 1):
+ t0 = time.time()
+ tr_loss, tr_acc, tr_lyap = train_one_epoch(
+ model, train_loader, optimizer, device, ce_loss,
+ lyapunov=args.lyapunov, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target,
+ progress=(not args.no_progress),
+ run_dir=run_dir,
+ epoch_idx=epoch,
+ log_batches=args.log_batches,
+ lyap_measure=args.lyap_measure,
+ lyap_eps=args.lyap_eps,
+ )
+ dt = time.time() - t0
+ lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else ""
+ print(f"[Epoch {epoch}] loss={tr_loss:.4f} acc={tr_acc:.3f}{lyap_str} ({dt:.1f}s)")
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "epoch",
+ "epoch": int(epoch),
+ "batch": int(-1),
+ "loss": float(tr_loss),
+ "acc": float(tr_acc),
+ "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"),
+ "time_sec": float(dt),
+ },
+ )
+
+
+if __name__ == "__main__":
+ main()
+
+