""" Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets. Datasets: - Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks - Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory - CIFAR-10: Rate-coded images, requires hierarchical features Usage: python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8 python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10 """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Optional, Tuple _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from tqdm.auto import tqdm from files.models.snn_snntorch import LyapunovSNN from files.data_io.benchmark_datasets import get_benchmark_dataloader @dataclass class EpochMetrics: epoch: int train_loss: float train_acc: float val_loss: float val_acc: float lyapunov: Optional[float] grad_norm: float grad_max_sv: Optional[float] grad_min_sv: Optional[float] grad_condition: Optional[float] time_sec: float def compute_gradient_svs(model): """Compute gradient singular value statistics.""" max_svs = [] min_svs = [] for name, param in model.named_parameters(): if param.grad is not None and param.ndim == 2: with torch.no_grad(): G = param.grad.detach() try: sv = torch.linalg.svdvals(G) if len(sv) > 0: max_svs.append(sv[0].item()) min_svs.append(sv[-1].item()) except Exception: pass if not max_svs: return None, None, None max_sv = max(max_svs) min_sv = min(min_svs) condition = max_sv / (min_sv + 1e-12) return max_sv, min_sv, condition def create_model( input_dim: int, num_classes: int, depth: int, hidden_dim: int = 128, beta: float = 0.9, ) -> LyapunovSNN: """Create SNN with specified depth.""" hidden_dims = [hidden_dim] * depth return LyapunovSNN( input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes, beta=beta, threshold=1.0, ) def train_epoch( model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, ce_loss: nn.Module, device: torch.device, use_lyapunov: bool, lambda_reg: float, lambda_target: float, lyap_eps: float, compute_sv_every: int = 10, ) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]: """Train one epoch.""" model.train() total_loss = 0.0 total_correct = 0 total_samples = 0 lyap_vals = [] grad_norms = [] grad_max_svs = [] grad_min_svs = [] grad_conditions = [] for batch_idx, (x, y) in enumerate(loader): x, y = x.to(device), y.to(device) # Handle different input shapes if x.ndim == 2: x = x.unsqueeze(-1) # (B, T) -> (B, T, 1) optimizer.zero_grad() logits, lyap_est, _ = model( x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps, record_states=False, ) ce = ce_loss(logits, y) if use_lyapunov and lyap_est is not None: reg = (lyap_est - lambda_target) ** 2 loss = ce + lambda_reg * reg lyap_vals.append(lyap_est.item()) else: loss = ce if torch.isnan(loss): return float('nan'), 0.0, None, float('nan'), None, None, None loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 grad_norms.append(grad_norm) # Compute gradient SVs periodically if batch_idx % compute_sv_every == 0: max_sv, min_sv, cond = compute_gradient_svs(model) if max_sv is not None: grad_max_svs.append(max_sv) grad_min_svs.append(min_sv) grad_conditions.append(cond) optimizer.step() total_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) avg_loss = total_loss / total_samples avg_acc = total_correct / total_samples avg_lyap = np.mean(lyap_vals) if lyap_vals else None avg_grad = np.mean(grad_norms) avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None avg_cond = np.mean(grad_conditions) if grad_conditions else None return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond @torch.no_grad() def evaluate( model: nn.Module, loader: DataLoader, ce_loss: nn.Module, device: torch.device, ) -> Tuple[float, float]: """Evaluate on validation set.""" model.eval() total_loss = 0.0 total_correct = 0 total_samples = 0 for x, y in loader: x, y = x.to(device), y.to(device) if x.ndim == 2: x = x.unsqueeze(-1) logits, _, _ = model(x, compute_lyapunov=False, record_states=False) loss = ce_loss(logits, y) if torch.isnan(loss): return float('nan'), 0.0 total_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) return total_loss / total_samples, total_correct / total_samples def run_experiment( depth: int, use_lyapunov: bool, train_loader: DataLoader, val_loader: DataLoader, input_dim: int, num_classes: int, hidden_dim: int, epochs: int, lr: float, lambda_reg: float, lambda_target: float, lyap_eps: float, device: torch.device, seed: int, progress: bool = True, ) -> List[EpochMetrics]: """Run single experiment configuration.""" torch.manual_seed(seed) model = create_model( input_dim=input_dim, num_classes=num_classes, depth=depth, hidden_dim=hidden_dim, ).to(device) optimizer = optim.Adam(model.parameters(), lr=lr) ce_loss = nn.CrossEntropyLoss() method = "Lyapunov" if use_lyapunov else "Vanilla" metrics_history = [] iterator = range(1, epochs + 1) if progress: iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False) for epoch in iterator: t0 = time.time() train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( model, train_loader, optimizer, ce_loss, device, use_lyapunov, lambda_reg, lambda_target, lyap_eps, ) val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) dt = time.time() - t0 metrics = EpochMetrics( epoch=epoch, train_loss=train_loss, train_acc=train_acc, val_loss=val_loss, val_acc=val_acc, lyapunov=lyap, grad_norm=grad_norm, grad_max_sv=grad_max_sv, grad_min_sv=grad_min_sv, grad_condition=grad_cond, time_sec=dt, ) metrics_history.append(metrics) if progress: lyap_str = f"λ={lyap:.2f}" if lyap else "" iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str}) if np.isnan(train_loss): print(f" Training diverged at epoch {epoch}") break return metrics_history def run_depth_comparison( dataset_name: str, depths: List[int], train_loader: DataLoader, val_loader: DataLoader, input_dim: int, num_classes: int, hidden_dim: int, epochs: int, lr: float, lambda_reg: float, lambda_target: float, lyap_eps: float, device: torch.device, seed: int, progress: bool = True, ) -> Dict[str, Dict[int, List[EpochMetrics]]]: """Run comparison across depths.""" results = {"vanilla": {}, "lyapunov": {}} for depth in depths: print(f"\n{'='*60}") print(f"Depth = {depth} layers") print(f"{'='*60}") for use_lyap in [False, True]: method = "lyapunov" if use_lyap else "vanilla" print(f"\n Training {method.upper()}...") metrics = run_experiment( depth=depth, use_lyapunov=use_lyap, train_loader=train_loader, val_loader=val_loader, input_dim=input_dim, num_classes=num_classes, hidden_dim=hidden_dim, epochs=epochs, lr=lr, lambda_reg=lambda_reg, lambda_target=lambda_target, lyap_eps=lyap_eps, device=device, seed=seed, progress=progress, ) results[method][depth] = metrics final = metrics[-1] lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " f"val_acc={final.val_acc:.3f} {lyap_str}") return results def print_summary(results: Dict, dataset_name: str): """Print summary table.""" print("\n" + "=" * 90) print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy") print("=" * 90) print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}") print("-" * 90) depths = sorted(results["vanilla"].keys()) for depth in depths: van = results["vanilla"][depth][-1] lyap = results["lyapunov"][depth][-1] van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" diff = lyap_acc - van_acc diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}") print("=" * 90) # Gradient health analysis print("\nGRADIENT HEALTH:") for depth in depths: van = results["vanilla"][depth][-1] van_cond = van.grad_condition if van.grad_condition else 0 if van_cond > 1e6: print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})") elif van_cond > 1e4: print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})") def save_results(results: Dict, output_dir: str, config: Dict): """Save results to JSON.""" os.makedirs(output_dir, exist_ok=True) serializable = {} for method, depth_results in results.items(): serializable[method] = {} for depth, metrics_list in depth_results.items(): serializable[method][str(depth)] = [asdict(m) for m in metrics_list] with open(os.path.join(output_dir, "results.json"), "w") as f: json.dump(serializable, f, indent=2) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f"\nResults saved to {output_dir}") def parse_args(): p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN") # Dataset p.add_argument("--dataset", type=str, default="smnist", choices=["smnist", "psmnist", "cifar10"], help="Dataset to use") p.add_argument("--data_dir", type=str, default="./data") # Model p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8], help="Network depths to test") p.add_argument("--hidden_dim", type=int, default=128) # Training p.add_argument("--epochs", type=int, default=30) p.add_argument("--batch_size", type=int, default=128) p.add_argument("--lr", type=float, default=1e-3) # Lyapunov p.add_argument("--lambda_reg", type=float, default=0.3, help="Lyapunov regularization weight (higher for harder tasks)") p.add_argument("--lambda_target", type=float, default=-0.1, help="Target Lyapunov exponent (negative for stability)") p.add_argument("--lyap_eps", type=float, default=1e-4) # Dataset-specific p.add_argument("--T", type=int, default=100, help="Timesteps for CIFAR-10 (sMNIST uses 784)") p.add_argument("--n_repeat", type=int, default=1, help="Repeat each pixel n times for sMNIST") # Other p.add_argument("--seed", type=int, default=42) p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--out_dir", type=str, default="runs/benchmark") p.add_argument("--no-progress", action="store_true") return p.parse_args() def main(): args = parse_args() device = torch.device(args.device) print("=" * 70) print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}") print("=" * 70) print(f"Depths: {args.depths}") print(f"Hidden dim: {args.hidden_dim}") print(f"Epochs: {args.epochs}") print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") print(f"Device: {device}") print("=" * 70) # Load dataset print(f"\nLoading {args.dataset} dataset...") if args.dataset == "smnist": train_loader, val_loader, info = get_benchmark_dataloader( "smnist", batch_size=args.batch_size, root=args.data_dir, n_repeat=args.n_repeat, spike_encoding="direct", ) elif args.dataset == "psmnist": train_loader, val_loader, info = get_benchmark_dataloader( "psmnist", batch_size=args.batch_size, root=args.data_dir, n_repeat=args.n_repeat, spike_encoding="direct", ) elif args.dataset == "cifar10": train_loader, val_loader, info = get_benchmark_dataloader( "cifar10", batch_size=args.batch_size, root=args.data_dir, T=args.T, ) print(f"Dataset info: {info}") print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") # Run experiments results = run_depth_comparison( dataset_name=args.dataset, depths=args.depths, train_loader=train_loader, val_loader=val_loader, input_dim=info["D"], num_classes=info["classes"], hidden_dim=args.hidden_dim, epochs=args.epochs, lr=args.lr, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, lyap_eps=args.lyap_eps, device=device, seed=args.seed, progress=not args.no_progress, ) # Print summary print_summary(results, args.dataset) # Save results ts = time.strftime("%Y%m%d-%H%M%S") output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") save_results(results, output_dir, vars(args)) if __name__ == "__main__": main()