diff options
Diffstat (limited to 'files/train_mvp.py')
| -rw-r--r-- | files/train_mvp.py | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/files/train_mvp.py b/files/train_mvp.py new file mode 100644 index 0000000..b89ddc6 --- /dev/null +++ b/files/train_mvp.py @@ -0,0 +1,202 @@ +import os +import sys +import json +import csv +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(_HERE) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) +import argparse +import time +from typing import Optional + +import torch +import torch.nn as nn +import torch.optim as optim + +from files.data_io.dataset_loader import get_dataloader, SHDDataset +from files.models.snn import SimpleSNN +from tqdm.auto import tqdm + + +def _prepare_run_dir(base_dir: str): + ts = time.strftime("%Y%m%d-%H%M%S") + run_dir = os.path.join(base_dir, ts) + os.makedirs(run_dir, exist_ok=True) + return run_dir + + +def _append_metrics(csv_path: str, row: dict): + write_header = not os.path.exists(csv_path) + with open(csv_path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=row.keys()) + if write_header: + writer.writeheader() + writer.writerow(row) + + +def parse_args(): + p = argparse.ArgumentParser(description="MVP training: baseline vs Lyapunov-regularized") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="YAML config for dataloader") + p.add_argument("--epochs", type=int, default=2) + p.add_argument("--hidden", type=int, default=256) + p.add_argument("--classes", type=int, default=20, help="Number of classes") + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regulation") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Weight for Lyapunov penalty") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target average log growth (≈0 for neutral)") + p.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bar") + p.add_argument("--out_dir", type=str, default="runs/mvp", help="Directory to save metrics/checkpoints") + p.add_argument("--log_batches", action="store_true", help="Also log per-batch metrics to CSV") + # Model dynamics and recurrence controls + p.add_argument("--spike_alpha", type=float, default=5.0, help="Surrogate spike sharpness") + p.add_argument("--decay", type=float, default=0.95, help="Membrane decay") + p.add_argument("--v_threshold", type=float, default=1.0, help="Firing threshold") + p.add_argument("--rec_strength", type=float, default=0.0, help="Recurrent coupling strength on spikes") + p.add_argument("--rec_init_scale", type=float, default=1.0, help="Gain for recurrent weight init") + # Lyapunov measurement controls + p.add_argument("--lyap_measure", type=str, default="v", choices=["v", "s"], help="Measure divergence on 'v' or 's'") + p.add_argument("--lyap_eps", type=float, default=1e-3, help="Initial perturbation magnitude") + return p.parse_args() + + +def train_one_epoch( + model: SimpleSNN, + loader, + optimizer, + device, + ce_loss: nn.Module, + lyapunov: bool, + lambda_reg: float, + lambda_target: float, + progress: bool, + run_dir: str | None = None, + epoch_idx: int | None = None, + log_batches: bool = False, + lyap_measure: str = "v", + lyap_eps: float = 1e-3, +): + model.train() + total = 0 + correct = 0 + running_loss = 0.0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader + for bidx, (x, y) in enumerate(iterator): + x = x.to(device) # (B, T, D) + y = y.to(device) + optimizer.zero_grad(set_to_none=True) + logits, lyap_est = model( + x, + compute_lyapunov=lyapunov, + lyap_eps=lyap_eps, + lyap_measure=lyap_measure, + ) + ce = ce_loss(logits, y) + if lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.detach().item()) + else: + loss = ce + loss.backward() + optimizer.step() + + running_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + batch_correct = (preds == y).sum().item() + correct += batch_correct + total += x.size(0) + if log_batches and run_dir is not None and epoch_idx is not None: + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "batch", + "epoch": int(epoch_idx), + "batch": int(bidx), + "loss": float(loss.item()), + "acc": float(batch_correct / max(x.size(0), 1)), + "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"), + "time_sec": float(0.0), + }, + ) + if progress: + avg_loss = running_loss / max(total, 1) + avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None + postfix = {"loss": f"{avg_loss:.4f}"} + if avg_lyap is not None: + postfix["lyap"] = f"{avg_lyap:.4f}" + iterator.set_postfix(postfix) + + avg_loss = running_loss / max(total, 1) + acc = correct / max(total, 1) + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return avg_loss, acc, avg_lyap + + +def main(): + args = parse_args() + device = torch.device(args.device) + + # Prepare output directory and save run config + run_dir = _prepare_run_dir(args.out_dir) + with open(os.path.join(run_dir, "args.json"), "w") as f: + json.dump(vars(args), f, indent=2) + + train_loader, val_loader = get_dataloader(args.cfg) + + # Infer input dim and classes from a sample and args + xb, yb = next(iter(train_loader)) + _, T, D = xb.shape + C = args.classes + + model = SimpleSNN( + input_dim=D, + hidden_dim=args.hidden, + num_classes=C, + v_threshold=args.v_threshold, + decay=args.decay, + spike_alpha=args.spike_alpha, + rec_strength=args.rec_strength, + rec_init_scale=args.rec_init_scale, + ).to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + ce_loss = nn.CrossEntropyLoss() + + print(f"Starting training on {device} | lyapunov={args.lyapunov} lambda={args.lambda_reg} target={args.lambda_target}") + print(f"Saving run to: {run_dir}") + for epoch in range(1, args.epochs + 1): + t0 = time.time() + tr_loss, tr_acc, tr_lyap = train_one_epoch( + model, train_loader, optimizer, device, ce_loss, + lyapunov=args.lyapunov, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, + progress=(not args.no_progress), + run_dir=run_dir, + epoch_idx=epoch, + log_batches=args.log_batches, + lyap_measure=args.lyap_measure, + lyap_eps=args.lyap_eps, + ) + dt = time.time() - t0 + lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else "" + print(f"[Epoch {epoch}] loss={tr_loss:.4f} acc={tr_acc:.3f}{lyap_str} ({dt:.1f}s)") + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "epoch", + "epoch": int(epoch), + "batch": int(-1), + "loss": float(tr_loss), + "acc": float(tr_acc), + "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"), + "time_sec": float(dt), + }, + ) + + +if __name__ == "__main__": + main() + + |
