summaryrefslogtreecommitdiff
path: root/files/experiments/benchmark_experiment.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/benchmark_experiment.py
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/experiments/benchmark_experiment.py')
-rw-r--r--files/experiments/benchmark_experiment.py518
1 files changed, 518 insertions, 0 deletions
diff --git a/files/experiments/benchmark_experiment.py b/files/experiments/benchmark_experiment.py
new file mode 100644
index 0000000..fb01ff2
--- /dev/null
+++ b/files/experiments/benchmark_experiment.py
@@ -0,0 +1,518 @@
+"""
+Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets.
+
+Datasets:
+- Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks
+- Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory
+- CIFAR-10: Rate-coded images, requires hierarchical features
+
+Usage:
+ python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8
+ python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+from files.data_io.benchmark_datasets import get_benchmark_dataloader
+
+
+@dataclass
+class EpochMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ grad_max_sv: Optional[float]
+ grad_min_sv: Optional[float]
+ grad_condition: Optional[float]
+ time_sec: float
+
+
+def compute_gradient_svs(model):
+ """Compute gradient singular value statistics."""
+ max_svs = []
+ min_svs = []
+
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ with torch.no_grad():
+ G = param.grad.detach()
+ try:
+ sv = torch.linalg.svdvals(G)
+ if len(sv) > 0:
+ max_svs.append(sv[0].item())
+ min_svs.append(sv[-1].item())
+ except Exception:
+ pass
+
+ if not max_svs:
+ return None, None, None
+
+ max_sv = max(max_svs)
+ min_sv = min(min_svs)
+ condition = max_sv / (min_sv + 1e-12)
+
+ return max_sv, min_sv, condition
+
+
+def create_model(
+ input_dim: int,
+ num_classes: int,
+ depth: int,
+ hidden_dim: int = 128,
+ beta: float = 0.9,
+) -> LyapunovSNN:
+ """Create SNN with specified depth."""
+ hidden_dims = [hidden_dim] * depth
+ return LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=beta,
+ threshold=1.0,
+ )
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ compute_sv_every: int = 10,
+) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]:
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+ grad_max_svs = []
+ grad_min_svs = []
+ grad_conditions = []
+
+ for batch_idx, (x, y) in enumerate(loader):
+ x, y = x.to(device), y.to(device)
+
+ # Handle different input shapes
+ if x.ndim == 2:
+ x = x.unsqueeze(-1) # (B, T) -> (B, T, 1)
+
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ record_states=False,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan'), None, None, None
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ # Compute gradient SVs periodically
+ if batch_idx % compute_sv_every == 0:
+ max_sv, min_sv, cond = compute_gradient_svs(model)
+ if max_sv is not None:
+ grad_max_svs.append(max_sv)
+ grad_min_svs.append(min_sv)
+ grad_conditions.append(cond)
+
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ avg_loss = total_loss / total_samples
+ avg_acc = total_correct / total_samples
+ avg_lyap = np.mean(lyap_vals) if lyap_vals else None
+ avg_grad = np.mean(grad_norms)
+ avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None
+ avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None
+ avg_cond = np.mean(grad_conditions) if grad_conditions else None
+
+ return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+) -> Tuple[float, float]:
+ """Evaluate on validation set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+
+ if x.ndim == 2:
+ x = x.unsqueeze(-1)
+
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+ loss = ce_loss(logits, y)
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_experiment(
+ depth: int,
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run single experiment configuration."""
+ torch.manual_seed(seed)
+
+ model = create_model(
+ input_dim=input_dim,
+ num_classes=num_classes,
+ depth=depth,
+ hidden_dim=hidden_dim,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ method = "Lyapunov" if use_lyapunov else "Vanilla"
+ metrics_history = []
+
+ iterator = range(1, epochs + 1)
+ if progress:
+ iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False)
+
+ for epoch in iterator:
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, ce_loss, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps,
+ )
+
+ val_loss, val_acc = evaluate(model, val_loader, ce_loss, device)
+ dt = time.time() - t0
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=val_loss,
+ val_acc=val_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ if progress:
+ lyap_str = f"λ={lyap:.2f}" if lyap else ""
+ iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str})
+
+ if np.isnan(train_loss):
+ print(f" Training diverged at epoch {epoch}")
+ break
+
+ return metrics_history
+
+
+def run_depth_comparison(
+ dataset_name: str,
+ depths: List[int],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> Dict[str, Dict[int, List[EpochMetrics]]]:
+ """Run comparison across depths."""
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for depth in depths:
+ print(f"\n{'='*60}")
+ print(f"Depth = {depth} layers")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ metrics = run_experiment(
+ depth=depth,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ hidden_dim=hidden_dim,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=lyap_eps,
+ device=device,
+ seed=seed,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ final = metrics[-1]
+ lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A"
+ print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} "
+ f"val_acc={final.val_acc:.3f} {lyap_str}")
+
+ return results
+
+
+def print_summary(results: Dict, dataset_name: str):
+ """Print summary table."""
+ print("\n" + "=" * 90)
+ print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy")
+ print("=" * 90)
+ print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}")
+ print("-" * 90)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A"
+ van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A"
+
+ print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}")
+
+ print("=" * 90)
+
+ # Gradient health analysis
+ print("\nGRADIENT HEALTH:")
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ van_cond = van.grad_condition if van.grad_condition else 0
+ if van_cond > 1e6:
+ print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})")
+ elif van_cond > 1e4:
+ print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})")
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ """Save results to JSON."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN")
+
+ # Dataset
+ p.add_argument("--dataset", type=str, default="smnist",
+ choices=["smnist", "psmnist", "cifar10"],
+ help="Dataset to use")
+ p.add_argument("--data_dir", type=str, default="./data")
+
+ # Model
+ p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8],
+ help="Network depths to test")
+ p.add_argument("--hidden_dim", type=int, default=128)
+
+ # Training
+ p.add_argument("--epochs", type=int, default=30)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+
+ # Lyapunov
+ p.add_argument("--lambda_reg", type=float, default=0.3,
+ help="Lyapunov regularization weight (higher for harder tasks)")
+ p.add_argument("--lambda_target", type=float, default=-0.1,
+ help="Target Lyapunov exponent (negative for stability)")
+ p.add_argument("--lyap_eps", type=float, default=1e-4)
+
+ # Dataset-specific
+ p.add_argument("--T", type=int, default=100,
+ help="Timesteps for CIFAR-10 (sMNIST uses 784)")
+ p.add_argument("--n_repeat", type=int, default=1,
+ help="Repeat each pixel n times for sMNIST")
+
+ # Other
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--out_dir", type=str, default="runs/benchmark")
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}")
+ print("=" * 70)
+ print(f"Depths: {args.depths}")
+ print(f"Hidden dim: {args.hidden_dim}")
+ print(f"Epochs: {args.epochs}")
+ print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}")
+ print(f"Device: {device}")
+ print("=" * 70)
+
+ # Load dataset
+ print(f"\nLoading {args.dataset} dataset...")
+
+ if args.dataset == "smnist":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "smnist",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ n_repeat=args.n_repeat,
+ spike_encoding="direct",
+ )
+ elif args.dataset == "psmnist":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "psmnist",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ n_repeat=args.n_repeat,
+ spike_encoding="direct",
+ )
+ elif args.dataset == "cifar10":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "cifar10",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ T=args.T,
+ )
+
+ print(f"Dataset info: {info}")
+ print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
+
+ # Run experiments
+ results = run_depth_comparison(
+ dataset_name=args.dataset,
+ depths=args.depths,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=info["D"],
+ num_classes=info["classes"],
+ hidden_dim=args.hidden_dim,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Print summary
+ print_summary(results, args.dataset)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}")
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()