From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/experiments/benchmark_experiment.py | 518 ++++++++++++++++++++++++++++++ 1 file changed, 518 insertions(+) create mode 100644 files/experiments/benchmark_experiment.py (limited to 'files/experiments/benchmark_experiment.py') diff --git a/files/experiments/benchmark_experiment.py b/files/experiments/benchmark_experiment.py new file mode 100644 index 0000000..fb01ff2 --- /dev/null +++ b/files/experiments/benchmark_experiment.py @@ -0,0 +1,518 @@ +""" +Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets. + +Datasets: +- Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks +- Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory +- CIFAR-10: Rate-coded images, requires hierarchical features + +Usage: + python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8 + python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN +from files.data_io.benchmark_datasets import get_benchmark_dataloader + + +@dataclass +class EpochMetrics: + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + grad_max_sv: Optional[float] + grad_min_sv: Optional[float] + grad_condition: Optional[float] + time_sec: float + + +def compute_gradient_svs(model): + """Compute gradient singular value statistics.""" + max_svs = [] + min_svs = [] + + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + with torch.no_grad(): + G = param.grad.detach() + try: + sv = torch.linalg.svdvals(G) + if len(sv) > 0: + max_svs.append(sv[0].item()) + min_svs.append(sv[-1].item()) + except Exception: + pass + + if not max_svs: + return None, None, None + + max_sv = max(max_svs) + min_sv = min(min_svs) + condition = max_sv / (min_sv + 1e-12) + + return max_sv, min_sv, condition + + +def create_model( + input_dim: int, + num_classes: int, + depth: int, + hidden_dim: int = 128, + beta: float = 0.9, +) -> LyapunovSNN: + """Create SNN with specified depth.""" + hidden_dims = [hidden_dim] * depth + return LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=beta, + threshold=1.0, + ) + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + compute_sv_every: int = 10, +) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]: + """Train one epoch.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + grad_max_svs = [] + grad_min_svs = [] + grad_conditions = [] + + for batch_idx, (x, y) in enumerate(loader): + x, y = x.to(device), y.to(device) + + # Handle different input shapes + if x.ndim == 2: + x = x.unsqueeze(-1) # (B, T) -> (B, T, 1) + + optimizer.zero_grad() + + logits, lyap_est, _ = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + record_states=False, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan'), None, None, None + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + # Compute gradient SVs periodically + if batch_idx % compute_sv_every == 0: + max_sv, min_sv, cond = compute_gradient_svs(model) + if max_sv is not None: + grad_max_svs.append(max_sv) + grad_min_svs.append(min_sv) + grad_conditions.append(cond) + + optimizer.step() + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + avg_loss = total_loss / total_samples + avg_acc = total_correct / total_samples + avg_lyap = np.mean(lyap_vals) if lyap_vals else None + avg_grad = np.mean(grad_norms) + avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None + avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None + avg_cond = np.mean(grad_conditions) if grad_conditions else None + + return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """Evaluate on validation set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + for x, y in loader: + x, y = x.to(device), y.to(device) + + if x.ndim == 2: + x = x.unsqueeze(-1) + + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + loss = ce_loss(logits, y) + + if torch.isnan(loss): + return float('nan'), 0.0 + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_experiment( + depth: int, + use_lyapunov: bool, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> List[EpochMetrics]: + """Run single experiment configuration.""" + torch.manual_seed(seed) + + model = create_model( + input_dim=input_dim, + num_classes=num_classes, + depth=depth, + hidden_dim=hidden_dim, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=lr) + ce_loss = nn.CrossEntropyLoss() + + method = "Lyapunov" if use_lyapunov else "Vanilla" + metrics_history = [] + + iterator = range(1, epochs + 1) + if progress: + iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False) + + for epoch in iterator: + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, ce_loss, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, + ) + + val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) + dt = time.time() - t0 + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=val_loss, + val_acc=val_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + time_sec=dt, + ) + metrics_history.append(metrics) + + if progress: + lyap_str = f"λ={lyap:.2f}" if lyap else "" + iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str}) + + if np.isnan(train_loss): + print(f" Training diverged at epoch {epoch}") + break + + return metrics_history + + +def run_depth_comparison( + dataset_name: str, + depths: List[int], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> Dict[str, Dict[int, List[EpochMetrics]]]: + """Run comparison across depths.""" + results = {"vanilla": {}, "lyapunov": {}} + + for depth in depths: + print(f"\n{'='*60}") + print(f"Depth = {depth} layers") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + metrics = run_experiment( + depth=depth, + use_lyapunov=use_lyap, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + hidden_dim=hidden_dim, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=lyap_eps, + device=device, + seed=seed, + progress=progress, + ) + + results[method][depth] = metrics + + final = metrics[-1] + lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" + print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " + f"val_acc={final.val_acc:.3f} {lyap_str}") + + return results + + +def print_summary(results: Dict, dataset_name: str): + """Print summary table.""" + print("\n" + "=" * 90) + print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy") + print("=" * 90) + print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}") + print("-" * 90) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" + van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" + + print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}") + + print("=" * 90) + + # Gradient health analysis + print("\nGRADIENT HEALTH:") + for depth in depths: + van = results["vanilla"][depth][-1] + van_cond = van.grad_condition if van.grad_condition else 0 + if van_cond > 1e6: + print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})") + elif van_cond > 1e4: + print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})") + + +def save_results(results: Dict, output_dir: str, config: Dict): + """Save results to JSON.""" + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN") + + # Dataset + p.add_argument("--dataset", type=str, default="smnist", + choices=["smnist", "psmnist", "cifar10"], + help="Dataset to use") + p.add_argument("--data_dir", type=str, default="./data") + + # Model + p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8], + help="Network depths to test") + p.add_argument("--hidden_dim", type=int, default=128) + + # Training + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + + # Lyapunov + p.add_argument("--lambda_reg", type=float, default=0.3, + help="Lyapunov regularization weight (higher for harder tasks)") + p.add_argument("--lambda_target", type=float, default=-0.1, + help="Target Lyapunov exponent (negative for stability)") + p.add_argument("--lyap_eps", type=float, default=1e-4) + + # Dataset-specific + p.add_argument("--T", type=int, default=100, + help="Timesteps for CIFAR-10 (sMNIST uses 784)") + p.add_argument("--n_repeat", type=int, default=1, + help="Repeat each pixel n times for sMNIST") + + # Other + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--out_dir", type=str, default="runs/benchmark") + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}") + print("=" * 70) + print(f"Depths: {args.depths}") + print(f"Hidden dim: {args.hidden_dim}") + print(f"Epochs: {args.epochs}") + print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") + print(f"Device: {device}") + print("=" * 70) + + # Load dataset + print(f"\nLoading {args.dataset} dataset...") + + if args.dataset == "smnist": + train_loader, val_loader, info = get_benchmark_dataloader( + "smnist", + batch_size=args.batch_size, + root=args.data_dir, + n_repeat=args.n_repeat, + spike_encoding="direct", + ) + elif args.dataset == "psmnist": + train_loader, val_loader, info = get_benchmark_dataloader( + "psmnist", + batch_size=args.batch_size, + root=args.data_dir, + n_repeat=args.n_repeat, + spike_encoding="direct", + ) + elif args.dataset == "cifar10": + train_loader, val_loader, info = get_benchmark_dataloader( + "cifar10", + batch_size=args.batch_size, + root=args.data_dir, + T=args.T, + ) + + print(f"Dataset info: {info}") + print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") + + # Run experiments + results = run_depth_comparison( + dataset_name=args.dataset, + depths=args.depths, + train_loader=train_loader, + val_loader=val_loader, + input_dim=info["D"], + num_classes=info["classes"], + hidden_dim=args.hidden_dim, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Print summary + print_summary(results, args.dataset) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() -- cgit v1.2.3