diff options
Diffstat (limited to 'files/experiments')
| -rw-r--r-- | files/experiments/benchmark_experiment.py | 518 | ||||
| -rw-r--r-- | files/experiments/cifar10_conv_experiment.py | 448 | ||||
| -rw-r--r-- | files/experiments/depth_comparison.py | 542 | ||||
| -rw-r--r-- | files/experiments/depth_scaling_benchmark.py | 1035 | ||||
| -rw-r--r-- | files/experiments/hyperparameter_grid_search.py | 597 | ||||
| -rw-r--r-- | files/experiments/lyapunov_diffonly_benchmark.py | 590 | ||||
| -rw-r--r-- | files/experiments/lyapunov_speedup_benchmark.py | 638 | ||||
| -rw-r--r-- | files/experiments/plot_depth_comparison.py | 305 | ||||
| -rw-r--r-- | files/experiments/posthoc_finetune.py | 323 | ||||
| -rw-r--r-- | files/experiments/scaled_reg_grid_search.py | 301 |
10 files changed, 5297 insertions, 0 deletions
diff --git a/files/experiments/benchmark_experiment.py b/files/experiments/benchmark_experiment.py new file mode 100644 index 0000000..fb01ff2 --- /dev/null +++ b/files/experiments/benchmark_experiment.py @@ -0,0 +1,518 @@ +""" +Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets. + +Datasets: +- Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks +- Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory +- CIFAR-10: Rate-coded images, requires hierarchical features + +Usage: + python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8 + python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN +from files.data_io.benchmark_datasets import get_benchmark_dataloader + + +@dataclass +class EpochMetrics: + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + grad_max_sv: Optional[float] + grad_min_sv: Optional[float] + grad_condition: Optional[float] + time_sec: float + + +def compute_gradient_svs(model): + """Compute gradient singular value statistics.""" + max_svs = [] + min_svs = [] + + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + with torch.no_grad(): + G = param.grad.detach() + try: + sv = torch.linalg.svdvals(G) + if len(sv) > 0: + max_svs.append(sv[0].item()) + min_svs.append(sv[-1].item()) + except Exception: + pass + + if not max_svs: + return None, None, None + + max_sv = max(max_svs) + min_sv = min(min_svs) + condition = max_sv / (min_sv + 1e-12) + + return max_sv, min_sv, condition + + +def create_model( + input_dim: int, + num_classes: int, + depth: int, + hidden_dim: int = 128, + beta: float = 0.9, +) -> LyapunovSNN: + """Create SNN with specified depth.""" + hidden_dims = [hidden_dim] * depth + return LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=beta, + threshold=1.0, + ) + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + compute_sv_every: int = 10, +) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]: + """Train one epoch.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + grad_max_svs = [] + grad_min_svs = [] + grad_conditions = [] + + for batch_idx, (x, y) in enumerate(loader): + x, y = x.to(device), y.to(device) + + # Handle different input shapes + if x.ndim == 2: + x = x.unsqueeze(-1) # (B, T) -> (B, T, 1) + + optimizer.zero_grad() + + logits, lyap_est, _ = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + record_states=False, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan'), None, None, None + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + # Compute gradient SVs periodically + if batch_idx % compute_sv_every == 0: + max_sv, min_sv, cond = compute_gradient_svs(model) + if max_sv is not None: + grad_max_svs.append(max_sv) + grad_min_svs.append(min_sv) + grad_conditions.append(cond) + + optimizer.step() + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + avg_loss = total_loss / total_samples + avg_acc = total_correct / total_samples + avg_lyap = np.mean(lyap_vals) if lyap_vals else None + avg_grad = np.mean(grad_norms) + avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None + avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None + avg_cond = np.mean(grad_conditions) if grad_conditions else None + + return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """Evaluate on validation set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + for x, y in loader: + x, y = x.to(device), y.to(device) + + if x.ndim == 2: + x = x.unsqueeze(-1) + + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + loss = ce_loss(logits, y) + + if torch.isnan(loss): + return float('nan'), 0.0 + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_experiment( + depth: int, + use_lyapunov: bool, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> List[EpochMetrics]: + """Run single experiment configuration.""" + torch.manual_seed(seed) + + model = create_model( + input_dim=input_dim, + num_classes=num_classes, + depth=depth, + hidden_dim=hidden_dim, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=lr) + ce_loss = nn.CrossEntropyLoss() + + method = "Lyapunov" if use_lyapunov else "Vanilla" + metrics_history = [] + + iterator = range(1, epochs + 1) + if progress: + iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False) + + for epoch in iterator: + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, ce_loss, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, + ) + + val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) + dt = time.time() - t0 + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=val_loss, + val_acc=val_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + time_sec=dt, + ) + metrics_history.append(metrics) + + if progress: + lyap_str = f"λ={lyap:.2f}" if lyap else "" + iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str}) + + if np.isnan(train_loss): + print(f" Training diverged at epoch {epoch}") + break + + return metrics_history + + +def run_depth_comparison( + dataset_name: str, + depths: List[int], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> Dict[str, Dict[int, List[EpochMetrics]]]: + """Run comparison across depths.""" + results = {"vanilla": {}, "lyapunov": {}} + + for depth in depths: + print(f"\n{'='*60}") + print(f"Depth = {depth} layers") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + metrics = run_experiment( + depth=depth, + use_lyapunov=use_lyap, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + hidden_dim=hidden_dim, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=lyap_eps, + device=device, + seed=seed, + progress=progress, + ) + + results[method][depth] = metrics + + final = metrics[-1] + lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" + print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " + f"val_acc={final.val_acc:.3f} {lyap_str}") + + return results + + +def print_summary(results: Dict, dataset_name: str): + """Print summary table.""" + print("\n" + "=" * 90) + print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy") + print("=" * 90) + print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}") + print("-" * 90) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" + van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" + + print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}") + + print("=" * 90) + + # Gradient health analysis + print("\nGRADIENT HEALTH:") + for depth in depths: + van = results["vanilla"][depth][-1] + van_cond = van.grad_condition if van.grad_condition else 0 + if van_cond > 1e6: + print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})") + elif van_cond > 1e4: + print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})") + + +def save_results(results: Dict, output_dir: str, config: Dict): + """Save results to JSON.""" + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN") + + # Dataset + p.add_argument("--dataset", type=str, default="smnist", + choices=["smnist", "psmnist", "cifar10"], + help="Dataset to use") + p.add_argument("--data_dir", type=str, default="./data") + + # Model + p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8], + help="Network depths to test") + p.add_argument("--hidden_dim", type=int, default=128) + + # Training + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + + # Lyapunov + p.add_argument("--lambda_reg", type=float, default=0.3, + help="Lyapunov regularization weight (higher for harder tasks)") + p.add_argument("--lambda_target", type=float, default=-0.1, + help="Target Lyapunov exponent (negative for stability)") + p.add_argument("--lyap_eps", type=float, default=1e-4) + + # Dataset-specific + p.add_argument("--T", type=int, default=100, + help="Timesteps for CIFAR-10 (sMNIST uses 784)") + p.add_argument("--n_repeat", type=int, default=1, + help="Repeat each pixel n times for sMNIST") + + # Other + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--out_dir", type=str, default="runs/benchmark") + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}") + print("=" * 70) + print(f"Depths: {args.depths}") + print(f"Hidden dim: {args.hidden_dim}") + print(f"Epochs: {args.epochs}") + print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") + print(f"Device: {device}") + print("=" * 70) + + # Load dataset + print(f"\nLoading {args.dataset} dataset...") + + if args.dataset == "smnist": + train_loader, val_loader, info = get_benchmark_dataloader( + "smnist", + batch_size=args.batch_size, + root=args.data_dir, + n_repeat=args.n_repeat, + spike_encoding="direct", + ) + elif args.dataset == "psmnist": + train_loader, val_loader, info = get_benchmark_dataloader( + "psmnist", + batch_size=args.batch_size, + root=args.data_dir, + n_repeat=args.n_repeat, + spike_encoding="direct", + ) + elif args.dataset == "cifar10": + train_loader, val_loader, info = get_benchmark_dataloader( + "cifar10", + batch_size=args.batch_size, + root=args.data_dir, + T=args.T, + ) + + print(f"Dataset info: {info}") + print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") + + # Run experiments + results = run_depth_comparison( + dataset_name=args.dataset, + depths=args.depths, + train_loader=train_loader, + val_loader=val_loader, + input_dim=info["D"], + num_classes=info["classes"], + hidden_dim=args.hidden_dim, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Print summary + print_summary(results, args.dataset) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/cifar10_conv_experiment.py b/files/experiments/cifar10_conv_experiment.py new file mode 100644 index 0000000..a582f9f --- /dev/null +++ b/files/experiments/cifar10_conv_experiment.py @@ -0,0 +1,448 @@ +""" +CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization. + +Uses proper convolutional architecture that preserves spatial structure. +Tests whether Lyapunov regularization helps train deeper Conv-SNNs. + +Architecture: + Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output + +Usage: + python files/experiments/cifar10_conv_experiment.py --model simple --T 25 + python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +from files.models.conv_snn import create_conv_snn + + +@dataclass +class EpochMetrics: + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + time_sec: float + + +def get_cifar10_loaders( + data_dir: str = './data', + batch_size: int = 128, + num_workers: int = 4, +) -> Tuple[DataLoader, DataLoader]: + """ + Get CIFAR-10 dataloaders with standard normalization. + + Images normalized to [0, 1] for rate encoding. + """ + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + # Note: For rate encoding, we keep values in [0, 1] + # No normalization to negative values + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + ]) + + train_dataset = datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transform_train + ) + test_dataset = datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transform_test + ) + + train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True + ) + test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True + ) + + return train_loader, test_loader + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + progress: bool = True, +) -> Tuple[float, float, Optional[float], float]: + """Train one epoch.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32) + + optimizer.zero_grad() + + logits, lyap_est, _ = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan') + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + optimizer.step() + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + if progress: + iterator.set_postfix({ + "loss": f"{loss.item():.3f}", + "acc": f"{total_correct/total_samples:.3f}", + }) + + return ( + total_loss / total_samples, + total_correct / total_samples, + np.mean(lyap_vals) if lyap_vals else None, + np.mean(grad_norms), + ) + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, + progress: bool = True, +) -> Tuple[float, float]: + """Evaluate on test set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + iterator = tqdm(loader, desc="eval", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + + loss = ce_loss(logits, y) + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_experiment( + model_type: str, + channels: List[int], + T: int, + use_lyapunov: bool, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> List[EpochMetrics]: + """Run single experiment.""" + torch.manual_seed(seed) + + model = create_conv_snn( + model_type=model_type, + in_channels=3, + num_classes=10, + channels=channels, + T=T, + encoding='rate', + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" Model: {model_type}, params: {num_params:,}") + + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + ce_loss = nn.CrossEntropyLoss() + + metrics_history = [] + best_acc = 0.0 + + for epoch in range(1, epochs + 1): + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm = train_epoch( + model, train_loader, optimizer, ce_loss, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress + ) + + test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_acc = max(best_acc, test_acc) + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=test_loss, + val_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + time_sec=dt, + ) + metrics_history.append(metrics) + + lyap_str = f"λ={lyap:.3f}" if lyap else "" + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)") + + if np.isnan(train_loss): + print(" Training diverged!") + break + + print(f" Best test accuracy: {best_acc:.3f}") + return metrics_history + + +def run_comparison( + model_type: str, + channels_configs: List[List[int]], + T: int, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool, +) -> Dict: + """Compare vanilla vs Lyapunov across different depths.""" + results = {"vanilla": {}, "lyapunov": {}} + + for channels in channels_configs: + depth = len(channels) + print(f"\n{'='*60}") + print(f"Depth = {depth} conv layers, channels = {channels}") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + metrics = run_experiment( + model_type=model_type, + channels=channels, + T=T, + use_lyapunov=use_lyap, + train_loader=train_loader, + test_loader=test_loader, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=1e-4, + device=device, + seed=seed, + progress=progress, + ) + + results[method][depth] = metrics + + return results + + +def print_summary(results: Dict): + """Print comparison summary.""" + print("\n" + "=" * 70) + print("SUMMARY: CIFAR-10 Conv-SNN Results") + print("=" * 70) + print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}") + print("-" * 70) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + + print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") + + print("=" * 70) + + +def save_results(results: Dict, output_dir: str, config: Dict): + """Save results.""" + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser() + + # Model + p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"]) + p.add_argument("--channels", type=int, nargs="+", default=None, + help="Channel sizes (default: test multiple depths)") + p.add_argument("--T", type=int, default=25, help="Timesteps") + + # Training + p.add_argument("--epochs", type=int, default=50) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + + # Lyapunov + p.add_argument("--lambda_reg", type=float, default=0.3) + p.add_argument("--lambda_target", type=float, default=-0.1) + + # Other + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--out_dir", type=str, default="runs/cifar10_conv") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print("CIFAR-10 Conv-SNN Experiment") + print("=" * 70) + print(f"Model: {args.model}") + print(f"Timesteps: {args.T}") + print(f"Epochs: {args.epochs}") + print(f"Device: {device}") + print("=" * 70) + + # Load data + print("\nLoading CIFAR-10...") + train_loader, test_loader = get_cifar10_loaders( + data_dir=args.data_dir, + batch_size=args.batch_size, + ) + print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") + + # Define depth configurations to test + if args.channels: + channels_configs = [args.channels] + else: + # Test increasing depths + channels_configs = [ + [64, 128], # 2 conv layers (shallow) + [64, 128, 256], # 3 conv layers + [64, 128, 256, 512], # 4 conv layers (deep) + ] + + # Run comparison + results = run_comparison( + model_type=args.model, + channels_configs=channels_configs, + T=args.T, + train_loader=train_loader, + test_loader=test_loader, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Summary + print_summary(results) + + # Save + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/depth_comparison.py b/files/experiments/depth_comparison.py new file mode 100644 index 0000000..48c62d8 --- /dev/null +++ b/files/experiments/depth_comparison.py @@ -0,0 +1,542 @@ +""" +Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths. + +Hypothesis: +- Shallow networks (1-2 layers): Both methods train successfully +- Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds + +Usage: + # Quick test (synthetic data) + python files/experiments/depth_comparison.py --synthetic --epochs 20 + + # Full experiment with SHD data + python files/experiments/depth_comparison.py --epochs 50 + + # Specific depths to test + python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN +from files.analysis.stability_monitor import StabilityMonitor + + +@dataclass +class ExperimentConfig: + """Configuration for a single experiment run.""" + depth: int + hidden_dim: int + use_lyapunov: bool + lambda_reg: float + lambda_target: float + lyap_eps: float + epochs: int + lr: float + batch_size: int + beta: float + threshold: float + seed: int + + +@dataclass +class EpochMetrics: + """Metrics collected per epoch.""" + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + firing_rate: float + dead_neurons: float + time_sec: float + + +def create_synthetic_data( + n_train: int = 2000, + n_val: int = 500, + T: int = 50, + D: int = 100, + n_classes: int = 10, + seed: int = 42, +) -> Tuple[DataLoader, DataLoader]: + """Create synthetic spike data for testing.""" + torch.manual_seed(seed) + np.random.seed(seed) + + def generate_data(n_samples): + # Generate class-conditional spike patterns + x = torch.zeros(n_samples, T, D) + y = torch.randint(0, n_classes, (n_samples,)) + + for i in range(n_samples): + label = y[i].item() + # Each class has different firing rate pattern + base_rate = 0.05 + 0.02 * label + # Class-specific channels fire more + class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) + for t in range(T): + # Background activity + x[i, t] = (torch.rand(D) < base_rate).float() + # Enhanced activity for class-specific channels + for c in class_channels: + if torch.rand(1) < base_rate * 3: + x[i, t, c] = 1.0 + + return x, y + + x_train, y_train = generate_data(n_train) + x_val, y_val = generate_data(n_val) + + train_loader = DataLoader( + TensorDataset(x_train, y_train), + batch_size=64, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(x_val, y_val), + batch_size=64, + shuffle=False, + ) + + return train_loader, val_loader, T, D, n_classes + + +def create_model( + input_dim: int, + num_classes: int, + depth: int, + hidden_dim: int = 128, + beta: float = 0.9, + threshold: float = 1.0, +) -> LyapunovSNN: + """Create SNN with specified depth.""" + # Create hidden dims list based on depth + # Gradually decrease size for deeper networks to keep param count reasonable + hidden_dims = [] + current_dim = hidden_dim + for i in range(depth): + hidden_dims.append(current_dim) + # Optionally decrease dim in deeper layers + # current_dim = max(64, current_dim // 2) + + return LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=beta, + threshold=threshold, + ) + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + monitor: StabilityMonitor, +) -> Tuple[float, float, float, float, float, float]: + """Train for one epoch, return metrics.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + firing_rates = [] + dead_fracs = [] + + for x, y in loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, recordings = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + record_states=True, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + # Check for NaN + if torch.isnan(loss): + return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0 + + loss.backward() + + # Gradient clipping for stability comparison fairness + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + optimizer.step() + + # Collect metrics + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + # Stability metrics + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + if recordings is not None: + spikes = recordings['spikes'] + fr = spikes.mean().item() + dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item() + firing_rates.append(fr) + dead_fracs.append(dead) + + avg_loss = total_loss / total_samples + avg_acc = total_correct / total_samples + avg_lyap = np.mean(lyap_vals) if lyap_vals else None + avg_grad = np.mean(grad_norms) + avg_fr = np.mean(firing_rates) if firing_rates else 0.0 + avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0 + + return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """Evaluate model on validation set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + for x, y in loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + + loss = ce_loss(logits, y) + if torch.isnan(loss): + return float('nan'), 0.0 + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_single_experiment( + config: ExperimentConfig, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + device: torch.device, + progress: bool = True, +) -> List[EpochMetrics]: + """Run a single experiment with given configuration.""" + torch.manual_seed(config.seed) + + model = create_model( + input_dim=input_dim, + num_classes=num_classes, + depth=config.depth, + hidden_dim=config.hidden_dim, + beta=config.beta, + threshold=config.threshold, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=config.lr) + ce_loss = nn.CrossEntropyLoss() + monitor = StabilityMonitor() + + metrics_history = [] + method = "Lyapunov" if config.use_lyapunov else "Vanilla" + + iterator = range(1, config.epochs + 1) + if progress: + iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False) + + for epoch in iterator: + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch( + model=model, + loader=train_loader, + optimizer=optimizer, + ce_loss=ce_loss, + device=device, + use_lyapunov=config.use_lyapunov, + lambda_reg=config.lambda_reg, + lambda_target=config.lambda_target, + lyap_eps=config.lyap_eps, + monitor=monitor, + ) + + val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) + + dt = time.time() - t0 + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=val_loss, + val_acc=val_acc, + lyapunov=lyap, + grad_norm=grad_norm, + firing_rate=fr, + dead_neurons=dead, + time_sec=dt, + ) + metrics_history.append(metrics) + + # Early stopping if training diverged + if np.isnan(train_loss): + print(f" Training diverged at epoch {epoch}") + break + + return metrics_history + + +def run_depth_comparison( + depths: List[int], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + device: torch.device, + epochs: int = 30, + hidden_dim: int = 128, + lr: float = 1e-3, + lambda_reg: float = 0.1, + lambda_target: float = 0.0, + lyap_eps: float = 1e-4, + beta: float = 0.9, + seed: int = 42, + progress: bool = True, +) -> Dict[str, Dict[int, List[EpochMetrics]]]: + """ + Run comparison experiments across depths. + + Returns: + Dictionary with structure: + { + "vanilla": {1: [metrics...], 2: [metrics...], ...}, + "lyapunov": {1: [metrics...], 2: [metrics...], ...} + } + """ + results = {"vanilla": {}, "lyapunov": {}} + + for depth in depths: + print(f"\n{'='*50}") + print(f"Depth = {depth} layers") + print(f"{'='*50}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + config = ExperimentConfig( + depth=depth, + hidden_dim=hidden_dim, + use_lyapunov=use_lyap, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=lyap_eps, + epochs=epochs, + lr=lr, + batch_size=64, + beta=beta, + threshold=1.0, + seed=seed, + ) + + metrics = run_single_experiment( + config=config, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + device=device, + progress=progress, + ) + + results[method][depth] = metrics + + # Print final metrics + final = metrics[-1] + lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" + print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " + f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}") + + return results + + +def save_results(results: Dict, output_dir: str, config: dict): + """Save experiment results to JSON.""" + os.makedirs(output_dir, exist_ok=True) + + # Convert metrics to dicts + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]): + """Print summary comparison table.""" + print("\n" + "=" * 70) + print("SUMMARY: Final Validation Accuracy by Depth") + print("=" * 70) + print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}") + print("-" * 70) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van_metrics = results["vanilla"][depth] + lyap_metrics = results["lyapunov"][depth] + + van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0 + lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0 + + van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED" + lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED" + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") + + print("=" * 70) + + # Gradient analysis + print("\nGradient Norm Analysis (final epoch):") + print("-" * 70) + print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}") + print("-" * 70) + for depth in depths: + van_grad = results["vanilla"][depth][-1].grad_norm + lyap_grad = results["lyapunov"][depth][-1].grad_norm + print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths") + p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6], + help="Network depths to test") + p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer") + p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment") + p.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") + p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov") + p.add_argument("--beta", type=float, default=0.9, help="Membrane decay") + p.add_argument("--seed", type=int, default=42, help="Random seed") + p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") + p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--no-progress", action="store_true", help="Disable progress bars") + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print("Experiment: Vanilla vs Lyapunov-Regularized SNN") + print("=" * 70) + print(f"Depths: {args.depths}") + print(f"Hidden dim: {args.hidden_dim}") + print(f"Epochs: {args.epochs}") + print(f"Lambda_reg: {args.lambda_reg}") + print(f"Device: {device}") + + # Load data + if args.synthetic: + print("\nUsing SYNTHETIC data for quick testing") + train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed) + else: + print(f"\nLoading data from {args.cfg}") + from files.data_io.dataset_loader import get_dataloader + train_loader, val_loader = get_dataloader(args.cfg) + xb, _ = next(iter(train_loader)) + _, T, D = xb.shape + C = 20 # SHD has 20 classes + + print(f"Data: T={T}, D={D}, classes={C}") + + # Run experiments + results = run_depth_comparison( + depths=args.depths, + train_loader=train_loader, + val_loader=val_loader, + input_dim=D, + num_classes=C, + device=device, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + beta=args.beta, + seed=args.seed, + progress=not args.no_progress, + ) + + # Print summary + print_summary(results) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/depth_scaling_benchmark.py b/files/experiments/depth_scaling_benchmark.py new file mode 100644 index 0000000..efab140 --- /dev/null +++ b/files/experiments/depth_scaling_benchmark.py @@ -0,0 +1,1035 @@ +""" +Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs. + +Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve. + +Key hypothesis (from literature): +- Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet) +- Deep SNNs without regularization fail to train (gradient issues) +- Deep SNNs WITH Lyapunov regularization achieve higher accuracy + +Reference results: +- Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI) +- SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS) +- Spikformer on ImageNet: ~74.8% top-1 (arXiv) + +Usage: + python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +import snntorch as snn +from snntorch import surrogate + + +# ============================================================================= +# VGG-style Spiking Network (scalable depth) +# ============================================================================= + +class SpikingVGGBlock(nn.Module): + """Conv-BN-LIF block for VGG-style architecture.""" + + def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): + super().__init__() + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + + def forward(self, x, mem): + h = self.bn(self.conv(x)) + spk, mem = self.lif(h, mem) + return spk, mem + + +class SpikingVGG(nn.Module): + """ + Scalable VGG-style Spiking Neural Network. + + Architecture follows VGG pattern: + - Multiple conv blocks between pooling layers + - Depth controlled by num_blocks_per_stage + + Args: + in_channels: Input channels (3 for RGB) + num_classes: Output classes + base_channels: Starting channel count (doubled each stage) + num_stages: Number of pooling stages (3-4 typical) + blocks_per_stage: Conv blocks per stage (controls depth) + T: Number of timesteps + beta: LIF membrane decay + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + base_channels: int = 64, + num_stages: int = 3, + blocks_per_stage: int = 2, + T: int = 4, + beta: float = 0.9, + threshold: float = 1.0, + dropout: float = 0.25, + stable_init: bool = False, + ): + super().__init__() + + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + self.total_conv_layers = num_stages * blocks_per_stage + self.stable_init = stable_init + + spike_grad = surrogate.fast_sigmoid(slope=25) + + # Build stages + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + ch_in = in_channels + ch_out = base_channels + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for b in range(blocks_per_stage): + block_in = ch_in if b == 0 else ch_out + stage_blocks.append( + SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad) + ) + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + ch_in = ch_out + ch_out = min(ch_out * 2, 512) # Cap at 512 + + # Calculate spatial size after pooling + # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages) + final_spatial = 32 // (2 ** num_stages) + final_channels = min(base_channels * (2 ** (num_stages - 1)), 512) + fc_input = final_channels * final_spatial * final_spatial + + # Classifier + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(fc_input, num_classes) + + if stable_init: + self._init_weights_stable() + else: + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_weights_stable(self): + """ + Stability-aware initialization for SNNs. + + Uses smaller weight magnitudes to produce less chaotic initial dynamics. + The key insight: Lyapunov exponent depends on weight magnitudes. + Smaller weights → smaller gradients → more stable dynamics. + + Strategy: + - Use orthogonal init (preserves gradient magnitude across layers) + - Scale down by factor of 0.5 to reduce initial chaos + - This should produce λ closer to 0 from the start + """ + scale_factor = 0.5 # Reduce weight magnitudes + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # Orthogonal init for conv (reshape to 2D, init, reshape back) + weight_shape = m.weight.shape + fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3] + fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3] + + # Use smaller gain for stability + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + with torch.no_grad(): + m.weight.mul_(scale_factor) + + elif isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight, gain=scale_factor) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_mems(self, batch_size, device, dtype, P=1): + """Initialize membrane potentials for all LIF layers. + + Args: + batch_size: Batch size B + device: torch device + dtype: torch dtype + P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov) + + Returns: + List of membrane tensors with shape (P, B, C, H, W) + """ + mems = [] + H, W = 32, 32 + ch = 64 + + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + ch = min(ch * 2, 512) + + return mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + """ + Forward pass with optimized Lyapunov computation (Approach A: trajectory batching). + + When compute_lyapunov=True, both original and perturbed trajectories are + processed together by batching them along a new dimension P=2. This avoids + redundant computation, especially for the first conv layer where inputs are identical. + + Args: + x: (B, C, H, W) static image (will be repeated for T steps) + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude + + Returns: + logits, lyap_est, recordings + """ + B = x.size(0) + device, dtype = x.device, x.dtype + + # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed) + P = 2 if compute_lyapunov else 1 + + # Initialize membrane potentials with shape (P, B, C, H, W) + mems = self._init_mems(B, device, dtype, P=P) + + # Initialize perturbed trajectory + if compute_lyapunov: + for i in range(len(mems)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = None + + # Time loop - repeat static input + for t in range(self.T): + mem_idx = 0 + new_mems = [] + is_first_block = True + + # Process through stages + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + # First block: input x is identical for both trajectories + # Compute conv+bn ONCE, then expand to (P, B, C, H, W) + h_conv = block.bn(block.conv(x)) # (B, C, H, W) + h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy + + # LIF with batched membrane states + # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) + + # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + + h = spk + new_mems.append(mem_new) + is_first_block = False + else: + # Subsequent blocks: inputs differ between trajectories + # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + + # Full block forward (conv+bn+lif) + h_conv = block.bn(block.conv(h_flat)) + spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) + + # Reshape back + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + + h = spk + new_mems.append(mem_new) + + mem_idx += 1 + + # Pool: apply to batched tensor + h_flat = h.reshape(P * B, *h.shape[2:]) + h_pooled = pool(h_flat) + h = h_pooled.view(P, B, *h_pooled.shape[1:]) + + mems = new_mems + + # Accumulate final spikes from ORIGINAL trajectory only (index 0) + h_orig = h[0].view(B, -1) # (B, C*H*W) + if spike_sum is None: + spike_sum = h_orig + else: + spike_sum = spike_sum + h_orig + + # Lyapunov divergence and renormalization (Option 1: global delta + global renorm) + # This is the textbook Benettin-style Lyapunov exponent estimator where + # the perturbation is treated as one vector in the concatenated state space. + if compute_lyapunov: + # Compute GLOBAL divergence across all layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W) + delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # GLOBAL renormalization: same scale factor for all layers + # This ensures ||perturbation||_global = eps after renorm + scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting + + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + # Update perturbed trajectory: scale the diff to have global norm = eps + mems[i] = torch.stack([ + new_mems[i][0], + new_mems[i][0] + diff * scale + ], dim=0) + + # Readout + out = self.dropout(spike_sum) + logits = self.fc(out) + + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est, None + + @property + def depth(self): + return self.total_conv_layers + + +# ============================================================================= +# Dataset Loading +# ============================================================================= + +def get_dataset( + name: str, + data_dir: str = './data', + batch_size: int = 128, + num_workers: int = 4, +) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]: + """ + Get train/test loaders for various datasets. + + Returns: + train_loader, test_loader, num_classes, input_shape + """ + + if name == 'mnist': + transform = transforms.Compose([ + transforms.Resize(32), # Resize to 32x32 for consistency + transforms.ToTensor(), + ]) + train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform) + num_classes = 10 + input_shape = (1, 32, 32) + + elif name == 'fashion_mnist': + transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + ]) + train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform) + test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform) + num_classes = 10 + input_shape = (1, 32, 32) + + elif name == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train) + test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test) + num_classes = 10 + input_shape = (3, 32, 32) + + elif name == 'cifar100': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train) + test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test) + num_classes = 100 + input_shape = (3, 32, 32) + + else: + raise ValueError(f"Unknown dataset: {name}") + + train_loader = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True + ) + test_loader = DataLoader( + test_ds, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True + ) + + return train_loader, test_loader, num_classes, input_shape + + +# ============================================================================= +# Training +# ============================================================================= + +@dataclass +class TrainingMetrics: + epoch: int + train_loss: float + train_acc: float + test_loss: float + test_acc: float + lyapunov: Optional[float] + grad_norm: float + grad_max_sv: Optional[float] # Max singular value of gradients + grad_min_sv: Optional[float] # Min singular value of gradients + grad_condition: Optional[float] # Condition number + lr: float + time_sec: float + + +def compute_gradient_svs(model): + """Compute gradient singular value statistics for all weight matrices.""" + max_svs = [] + min_svs = [] + + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + with torch.no_grad(): + G = param.grad.detach() + try: + sv = torch.linalg.svdvals(G) + if len(sv) > 0: + max_svs.append(sv[0].item()) + min_svs.append(sv[-1].item()) + except Exception: + pass + + if not max_svs: + return None, None, None + + max_sv = max(max_svs) + min_sv = min(min_svs) + condition = max_sv / (min_sv + 1e-12) + + return max_sv, min_sv, condition + + +def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float, + lyap_threshold: float = 2.0) -> torch.Tensor: + """ + Compute Lyapunov regularization loss with different penalty types. + + Args: + lyap_est: Estimated Lyapunov exponent (scalar tensor) + reg_type: Type of regularization: + - "squared": (λ - target)² - original, aggressive + - "hinge": max(0, λ - threshold)² - only penalize chaos + - "asymmetric": strong penalty for chaos, weak for collapse + - "extreme": only penalize when λ > lyap_threshold (configurable) + - "adaptive_linear": penalty scales linearly with excess over threshold + - "adaptive_exp": penalty grows exponentially for severe chaos + - "adaptive_sigmoid": smooth sigmoid transition around threshold + lambda_target: Target value (used for squared reg_type) + lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0) + + Returns: + Regularization loss (scalar tensor) + """ + if reg_type == "squared": + # Original: penalize any deviation from target + return (lyap_est - lambda_target) ** 2 + + elif reg_type == "hinge": + # Only penalize when λ > threshold (too chaotic) + # threshold = 0 means: only penalize positive Lyapunov (chaos) + threshold = 0.0 + excess = torch.relu(lyap_est - threshold) + return excess ** 2 + + elif reg_type == "asymmetric": + # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1) + # This allows the network to be stable without being dead + chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0 + collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1 + return chaos_penalty + collapse_penalty + + elif reg_type == "extreme": + # Only penalize when λ > threshold (VERY chaotic) + # This allows moderate chaos while preventing extreme instability + # Threshold is now configurable via lyap_threshold argument + excess = torch.relu(lyap_est - lyap_threshold) + return excess ** 2 + + elif reg_type == "adaptive_linear": + # Penalty scales linearly with how far above threshold we are + # loss = excess * excess² = excess³ + # This naturally makes the penalty weaker for small excesses + # and much stronger for large excesses + excess = torch.relu(lyap_est - lyap_threshold) + return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic + + elif reg_type == "adaptive_exp": + # Exponential penalty for severe chaos + # loss = (exp(excess) - 1) * excess² for excess > 0 + # This gives very weak penalty near threshold, explosive penalty for chaos + excess = torch.relu(lyap_est - lyap_threshold) + # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after + exp_scale = torch.exp(excess) - 1.0 + return exp_scale * excess # exp(excess) * excess - excess + + elif reg_type == "adaptive_sigmoid": + # Smooth sigmoid transition around threshold + # The "sharpness" of transition is controlled by a temperature parameter + # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness + # Using T=0.5 for moderately sharp transition + temperature = 0.5 + weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature) + # Penalize deviation from target, weighted by how far past threshold + deviation = lyap_est - lambda_target + return weight * (deviation ** 2) + + # ========================================================================= + # SCALED MULTIPLIER REGULARIZATION + # loss = (λ_reg × g(relu(λ))) × relu(λ) + # └─────────────────┘ └──────┘ + # scaled multiplier penalty toward target=0 + # + # The multiplier itself scales with λ, making it mild when λ is small + # and aggressive when λ is large. + # ========================================================================= + + elif reg_type == "mult_linear": + # Multiplier scales linearly: g(x) = x + # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)² + # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0 + pos_lyap = torch.relu(lyap_est) + return pos_lyap * pos_lyap # relu(λ)² + + elif reg_type == "mult_squared": + # Multiplier scales quadratically: g(x) = x² + # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³ + # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0 + pos_lyap = torch.relu(lyap_est) + return pos_lyap * pos_lyap * pos_lyap # relu(λ)³ + + elif reg_type == "mult_log": + # Multiplier scales logarithmically: g(x) = log(1+x) + # loss = (λ_reg × log(1+relu(λ))) × relu(λ) + # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16 + pos_lyap = torch.relu(lyap_est) + return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ + + else: + raise ValueError(f"Unknown reg_type: {reg_type}") + + +def train_epoch( + model, loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, + progress=True, compute_sv_every=10, + reg_type="squared", current_lambda_reg=None, + lyap_threshold=2.0 +): + """ + Train one epoch. + + Args: + current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg. + reg_type: "squared", "hinge", "asymmetric", or "extreme" + lyap_threshold: Threshold for extreme reg_type + """ + model.train() + total_loss = 0.0 + correct = 0 + total = 0 + lyap_vals = [] + grad_norms = [] + grad_max_svs = [] + grad_min_svs = [] + grad_conditions = [] + + # Use warmup value if provided + effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for batch_idx, (x, y) in enumerate(iterator): + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps) + + loss = criterion(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold) + loss = loss + effective_lambda_reg * reg + lyap_vals.append(lyap_est.item()) + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan'), None, None, None + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5 + grad_norms.append(grad_norm) + + # Compute gradient SVs periodically (expensive) + if batch_idx % compute_sv_every == 0: + max_sv, min_sv, cond = compute_gradient_svs(model) + if max_sv is not None: + grad_max_svs.append(max_sv) + grad_min_svs.append(min_sv) + grad_conditions.append(cond) + + optimizer.step() + + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + return ( + total_loss / total, + correct / total, + np.mean(lyap_vals) if lyap_vals else None, + np.mean(grad_norms), + np.mean(grad_max_svs) if grad_max_svs else None, + np.mean(grad_min_svs) if grad_min_svs else None, + np.mean(grad_conditions) if grad_conditions else None, + ) + + +@torch.no_grad() +def evaluate(model, loader, criterion, device, progress=True): + model.eval() + total_loss = 0.0 + correct = 0 + total = 0 + + iterator = tqdm(loader, desc="eval", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + + loss = criterion(logits, y) + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + return total_loss / total, correct / total + + +def run_single_config( + dataset_name: str, + depth_config: Tuple[int, int], # (num_stages, blocks_per_stage) + use_lyapunov: bool, + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool = True, + reg_type: str = "squared", + warmup_epochs: int = 0, + stable_init: bool = False, + lyap_threshold: float = 2.0, +) -> List[TrainingMetrics]: + """Run training for a single configuration.""" + torch.manual_seed(seed) + + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + model = SpikingVGG( + in_channels=in_channels, + num_classes=num_classes, + base_channels=64, + num_stages=num_stages, + blocks_per_stage=blocks_per_stage, + T=T, + stable_init=stable_init, + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + method = "Lyapunov" if use_lyapunov else "Vanilla" + print(f" {method}: depth={total_depth}, params={num_params:,}") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + criterion = nn.CrossEntropyLoss() + + history = [] + best_acc = 0.0 + + for epoch in range(1, epochs + 1): + t0 = time.time() + + # Warmup: gradually increase lambda_reg + if warmup_epochs > 0 and epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, lambda_target, 1e-4, progress, + reg_type=reg_type, current_lambda_reg=current_lambda_reg, + lyap_threshold=lyap_threshold + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_acc = max(best_acc, test_acc) + + metrics = TrainingMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + history.append(metrics) + + if epoch % 10 == 0 or epoch == epochs: + lyap_str = f"λ={lyap:.3f}" if lyap else "" + sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else "" + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}") + + if np.isnan(train_loss): + print(f" DIVERGED at epoch {epoch}") + break + + print(f" Best test acc: {best_acc:.3f}") + return history + + +def run_depth_scaling_experiment( + dataset_name: str, + depth_configs: List[Tuple[int, int]], + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool, + reg_type: str = "squared", + warmup_epochs: int = 0, + stable_init: bool = False, + lyap_threshold: float = 2.0, +) -> Dict: + """Run full depth scaling experiment.""" + + results = {"vanilla": {}, "lyapunov": {}} + + print(f"Regularization type: {reg_type}") + print(f"Warmup epochs: {warmup_epochs}") + print(f"Stable init: {stable_init}") + print(f"Lyapunov threshold: {lyap_threshold}") + + for depth_config in depth_configs: + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + print(f"\n{'='*60}") + print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + + history = run_single_config( + dataset_name=dataset_name, + depth_config=depth_config, + use_lyapunov=use_lyap, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=T, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + device=device, + seed=seed, + progress=progress, + reg_type=reg_type, + warmup_epochs=warmup_epochs, + stable_init=stable_init, + lyap_threshold=lyap_threshold, + ) + + results[method][total_depth] = history + + return results + + +def print_summary(results: Dict, dataset_name: str): + """Print final summary table.""" + print("\n" + "=" * 100) + print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}") + print("=" * 100) + print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}") + print("-" * 100) + + depths = sorted(results["vanilla"].keys()) + + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0 + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}" + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A" + + van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" + lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A" + van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" + + print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}") + + print("=" * 100) + + # Gradient health analysis + print("\nGRADIENT HEALTH ANALYSIS:") + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_cond = van.grad_condition if van.grad_condition else 0 + lyap_cond = lyap.grad_condition if lyap.grad_condition else 0 + + status = "" + if van_cond > 1e6: + status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)" + elif van_cond > 1e4: + status = "~ Vanilla has moderately ill-conditioned gradients" + + if status: + print(f" Depth {depth}: {status}") + + print("") + + # Analysis + print("\nKEY OBSERVATIONS:") + shallow = min(depths) + deep = max(depths) + + van_shallow = results["vanilla"][shallow][-1].test_acc + van_deep = results["vanilla"][deep][-1].test_acc + lyap_shallow = results["lyapunov"][shallow][-1].test_acc + lyap_deep = results["lyapunov"][deep][-1].test_acc + + van_gain = van_deep - van_shallow + lyap_gain = lyap_deep - lyap_shallow + + print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change") + print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change") + + if lyap_gain > van_gain + 0.02: + print(f" ✓ Lyapunov regularization enables better depth scaling!") + elif lyap_gain > van_gain: + print(f" ~ Lyapunov shows slight improvement in depth scaling") + else: + print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range") + + +def save_results(results: Dict, output_dir: str, config: Dict): + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, history in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in history] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs") + + p.add_argument("--dataset", type=str, default="cifar100", + choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) + p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16], + help="Total conv layer depths to test") + p.add_argument("--T", type=int, default=4, help="Timesteps") + p.add_argument("--epochs", type=int, default=100) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--lambda_reg", type=float, default=0.3) + p.add_argument("--lambda_target", type=float, default=-0.1) + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--out_dir", type=str, default="runs/depth_scaling") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--no-progress", action="store_true") + p.add_argument("--reg_type", type=str, default="squared", + choices=["squared", "hinge", "asymmetric", "extreme"], + help="Lyapunov regularization type") + p.add_argument("--warmup_epochs", type=int, default=0, + help="Epochs to warmup lambda_reg (0 = no warmup)") + p.add_argument("--stable_init", action="store_true", + help="Use stability-aware weight initialization") + p.add_argument("--lyap_threshold", type=float, default=2.0, + help="Threshold for extreme reg_type (only penalize λ > threshold)") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("DEPTH SCALING BENCHMARK") + print("=" * 80) + print(f"Dataset: {args.dataset}") + print(f"Depths: {args.depths}") + print(f"Timesteps: {args.T}") + print(f"Epochs: {args.epochs}") + print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") + print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}") + print(f"Device: {device}") + print("=" * 80) + + # Load data + print(f"\nLoading {args.dataset}...") + train_loader, test_loader, num_classes, input_shape = get_dataset( + args.dataset, args.data_dir, args.batch_size + ) + in_channels = input_shape[0] + print(f"Classes: {num_classes}, Input: {input_shape}") + print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") + + # Convert depths to (num_stages, blocks_per_stage) configs + # We use 4 stages (3 for smaller nets), adjust blocks_per_stage + depth_configs = [] + for d in args.depths: + if d <= 4: + depth_configs.append((d, 1)) # d stages, 1 block each + elif d <= 8: + depth_configs.append((4, d // 4)) # 4 stages + else: + depth_configs.append((4, d // 4)) # 4 stages, more blocks + + print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}") + + # Run experiment + results = run_depth_scaling_experiment( + dataset_name=args.dataset, + depth_configs=depth_configs, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=args.T, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + progress=not args.no_progress, + reg_type=args.reg_type, + warmup_epochs=args.warmup_epochs, + stable_init=args.stable_init, + lyap_threshold=args.lyap_threshold, + ) + + # Summary + print_summary(results, args.dataset) + + # Save + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/hyperparameter_grid_search.py b/files/experiments/hyperparameter_grid_search.py new file mode 100644 index 0000000..011387f --- /dev/null +++ b/files/experiments/hyperparameter_grid_search.py @@ -0,0 +1,597 @@ +""" +Hyperparameter Grid Search for Lyapunov-Regularized SNNs. + +Goal: Find optimal (lambda_reg, lambda_target) for each network depth + and derive an adaptive curve for automatic hyperparameter selection. + +Usage: + python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Tuple +from itertools import product + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN + + +@dataclass +class GridSearchResult: + """Result from a single grid search configuration.""" + depth: int + lambda_reg: float + lambda_target: float + final_train_acc: float + final_val_acc: float + final_lyapunov: float + final_grad_norm: float + converged: bool # Did training succeed (not NaN)? + epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never) + + +def create_synthetic_data( + n_train: int = 2000, + n_val: int = 500, + T: int = 50, + D: int = 100, + n_classes: int = 10, + seed: int = 42, + batch_size: int = 128, +) -> Tuple[DataLoader, DataLoader, int, int, int]: + """Create synthetic spike data.""" + torch.manual_seed(seed) + np.random.seed(seed) + + def generate_data(n_samples): + x = torch.zeros(n_samples, T, D) + y = torch.randint(0, n_classes, (n_samples,)) + for i in range(n_samples): + label = y[i].item() + base_rate = 0.05 + 0.02 * label + class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) + for t in range(T): + x[i, t] = (torch.rand(D) < base_rate).float() + for c in class_channels: + if torch.rand(1) < base_rate * 3: + x[i, t, c] = 1.0 + return x, y + + x_train, y_train = generate_data(n_train) + x_val, y_val = generate_data(n_val) + + train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) + val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False) + + return train_loader, val_loader, T, D, n_classes + + +def train_and_evaluate( + depth: int, + lambda_reg: float, + lambda_target: float, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + device: torch.device, + seed: int = 42, + warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early +) -> GridSearchResult: + """Train a single configuration and return results.""" + torch.manual_seed(seed) + + # Create model + hidden_dims = [hidden_dim] * depth + model = LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=0.9, + threshold=1.0, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=lr) + ce_loss = nn.CrossEntropyLoss() + + best_val_acc = 0.0 + epochs_to_90 = -1 + final_lyap = 0.0 + final_grad = 0.0 + converged = True + + for epoch in range(1, epochs + 1): + # Warmup: gradually increase lambda_reg + if epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + # Training + model.train() + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + + for x, y in train_loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False) + + ce = ce_loss(logits, y) + if lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + current_lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + converged = False + break + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + optimizer.step() + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + if not converged: + break + + train_acc = total_correct / total_samples + final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0 + final_grad = np.mean(grad_norms) if grad_norms else 0.0 + + # Track epochs to 90% accuracy + if epochs_to_90 < 0 and train_acc >= 0.9: + epochs_to_90 = epoch + + # Validation + model.eval() + val_correct = 0 + val_total = 0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + preds = logits.argmax(dim=1) + val_correct += (preds == y).sum().item() + val_total += x.size(0) + val_acc = val_correct / val_total + best_val_acc = max(best_val_acc, val_acc) + + return GridSearchResult( + depth=depth, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + final_train_acc=train_acc if converged else 0.0, + final_val_acc=best_val_acc if converged else 0.0, + final_lyapunov=final_lyap, + final_grad_norm=final_grad, + converged=converged, + epochs_to_90pct=epochs_to_90, + ) + + +def run_grid_search( + depths: List[int], + lambda_regs: List[float], + lambda_targets: List[float], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + device: torch.device, + seed: int = 42, + progress: bool = True, +) -> List[GridSearchResult]: + """Run full grid search.""" + results = [] + + # Total configurations + configs = list(product(depths, lambda_regs, lambda_targets)) + total = len(configs) + + iterator = tqdm(configs, desc="Grid Search", disable=not progress) + + for depth, lambda_reg, lambda_target in iterator: + if progress: + iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target}) + + result = train_and_evaluate( + depth=depth, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + hidden_dim=hidden_dim, + epochs=epochs, + lr=lr, + device=device, + seed=seed, + ) + results.append(result) + + if progress: + iterator.set_postfix({ + "d": depth, + "λr": lambda_reg, + "λt": lambda_target, + "acc": f"{result.final_val_acc:.2f}" + }) + + return results + + +def analyze_results(results: List[GridSearchResult]) -> Dict: + """Analyze grid search results and find optimal hyperparameters per depth.""" + + # Group by depth + by_depth = {} + for r in results: + if r.depth not in by_depth: + by_depth[r.depth] = [] + by_depth[r.depth].append(r) + + analysis = { + "optimal_per_depth": {}, + "all_results": [asdict(r) for r in results], + } + + print("\n" + "=" * 80) + print("GRID SEARCH ANALYSIS") + print("=" * 80) + + # Find optimal for each depth + print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}") + print("-" * 80) + + optimal_lambda_regs = [] + optimal_lambda_targets = [] + depths_list = [] + + for depth in sorted(by_depth.keys()): + depth_results = by_depth[depth] + # Find best by validation accuracy + best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0) + + analysis["optimal_per_depth"][depth] = { + "lambda_reg": best.lambda_reg, + "lambda_target": best.lambda_target, + "val_acc": best.final_val_acc, + "lyapunov": best.final_lyapunov, + "epochs_to_90": best.epochs_to_90pct, + } + + print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} " + f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}") + + if best.final_val_acc > 0.5: # Only use successful runs for curve fitting + depths_list.append(depth) + optimal_lambda_regs.append(best.lambda_reg) + optimal_lambda_targets.append(best.lambda_target) + + # Fit adaptive curves + print("\n" + "=" * 80) + print("ADAPTIVE HYPERPARAMETER CURVES") + print("=" * 80) + + if len(depths_list) >= 3: + # Fit polynomial curves + depths_arr = np.array(depths_list) + lambda_regs_arr = np.array(optimal_lambda_regs) + lambda_targets_arr = np.array(optimal_lambda_targets) + + # Fit lambda_reg vs depth (expect increasing with depth) + try: + reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1)) + reg_poly = np.poly1d(reg_coeffs) + print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}" + if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}") + analysis["lambda_reg_curve"] = reg_coeffs.tolist() + except Exception as e: + print(f"Could not fit λ_reg curve: {e}") + + # Fit lambda_target vs depth (expect decreasing / more negative with depth) + try: + target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1)) + target_poly = np.poly1d(target_coeffs) + print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}" + if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}") + analysis["lambda_target_curve"] = target_coeffs.tolist() + except Exception as e: + print(f"Could not fit λ_target curve: {e}") + + # Print recommendations + print("\n" + "-" * 80) + print("RECOMMENDED HYPERPARAMETERS BY DEPTH:") + print("-" * 80) + for d in [2, 4, 6, 8, 10, 12, 14, 16]: + rec_reg = max(0.01, reg_poly(d)) + rec_target = min(0.0, target_poly(d)) + print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}") + + else: + print("Not enough successful runs to fit curves") + + return analysis + + +def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict): + """Save grid search results.""" + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f: + json.dump(analysis, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def plot_grid_search(results: List[GridSearchResult], output_dir: str): + """Generate visualization of grid search results.""" + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not available, skipping plots") + return + + # Group by depth + by_depth = {} + for r in results: + if r.depth not in by_depth: + by_depth[r.depth] = [] + by_depth[r.depth].append(r) + + depths = sorted(by_depth.keys()) + + # Get unique lambda values + lambda_regs = sorted(set(r.lambda_reg for r in results)) + lambda_targets = sorted(set(r.lambda_target for r in results)) + + # Create heatmaps for each depth + n_depths = len(depths) + fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10)) + axes = axes.flatten() + + for idx, depth in enumerate(depths): + ax = axes[idx] + depth_results = by_depth[depth] + + # Create accuracy matrix + acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs))) + for r in depth_results: + i = lambda_targets.index(r.lambda_target) + j = lambda_regs.index(r.lambda_reg) + acc_matrix[i, j] = r.final_val_acc + + im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto') + ax.set_xticks(range(len(lambda_regs))) + ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45) + ax.set_yticks(range(len(lambda_targets))) + ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets]) + ax.set_xlabel("λ_reg") + ax.set_ylabel("λ_target") + ax.set_title(f"Depth {depth}") + + # Mark best + best = max(depth_results, key=lambda r: r.final_val_acc) + bi = lambda_targets.index(best.lambda_target) + bj = lambda_regs.index(best.lambda_reg) + ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2) + + # Add colorbar + plt.colorbar(im, ax=ax, label='Val Acc') + + # Hide unused subplots + for idx in range(len(depths), len(axes)): + axes[idx].axis('off') + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight') + plt.close() + + # Plot optimal hyperparameters vs depth + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + optimal_regs = [] + optimal_targets = [] + optimal_accs = [] + for depth in depths: + best = max(by_depth[depth], key=lambda r: r.final_val_acc) + optimal_regs.append(best.lambda_reg) + optimal_targets.append(best.lambda_target) + optimal_accs.append(best.final_val_acc) + + axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8) + axes[0].set_xlabel("Network Depth") + axes[0].set_ylabel("Optimal λ_reg") + axes[0].set_title("Optimal Regularization Strength vs Depth") + axes[0].grid(True, alpha=0.3) + + axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange') + axes[1].set_xlabel("Network Depth") + axes[1].set_ylabel("Optimal λ_target") + axes[1].set_title("Optimal Target Lyapunov vs Depth") + axes[1].grid(True, alpha=0.3) + + axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green') + axes[2].set_xlabel("Network Depth") + axes[2].set_ylabel("Best Validation Accuracy") + axes[2].set_title("Best Achievable Accuracy vs Depth") + axes[2].set_ylim(0, 1.05) + axes[2].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight') + plt.close() + + print(f"Plots saved to {output_dir}") + + +def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'): + """Get CIFAR-10 dataloaders with rate encoding for SNN.""" + from torchvision import datasets, transforms + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform) + val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform) + + # Rate encoding: convert images to spike sequences + class RateEncodedDataset(torch.utils.data.Dataset): + def __init__(self, dataset, T): + self.dataset = dataset + self.T = T + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + img, label = self.dataset[idx] + # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D) + flat = img.view(-1) # (3072,) + # Rate encoding: probability of spike = pixel intensity + spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float() + return spikes, label + + train_encoded = RateEncodedDataset(train_ds, T) + val_encoded = RateEncodedDataset(val_ds, T) + + train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_loader, val_loader, T, 3072, 10 # T, D, num_classes + + +def parse_args(): + p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN") + + # Grid search parameters + p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10], + help="Network depths to test") + p.add_argument("--lambda_regs", type=float, nargs="+", + default=[0.01, 0.05, 0.1, 0.2, 0.3], + help="Lambda_reg values to test") + p.add_argument("--lambda_targets", type=float, nargs="+", + default=[0.0, -0.05, -0.1, -0.2], + help="Lambda_target values to test") + + # Model parameters + p.add_argument("--hidden_dim", type=int, default=256) + p.add_argument("--epochs", type=int, default=15) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--seed", type=int, default=42) + + # Data + p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)") + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding") + + # Output + p.add_argument("--out_dir", type=str, default="runs/grid_search") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("HYPERPARAMETER GRID SEARCH") + print("=" * 80) + print(f"Depths: {args.depths}") + print(f"λ_reg values: {args.lambda_regs}") + print(f"λ_target values: {args.lambda_targets}") + print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}") + print(f"Epochs per config: {args.epochs}") + print(f"Device: {device}") + print("=" * 80) + + # Load data + if args.synthetic: + print("\nUsing synthetic data") + train_loader, val_loader, T, D, C = create_synthetic_data( + seed=args.seed, batch_size=args.batch_size + ) + else: + print("\nUsing CIFAR-10 with rate encoding") + train_loader, val_loader, T, D, C = get_cifar10_loaders( + batch_size=args.batch_size, + T=args.T, + data_dir=args.data_dir + ) + + print(f"Data: T={T}, D={D}, classes={C}\n") + + # Run grid search + results = run_grid_search( + depths=args.depths, + lambda_regs=args.lambda_regs, + lambda_targets=args.lambda_targets, + train_loader=train_loader, + val_loader=val_loader, + input_dim=D, + num_classes=C, + hidden_dim=args.hidden_dim, + epochs=args.epochs, + lr=args.lr, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Analyze results + analysis = analyze_results(results) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, analysis, output_dir, vars(args)) + + # Generate plots + plot_grid_search(results, output_dir) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/lyapunov_diffonly_benchmark.py b/files/experiments/lyapunov_diffonly_benchmark.py new file mode 100644 index 0000000..05dbcd2 --- /dev/null +++ b/files/experiments/lyapunov_diffonly_benchmark.py @@ -0,0 +1,590 @@ +""" +Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation. + +Optimization B: Instead of storing two full membrane trajectories: + mems[i][0] = base trajectory + mems[i][1] = perturbed trajectory + +Store only: + base_mems[i] = base trajectory + delta_mems[i] = perturbation (perturbed - base) + +Benefits: + - ~2x less memory for membrane states + - Fewer memory reads/writes during renormalization + - Better cache utilization +""" + +import os +import sys +import time +import torch +import torch.nn as nn +from typing import Tuple, Optional, List + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import snntorch as snn +from snntorch import surrogate + + +class SpikingVGGBlock(nn.Module): + """Conv-BN-LIF block.""" + + def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): + super().__init__() + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + + def forward(self, x, mem): + h = self.bn(self.conv(x)) + spk, mem = self.lif(h, mem) + return spk, mem + + +class SpikingVGG_Original(nn.Module): + """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W).""" + + def __init__(self, in_channels=3, num_classes=100, base_channels=64, + num_stages=3, blocks_per_stage=2, T=4, beta=0.9): + super().__init__() + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + + # Build stages + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + in_ch = in_channels + out_ch = base_channels + current_size = 32 # CIFAR + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for _ in range(blocks_per_stage): + stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) + in_ch = out_ch + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + current_size //= 2 + if stage < num_stages - 1: + out_ch = min(out_ch * 2, 512) + + self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) + self._channel_sizes = self._compute_channel_sizes(base_channels) + + def _compute_channel_sizes(self, base): + sizes = [] + ch = base + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + sizes.append(ch) + if stage < self.num_stages - 1: + ch = min(ch * 2, 512) + return sizes + + def _init_mems(self, batch_size, device, dtype, P=1): + mems = [] + H, W = 32, 32 + for stage in range(self.num_stages): + for block_idx in range(self.blocks_per_stage): + layer_idx = stage * self.blocks_per_stage + block_idx + ch = self._channel_sizes[layer_idx] + mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + return mems + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + P = 2 if compute_lyapunov else 1 + + mems = self._init_mems(B, device, dtype, P=P) + + if compute_lyapunov: + for i in range(len(mems)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = None + + for t in range(self.T): + mem_idx = 0 + new_mems = [] + is_first_block = True + + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + h_conv = block.bn(block.conv(x)) + h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + h = spk + new_mems.append(mem_new) + is_first_block = False + else: + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + h_conv = block.bn(block.conv(h_flat)) + spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + h = spk + new_mems.append(mem_new) + mem_idx += 1 + + h_flat = h.reshape(P * B, *h.shape[2:]) + h_pooled = pool(h_flat) + h = h_pooled.view(P, B, *h_pooled.shape[1:]) + + mems = new_mems + + h_orig = h[0].view(B, -1) + if spike_sum is None: + spike_sum = h_orig + else: + spike_sum = spike_sum + h_orig + + if compute_lyapunov: + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + scale = (lyap_eps / delta).view(B, 1, 1, 1) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + mems[i] = torch.stack([ + new_mems[i][0], + new_mems[i][0] + diff * scale + ], dim=0) + + logits = self.fc(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +class SpikingVGG_DiffOnly(nn.Module): + """ + Optimized implementation: stores base + diff instead of 2 full trajectories. + + Memory layout: + base_mems[i]: (B, C, H, W) - base trajectory membrane + delta_mems[i]: (B, C, H, W) - perturbation vector + + Perturbed trajectory is materialized as (base + delta) only when needed. + """ + + def __init__(self, in_channels=3, num_classes=100, base_channels=64, + num_stages=3, blocks_per_stage=2, T=4, beta=0.9): + super().__init__() + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + in_ch = in_channels + out_ch = base_channels + current_size = 32 + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for _ in range(blocks_per_stage): + stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) + in_ch = out_ch + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + current_size //= 2 + if stage < num_stages - 1: + out_ch = min(out_ch * 2, 512) + + self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) + self._channel_sizes = self._compute_channel_sizes(base_channels) + + def _compute_channel_sizes(self, base): + sizes = [] + ch = base + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + sizes.append(ch) + if stage < self.num_stages - 1: + ch = min(ch * 2, 512) + return sizes + + def _init_mems(self, batch_size, device, dtype): + """Initialize base membrane states (B, C, H, W).""" + base_mems = [] + H, W = 32, 32 + for stage in range(self.num_stages): + for block_idx in range(self.blocks_per_stage): + layer_idx = stage * self.blocks_per_stage + block_idx + ch = self._channel_sizes[layer_idx] + base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + return base_mems + + def _init_deltas(self, base_mems, lyap_eps): + """Initialize perturbation vectors δ with ||δ||_global = eps.""" + delta_mems = [] + for base in base_mems: + delta_mems.append(lyap_eps * torch.randn_like(base)) + return delta_mems + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + + # Initialize base membrane states + base_mems = self._init_mems(B, device, dtype) + + # Initialize perturbations if computing Lyapunov + if compute_lyapunov: + delta_mems = self._init_deltas(base_mems, lyap_eps) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + else: + delta_mems = None + + spike_sum = None + + for t in range(self.T): + mem_idx = 0 + new_base_mems = [] + new_delta_mems = [] if compute_lyapunov else None + + # Track spikes for base and perturbed (if computing Lyapunov) + h_base = None + h_delta = None # Will store (h_perturbed - h_base) + is_first_block = True + + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + # First block: input x is same for both trajectories + h_conv = block.bn(block.conv(x)) # (B, C, H, W) + + # Base trajectory + spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx]) + new_base_mems.append(mem_base_new) + h_base = spk_base + + if compute_lyapunov: + # Perturbed trajectory: mem = base + delta + mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] + spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed) + # Store delta for new membrane + new_delta_mems.append(mem_perturbed_new - mem_base_new) + # Store spike difference for propagation + h_delta = spk_perturbed - spk_base + + is_first_block = False + else: + # Subsequent blocks: inputs differ + # Base trajectory + h_conv_base = block.bn(block.conv(h_base)) + spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx]) + new_base_mems.append(mem_base_new) + + if compute_lyapunov: + # Perturbed trajectory: h_perturbed = h_base + h_delta + h_perturbed = h_base + h_delta + h_conv_perturbed = block.bn(block.conv(h_perturbed)) + mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] + spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed) + new_delta_mems.append(mem_perturbed_new - mem_base_new) + h_delta = spk_perturbed - spk_base + + h_base = spk_base + + mem_idx += 1 + + # Pooling + h_base = pool(h_base) + if compute_lyapunov: + # Pool both and compute new delta + h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general + # But for AvgPool, it's linear so this is fine + h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool + h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear) + + # Update membrane states + base_mems = new_base_mems + + # Accumulate spikes from base trajectory + h_flat = h_base.view(B, -1) + if spike_sum is None: + spike_sum = h_flat + else: + spike_sum = spike_sum + h_flat + + # Lyapunov: compute global divergence and renormalize + if compute_lyapunov: + # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||² + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for delta in new_delta_mems: + delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3)) + + delta_norm = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12) + + # Renormalize: scale all deltas so ||δ||_global = eps + scale = (lyap_eps / delta_norm).view(B, 1, 1, 1) + delta_mems = [delta * scale for delta in new_delta_mems] + + logits = self.fc(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + +def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20): + """Benchmark forward pass time.""" + device = x.device + + # Warmup + for _ in range(num_warmup): + with torch.no_grad(): + _ = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(num_runs): + torch.cuda.synchronize() + start = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return times, lyap + + +def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov, + lambda_reg=0.3, num_warmup=5, num_runs=20): + """Benchmark forward + backward pass time.""" + device = x.device + + # Warmup + for _ in range(num_warmup): + model.zero_grad() + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + loss = criterion(logits, y) + if compute_lyapunov and lyap is not None: + loss = loss + lambda_reg * (lyap ** 2) + loss.backward() + + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(num_runs): + model.zero_grad() + torch.cuda.synchronize() + start = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + loss = criterion(logits, y) + if compute_lyapunov and lyap is not None: + loss = loss + lambda_reg * (lyap ** 2) + loss.backward() + + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return times + + +def measure_memory(model, x, compute_lyapunov): + """Measure peak GPU memory during forward pass.""" + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + with torch.no_grad(): + _ = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB + return peak_mem + + +def run_benchmark(): + print("=" * 70) + print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + if device.type == "cuda": + print(f"GPU: {torch.cuda.get_device_name()}") + + # Test configurations + configs = [ + {"depth": 4, "blocks_per_stage": 1, "batch_size": 64}, + {"depth": 8, "blocks_per_stage": 2, "batch_size": 64}, + {"depth": 12, "blocks_per_stage": 3, "batch_size": 32}, + ] + + print("\n" + "=" * 70) + + for cfg in configs: + depth = cfg["depth"] + blocks = cfg["blocks_per_stage"] + batch_size = cfg["batch_size"] + + print(f"\n{'='*70}") + print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}") + print(f"{'='*70}") + + # Create models + model_orig = SpikingVGG_Original( + blocks_per_stage=blocks, T=4 + ).to(device) + + model_diff = SpikingVGG_DiffOnly( + blocks_per_stage=blocks, T=4 + ).to(device) + + # Copy weights from original to diff-only + model_diff.load_state_dict(model_orig.state_dict()) + + print(f"Parameters: {count_parameters(model_orig):,}") + + # Create input + x = torch.randn(batch_size, 3, 32, 32, device=device) + y = torch.randint(0, 100, (batch_size,), device=device) + criterion = nn.CrossEntropyLoss() + + # ============================================================ + # Test 1: Verify outputs match + # ============================================================ + print("\n--- Output Verification ---") + model_orig.eval() + model_diff.eval() + + torch.manual_seed(42) + with torch.no_grad(): + logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4) + + torch.manual_seed(42) + with torch.no_grad(): + logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4) + + logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5) + lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations + + print(f"Logits match: {logits_match}") + print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}") + print(f"Lyapunov close (within 0.1): {lyap_close}") + + # ============================================================ + # Test 2: Forward-only speed (no grad) + # ============================================================ + print("\n--- Forward Speed (no_grad) ---") + model_orig.eval() + model_diff.eval() + + # Without Lyapunov + times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False) + times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False) + + mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000 + mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000 + + print(f" Without Lyapunov:") + print(f" Original: {mean_orig:.2f} ms") + print(f" Diff-only: {mean_diff:.2f} ms") + + # With Lyapunov + times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True) + times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True) + + mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000 + mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000 + speedup = mean_orig_ly / mean_diff_ly + + print(f" With Lyapunov:") + print(f" Original: {mean_orig_ly:.2f} ms") + print(f" Diff-only: {mean_diff_ly:.2f} ms") + print(f" Speedup: {speedup:.2f}x") + + # ============================================================ + # Test 3: Forward + Backward speed (training mode) + # ============================================================ + print("\n--- Forward+Backward Speed (training) ---") + model_orig.train() + model_diff.train() + + times_orig_train = benchmark_forward_backward( + model_orig, x, y, criterion, compute_lyapunov=True + ) + times_diff_train = benchmark_forward_backward( + model_diff, x, y, criterion, compute_lyapunov=True + ) + + mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000 + mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000 + speedup_train = mean_orig_train / mean_diff_train + + print(f" With Lyapunov + backward:") + print(f" Original: {mean_orig_train:.2f} ms") + print(f" Diff-only: {mean_diff_train:.2f} ms") + print(f" Speedup: {speedup_train:.2f}x") + + # ============================================================ + # Test 4: Memory usage + # ============================================================ + if device.type == "cuda": + print("\n--- Peak GPU Memory ---") + + mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False) + mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False) + + mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True) + mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True) + + print(f" Without Lyapunov:") + print(f" Original: {mem_orig_noly:.1f} MB") + print(f" Diff-only: {mem_diff_noly:.1f} MB") + print(f" With Lyapunov:") + print(f" Original: {mem_orig_ly:.1f} MB") + print(f" Diff-only: {mem_diff_ly:.1f} MB") + print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)") + + # Cleanup + del model_orig, model_diff, x, y + torch.cuda.empty_cache() + + print("\n" + "=" * 70) + print("BENCHMARK COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + run_benchmark() diff --git a/files/experiments/lyapunov_speedup_benchmark.py b/files/experiments/lyapunov_speedup_benchmark.py new file mode 100644 index 0000000..117009b --- /dev/null +++ b/files/experiments/lyapunov_speedup_benchmark.py @@ -0,0 +1,638 @@ +""" +Lyapunov Computation Speedup Benchmark + +Tests different optimization approaches for computing Lyapunov exponents +during SNN training. All approaches should produce equivalent results +(within numerical precision) but with different performance characteristics. + +Approaches tested: +- Baseline: Current sequential implementation +- Approach A: Trajectory-as-batch (P=2), share first Linear +- Approach B: Global-norm divergence + single-scale renorm +- Approach C: torch.compile the time loop +- Combined: A + B + C together +""" + +import os +import sys +import time +from typing import Tuple, Optional, List +from dataclasses import dataclass + +import torch +import torch.nn as nn +import snntorch as snn +from snntorch import surrogate + +# Ensure we can import from project +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + + +# ============================================================================= +# Baseline Implementation (Current) +# ============================================================================= + +class BaselineSNN(nn.Module): + """Current implementation: sequential perturbed trajectory.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + # Simple feedforward for benchmarking (not full VGG) + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims # Flattened input + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) # Flatten + + # Init membrane potentials + mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Original trajectory + h = x + new_mems = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h = lin(h) + spk, mem = lif(h, mems[i]) + new_mems.append(mem) + h = spk + mems = new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory (SEPARATE PASS - this is slow) + h_p = x + new_mems_p = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h_p = lin(h_p) + spk_p, mem_p = lif(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = spk_p + + # Divergence (per-layer norms, then sum) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + delta_sq += (diff ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize (per-layer - SLOW) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 + new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm + + mems_p = new_mems_p + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach A: Trajectory-as-batch (P=2), share first Linear +# ============================================================================= + +class ApproachA_SNN(nn.Module): + """Batch both trajectories together, share Linear_1.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + P = 2 if compute_lyapunov else 1 + + # State layout: (P, B, H) where P=2 for [original, perturbed] + mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + # Initialize perturbed state + for i in range(len(self.hidden_dims)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Layer 1: compute Linear ONCE, expand to (P, B, H1) + h1 = self.linears[0](x) # (B, H1) - computed ONCE + + if compute_lyapunov: + h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view + else: + h = h1.unsqueeze(0) # (1, B, H1) + + # LIF layer 1 + spk, mems[0] = self.lifs[0](h, mems[0]) + h = spk + + # Layers 2+: different inputs for each trajectory + for i in range(1, len(self.hidden_dims)): + # Reshape to (P*B, H) for batched Linear + h_flat = h.reshape(P * B, -1) + h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) + spk, mems[i] = self.lifs[i](h_lin, mems[i]) + h = spk + + # Accumulate spikes from original trajectory only + spike_sum = spike_sum + h[0] + + if compute_lyapunov: + # Global divergence across all layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] # (B, H_i) + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # Renormalize with global scale (per-layer still, but simpler) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 + mems[i][1] = mems[i][0] + lyap_eps * diff / norm + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach B: Global-norm divergence + single-scale renorm +# ============================================================================= + +class ApproachB_SNN(nn.Module): + """Global norm for divergence, single scale factor for renorm.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Original trajectory + h = x + new_mems = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h = lin(h) + spk, mem = lif(h, mems[i]) + new_mems.append(mem) + h = spk + mems = new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory + h_p = x + new_mems_p = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h_p = lin(h_p) + spk_p, mem_p = lif(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = spk_p + + # GLOBAL divergence (one delta per batch element) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # SINGLE SCALE renormalization (key optimization) + scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + new_mems_p[i] = new_mems[i] + diff * scale + + mems_p = new_mems_p + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach A+B Combined: Batched trajectories + global renorm +# ============================================================================= + +class ApproachAB_SNN(nn.Module): + """Combined: trajectory-as-batch + global-norm renorm.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + P = 2 if compute_lyapunov else 1 + + # State: (P, B, H) + mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + for i in range(len(self.hidden_dims)): + mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Layer 1: Linear computed ONCE + h1 = self.linears[0](x) + h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0) + + spk, mems[0] = self.lifs[0](h, mems[0]) + h = spk + + # Layers 2+ + for i in range(1, len(self.hidden_dims)): + h_flat = h.reshape(P * B, -1) + h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) + spk, mems[i] = self.lifs[i](h_lin, mems[i]) + h = spk + + spike_sum = spike_sum + h[0] + + if compute_lyapunov: + # Global divergence + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # Global scale renorm + scale = (lyap_eps / delta).unsqueeze(-1) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + mems[i][1] = mems[i][0] + diff * scale + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach C: torch.compile wrapper +# ============================================================================= + +def make_compiled_model(model_class, *args, **kwargs): + """Create a model and compile its forward pass.""" + model = model_class(*args, **kwargs) + # Compile the forward method + model.forward = torch.compile(model.forward, mode="reduce-overhead") + return model + + +# ============================================================================= +# Benchmarking +# ============================================================================= + +@dataclass +class BenchmarkResult: + name: str + forward_time_ms: float + backward_time_ms: float + total_time_ms: float + lyap_value: float + memory_mb: float + + def __str__(self): + return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | " + f"Bwd: {self.backward_time_ms:7.2f}ms | " + f"Total: {self.total_time_ms:7.2f}ms | " + f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB") + + +def benchmark_model( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + name: str, + warmup_iters: int = 5, + bench_iters: int = 20, +) -> BenchmarkResult: + """Benchmark a single model configuration.""" + + device = x.device + criterion = nn.CrossEntropyLoss() + + # Warmup + for _ in range(warmup_iters): + logits, lyap = model(x, compute_lyapunov=True) + loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) + loss.backward() + model.zero_grad() + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + fwd_times = [] + bwd_times = [] + lyap_vals = [] + + for _ in range(bench_iters): + # Forward + torch.cuda.synchronize() + t0 = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=True) + loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) + + torch.cuda.synchronize() + t1 = time.perf_counter() + + # Backward + loss.backward() + + torch.cuda.synchronize() + t2 = time.perf_counter() + + fwd_times.append((t1 - t0) * 1000) + bwd_times.append((t2 - t1) * 1000) + if lyap is not None: + lyap_vals.append(lyap.item()) + + model.zero_grad() + + peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 + + return BenchmarkResult( + name=name, + forward_time_ms=sum(fwd_times) / len(fwd_times), + backward_time_ms=sum(bwd_times) / len(bwd_times), + total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times), + lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0, + memory_mb=peak_mem, + ) + + +def run_benchmarks( + batch_size: int = 64, + T: int = 4, + hidden_dims: List[int] = [64, 128, 256], + device: str = "cuda", +): + """Run all benchmarks and compare.""" + + print("=" * 80) + print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK") + print("=" * 80) + print(f"Batch size: {batch_size}") + print(f"Timesteps: {T}") + print(f"Hidden dims: {hidden_dims}") + print(f"Device: {device}") + print("=" * 80) + + # Create dummy data + x = torch.randn(batch_size, 3, 32, 32, device=device) + y = torch.randint(0, 10, (batch_size,), device=device) + + results = [] + + # 1. Baseline + print("\n[1/6] Benchmarking Baseline...") + model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "Baseline")) + del model + torch.cuda.empty_cache() + + # 2. Approach A (batched trajectories) + print("[2/6] Benchmarking Approach A (batched)...") + model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "A: Batched trajectories")) + del model + torch.cuda.empty_cache() + + # 3. Approach B (global renorm) + print("[3/6] Benchmarking Approach B (global renorm)...") + model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "B: Global renorm")) + del model + torch.cuda.empty_cache() + + # 4. Approach A+B combined + print("[4/6] Benchmarking Approach A+B (combined)...") + model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "A+B: Combined")) + del model + torch.cuda.empty_cache() + + # 5. Approach C (torch.compile on baseline) + print("[5/6] Benchmarking Approach C (compiled baseline)...") + try: + model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10)) + del model + torch.cuda.empty_cache() + except Exception as e: + print(f" torch.compile failed: {e}") + results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0)) + + # 6. A+B+C (all combined) + print("[6/6] Benchmarking A+B+C (all optimizations)...") + try: + model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10)) + del model + torch.cuda.empty_cache() + except Exception as e: + print(f" torch.compile failed: {e}") + results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0)) + + # Print results + print("\n" + "=" * 80) + print("RESULTS") + print("=" * 80) + + baseline_time = results[0].total_time_ms + + for r in results: + print(r) + + print("\n" + "-" * 80) + print("SPEEDUP vs BASELINE:") + print("-" * 80) + + for r in results[1:]: + if r.total_time_ms > 0: + speedup = baseline_time / r.total_time_ms + print(f" {r.name:<25}: {speedup:.2f}x") + + # Verify Lyapunov values are consistent + print("\n" + "-" * 80) + print("LYAPUNOV VALUE CONSISTENCY CHECK:") + print("-" * 80) + + base_lyap = results[0].lyap_value + for r in results[1:]: + if r.lyap_value != 0: + diff = abs(r.lyap_value - base_lyap) + status = "✓" if diff < 0.1 else "✗" + print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}") + + return results + + +def run_scaling_test(device: str = "cuda"): + """Test how approaches scale with batch size and timesteps.""" + + print("\n" + "=" * 80) + print("SCALING TESTS") + print("=" * 80) + + configs = [ + {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model + ] + + print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}") + print("-" * 80) + + for cfg in configs: + x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device) + y = torch.randint(0, 10, (cfg["batch_size"],), device=device) + + # Baseline + model_base = BaselineSNN(**cfg).to(device) + r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10) + del model_base + + # A+B + model_ab = ApproachAB_SNN(**cfg).to(device) + r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10) + del model_ab + + torch.cuda.empty_cache() + + speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0 + + cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}" + print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--T", type=int, default=4) + parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256]) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--scaling", action="store_true", help="Run scaling tests") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA not available, using CPU (results will not be representative)") + args.device = "cpu" + + # Main benchmark + results = run_benchmarks( + batch_size=args.batch_size, + T=args.T, + hidden_dims=args.hidden_dims, + device=args.device, + ) + + # Scaling tests + if args.scaling: + run_scaling_test(args.device) diff --git a/files/experiments/plot_depth_comparison.py b/files/experiments/plot_depth_comparison.py new file mode 100644 index 0000000..2222b7b --- /dev/null +++ b/files/experiments/plot_depth_comparison.py @@ -0,0 +1,305 @@ +""" +Visualization for depth comparison experiments. + +Usage: + python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP +""" + +import os +import sys +import json +import argparse +from typing import Dict, List + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + + +def load_results(results_dir: str) -> Dict: + """Load results from JSON file.""" + with open(os.path.join(results_dir, "results.json"), "r") as f: + return json.load(f) + + +def load_config(results_dir: str) -> Dict: + """Load config from JSON file.""" + config_path = os.path.join(results_dir, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + return {} + + +def plot_training_curves(results: Dict, output_path: str): + """ + Plot training curves for each depth. + + Creates a figure with subplots for each depth showing: + - Training loss + - Validation accuracy + - Lyapunov exponent (if available) + - Gradient norm + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + n_depths = len(depths) + + fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths)) + if n_depths == 1: + axes = axes.reshape(1, -1) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"} + + for i, depth in enumerate(depths): + for method in ["vanilla", "lyapunov"]: + metrics = results[method][str(depth)] + epochs = [m["epoch"] for m in metrics] + + # Training Loss + train_loss = [m["train_loss"] for m in metrics] + axes[i, 0].plot(epochs, train_loss, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 0].set_ylabel("Train Loss") + axes[i, 0].set_title(f"Depth={depth}: Training Loss") + axes[i, 0].set_yscale("log") + axes[i, 0].grid(True, alpha=0.3) + + # Validation Accuracy + val_acc = [m["val_acc"] for m in metrics] + axes[i, 1].plot(epochs, val_acc, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 1].set_ylabel("Val Accuracy") + axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy") + axes[i, 1].set_ylim(0, 1) + axes[i, 1].grid(True, alpha=0.3) + + # Lyapunov Exponent + lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None] + lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None] + if lyap: + axes[i, 2].plot(lyap_epochs, lyap, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5) + axes[i, 2].set_ylabel("Lyapunov λ") + axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent") + axes[i, 2].grid(True, alpha=0.3) + + # Gradient Norm + grad_norm = [m["grad_norm"] for m in metrics] + axes[i, 3].plot(epochs, grad_norm, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 3].set_ylabel("Gradient Norm") + axes[i, 3].set_title(f"Depth={depth}: Gradient Norm") + axes[i, 3].set_yscale("log") + axes[i, 3].grid(True, alpha=0.3) + + # Add legend to first row + if i == 0: + for ax in axes[i]: + ax.legend(loc="upper right") + + # Set x-labels on bottom row + for ax in axes[-1]: + ax.set_xlabel("Epoch") + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved training curves to {output_path}") + + +def plot_depth_summary(results: Dict, output_path: str): + """ + Plot summary comparing methods across depths. + + Creates a figure showing: + - Final validation accuracy vs depth + - Final gradient norm vs depth + - Final Lyapunov exponent vs depth + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + + fig, axes = plt.subplots(1, 3, figsize=(14, 4)) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + markers = {"vanilla": "o", "lyapunov": "s"} + + # Collect final metrics + van_acc = [] + lyap_acc = [] + van_grad = [] + lyap_grad = [] + lyap_lambda = [] + + for depth in depths: + van_metrics = results["vanilla"][str(depth)][-1] + lyap_metrics = results["lyapunov"][str(depth)][-1] + + van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0) + lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0) + + van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0) + lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0) + + if lyap_metrics["lyapunov"] is not None: + lyap_lambda.append(lyap_metrics["lyapunov"]) + else: + lyap_lambda.append(0) + + # Plot 1: Validation Accuracy vs Depth + ax = axes[0] + ax.plot(depths, van_acc, 'o-', color=colors["vanilla"], + label="Vanilla", linewidth=2, markersize=8) + ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"], + label="Lyapunov", linewidth=2, markersize=8) + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Validation Accuracy") + ax.set_title("Accuracy vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05) + + # Plot 2: Gradient Norm vs Depth + ax = axes[1] + ax.plot(depths, van_grad, 'o-', color=colors["vanilla"], + label="Vanilla", linewidth=2, markersize=8) + ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"], + label="Lyapunov", linewidth=2, markersize=8) + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Gradient Norm") + ax.set_title("Gradient Stability vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_yscale("log") + + # Plot 3: Lyapunov Exponent vs Depth + ax = axes[2] + ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"], + linewidth=2, markersize=8) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)") + ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region") + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Lyapunov Exponent") + ax.set_title("Lyapunov Exponent vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved depth summary to {output_path}") + + +def plot_stability_comparison(results: Dict, output_path: str): + """ + Plot stability metrics comparison. + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + + # Collect metrics over training + for depth in depths: + van_metrics = results["vanilla"][str(depth)] + lyap_metrics = results["lyapunov"][str(depth)] + + van_epochs = [m["epoch"] for m in van_metrics] + lyap_epochs = [m["epoch"] for m in lyap_metrics] + + # Firing rate + van_fr = [m["firing_rate"] for m in van_metrics] + lyap_fr = [m["firing_rate"] for m in lyap_metrics] + axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"], + alpha=0.3 + 0.1 * depths.index(depth)) + axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"], + alpha=0.3 + 0.1 * depths.index(depth)) + + # Dead neurons + van_dead = [m["dead_neurons"] for m in van_metrics] + lyap_dead = [m["dead_neurons"] for m in lyap_metrics] + axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"], + alpha=0.3 + 0.1 * depths.index(depth)) + axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"], + alpha=0.3 + 0.1 * depths.index(depth)) + + axes[0, 0].set_xlabel("Epoch") + axes[0, 0].set_ylabel("Firing Rate") + axes[0, 0].set_title("Firing Rate Over Training") + axes[0, 0].grid(True, alpha=0.3) + + axes[0, 1].set_xlabel("Epoch") + axes[0, 1].set_ylabel("Dead Neuron Fraction") + axes[0, 1].set_title("Dead Neurons Over Training") + axes[0, 1].grid(True, alpha=0.3) + + # Final metrics bar chart + van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths] + lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths] + + x = np.arange(len(depths)) + width = 0.35 + + axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"]) + axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"]) + axes[1, 0].set_xlabel("Network Depth") + axes[1, 0].set_ylabel("Final Validation Accuracy") + axes[1, 0].set_title("Final Accuracy Comparison") + axes[1, 0].set_xticks(x) + axes[1, 0].set_xticklabels(depths) + axes[1, 0].legend() + axes[1, 0].grid(True, alpha=0.3, axis='y') + + # Improvement percentage + improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)] + colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements] + + axes[1, 1].bar(x, improvements, color=colors_bar) + axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5) + axes[1, 1].set_xlabel("Network Depth") + axes[1, 1].set_ylabel("Accuracy Improvement") + axes[1, 1].set_title("Lyapunov Improvement over Vanilla") + axes[1, 1].set_xticks(x) + axes[1, 1].set_xticklabels(depths) + axes[1, 1].grid(True, alpha=0.3, axis='y') + + # Add legend for line plots + custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2), + Line2D([0], [0], color=colors["lyapunov"], lw=2)] + axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov']) + axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov']) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved stability comparison to {output_path}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--results_dir", type=str, required=True, + help="Directory containing results.json") + parser.add_argument("--output_dir", type=str, default=None, + help="Output directory for plots (default: same as results_dir)") + args = parser.parse_args() + + output_dir = args.output_dir or args.results_dir + + print(f"Loading results from {args.results_dir}") + results = load_results(args.results_dir) + config = load_config(args.results_dir) + + print(f"Config: {config}") + + # Generate plots + plot_training_curves(results, os.path.join(output_dir, "training_curves.png")) + plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png")) + plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png")) + + print(f"\nAll plots saved to {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/files/experiments/posthoc_finetune.py b/files/experiments/posthoc_finetune.py new file mode 100644 index 0000000..3f3bf6c --- /dev/null +++ b/files/experiments/posthoc_finetune.py @@ -0,0 +1,323 @@ +""" +Post-hoc Lyapunov Fine-tuning Experiment + +Strategy: +1. Train network with vanilla (no Lyapunov) for N epochs +2. Then fine-tune with Lyapunov regularization for M epochs + +This allows the network to learn task-relevant features first, +then stabilize dynamics without starting from chaotic initialization. +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from files.experiments.depth_scaling_benchmark import ( + SpikingVGG, + get_dataset, + train_epoch, + evaluate, + TrainingMetrics, + compute_lyap_reg_loss, +) + + +def run_posthoc_experiment( + dataset_name: str, + depth_config: Tuple[int, int], + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + pretrain_epochs: int, + finetune_epochs: int, + lr: float, + finetune_lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + reg_type: str = "extreme", + lyap_threshold: float = 2.0, + progress: bool = True, +) -> Dict: + """Run post-hoc fine-tuning experiment.""" + torch.manual_seed(seed) + + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + print(f"\n{'='*60}") + print(f"POST-HOC FINE-TUNING: Depth = {total_depth}") + print(f"Pretrain: {pretrain_epochs} epochs (vanilla)") + print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})") + print(f"{'='*60}") + + model = SpikingVGG( + in_channels=in_channels, + num_classes=num_classes, + base_channels=64, + num_stages=num_stages, + blocks_per_stage=blocks_per_stage, + T=T, + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Parameters: {num_params:,}") + + criterion = nn.CrossEntropyLoss() + + # Phase 1: Vanilla pre-training + print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---") + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs) + + pretrain_history = [] + best_pretrain_acc = 0.0 + + for epoch in range(1, pretrain_epochs + 1): + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov=False, # No Lyapunov during pre-training + lambda_reg=0, lambda_target=0, lyap_eps=1e-4, + progress=progress, + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_pretrain_acc = max(best_pretrain_acc, test_acc) + + metrics = TrainingMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + pretrain_history.append(metrics) + + if epoch % 10 == 0 or epoch == pretrain_epochs: + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}") + + print(f" Best pretrain acc: {best_pretrain_acc:.3f}") + + # Phase 2: Lyapunov fine-tuning + print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---") + print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}") + + # Reset optimizer with lower learning rate for fine-tuning + optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs) + + finetune_history = [] + best_finetune_acc = 0.0 + + for epoch in range(1, finetune_epochs + 1): + t0 = time.time() + + # Warmup lambda_reg over first 10 epochs of fine-tuning + warmup_epochs = 10 + if epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov=True, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=1e-4, + progress=progress, + reg_type=reg_type, + current_lambda_reg=current_lambda_reg, + lyap_threshold=lyap_threshold, + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_finetune_acc = max(best_finetune_acc, test_acc) + + metrics = TrainingMetrics( + epoch=pretrain_epochs + epoch, # Continue epoch numbering + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + finetune_history.append(metrics) + + if epoch % 10 == 0 or epoch == finetune_epochs: + lyap_str = f"λ={lyap:.3f}" if lyap else "" + print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}") + + if np.isnan(train_loss): + print(f" DIVERGED at epoch {epoch}") + break + + print(f" Best finetune acc: {best_finetune_acc:.3f}") + print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "") + + return { + "depth": total_depth, + "pretrain_history": pretrain_history, + "finetune_history": finetune_history, + "best_pretrain_acc": best_pretrain_acc, + "best_finetune_acc": best_finetune_acc, + } + + +def main(): + parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning") + parser.add_argument("--dataset", type=str, default="cifar100", + choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) + parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16]) + parser.add_argument("--T", type=int, default=4) + parser.add_argument("--pretrain_epochs", type=int, default=100) + parser.add_argument("--finetune_epochs", type=int, default=50) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--finetune_lr", type=float, default=1e-4) + parser.add_argument("--lambda_reg", type=float, default=0.1) + parser.add_argument("--lambda_target", type=float, default=-0.1) + parser.add_argument("--reg_type", type=str, default="extreme") + parser.add_argument("--lyap_threshold", type=float, default=2.0) + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--no-progress", action="store_true") + + args = parser.parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT") + print("=" * 80) + print(f"Dataset: {args.dataset}") + print(f"Depths: {args.depths}") + print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})") + print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})") + print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}") + print("=" * 80) + + # Load data + train_loader, test_loader, num_classes, input_shape = get_dataset( + args.dataset, args.data_dir, args.batch_size + ) + in_channels = input_shape[0] + + # Convert depths to configs + depth_configs = [] + for d in args.depths: + if d <= 4: + depth_configs.append((d, 1)) + else: + depth_configs.append((4, d // 4)) + + # Run experiments + all_results = [] + for depth_config in depth_configs: + result = run_posthoc_experiment( + dataset_name=args.dataset, + depth_config=depth_config, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=args.T, + pretrain_epochs=args.pretrain_epochs, + finetune_epochs=args.finetune_epochs, + lr=args.lr, + finetune_lr=args.finetune_lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + reg_type=args.reg_type, + lyap_threshold=args.lyap_threshold, + progress=not args.no_progress, + ) + all_results.append(result) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}") + print("-" * 80) + + for r in all_results: + pre_acc = r["best_pretrain_acc"] + fine_acc = r["best_finetune_acc"] + change = fine_acc - pre_acc + final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None + lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A" + change_str = f"{change:+.3f}" + + print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}") + + print("=" * 80) + + # Save results + os.makedirs(args.out_dir, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json") + + serializable_results = [] + for r in all_results: + sr = { + "depth": r["depth"], + "best_pretrain_acc": r["best_pretrain_acc"], + "best_finetune_acc": r["best_finetune_acc"], + "pretrain_history": [asdict(m) for m in r["pretrain_history"]], + "finetune_history": [asdict(m) for m in r["finetune_history"]], + } + serializable_results.append(sr) + + with open(output_file, "w") as f: + json.dump({"config": vars(args), "results": serializable_results}, f, indent=2) + + print(f"\nResults saved to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/files/experiments/scaled_reg_grid_search.py b/files/experiments/scaled_reg_grid_search.py new file mode 100644 index 0000000..928caff --- /dev/null +++ b/files/experiments/scaled_reg_grid_search.py @@ -0,0 +1,301 @@ +""" +Grid Search: Multiplier-Scaled Regularization Experiments + +Tests the new multiplier-scaled regularization approach: + loss = (λ_reg × g(relu(λ))) × relu(λ) + +Where g(x) is the multiplier scaling function: + - mult_linear: g(x) = x → loss = λ_reg × relu(λ)² + - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³ + - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ) + +Grid: + - λ_reg: 0.01, 0.05, 0.1, 0.3 + - reg_type: mult_linear, mult_squared, mult_log + - depth: specified via command line + +Usage: + python scaled_reg_grid_search.py --depth 4 + python scaled_reg_grid_search.py --depth 8 + python scaled_reg_grid_search.py --depth 12 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional +from itertools import product + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +# Import from main benchmark +from depth_scaling_benchmark import ( + SpikingVGG, + compute_lyap_reg_loss, +) + +import snntorch as snn +from snntorch import surrogate + + +@dataclass +class ExperimentResult: + depth: int + reg_type: str + lambda_reg: float + vanilla_acc: float + lyapunov_acc: float + final_lyap: Optional[float] + delta: float + + +def train_epoch(model, loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, reg_type, progress=False): + """Train one epoch.""" + model.train() + total_loss = 0.0 + correct = 0 + total = 0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4) + loss = criterion(logits, y) + + if use_lyapunov and lyap_est is not None: + # Target is implicitly 0 for scaled reg types + lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0) + loss = loss + lambda_reg * lyap_reg + lyap_vals.append(lyap_est.item()) + + loss.backward() + optimizer.step() + + total_loss += loss.item() * x.size(0) + _, pred = logits.max(1) + correct += pred.eq(y).sum().item() + total += x.size(0) + + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return total_loss / total, correct / total, avg_lyap + + +def evaluate(model, loader, device): + """Evaluate model.""" + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + _, pred = logits.max(1) + correct += pred.eq(y).sum().item() + total += x.size(0) + + return correct / total + + +def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader, + device, epochs=100, lr=0.001): + """Run a single experiment configuration.""" + + # Determine blocks per stage based on depth + # depth = num_stages * blocks_per_stage, with num_stages=4 + blocks_per_stage = depth // 4 + + print(f"\n{'='*60}") + print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}") + print(f"{'='*60}") + + # --- Run Vanilla baseline --- + print(f" Training Vanilla...") + model_v = SpikingVGG( + num_classes=100, + blocks_per_stage=blocks_per_stage, + T=4, + ).to(device) + + optimizer_v = optim.Adam(model_v.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss() + scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs) + + best_vanilla = 0.0 + for epoch in range(epochs): + train_epoch(model_v, train_loader, optimizer_v, criterion, device, + use_lyapunov=False, lambda_reg=0, reg_type="squared") + scheduler_v.step() + + if (epoch + 1) % 10 == 0 or epoch == epochs - 1: + acc = evaluate(model_v, test_loader, device) + best_vanilla = max(best_vanilla, acc) + print(f" Epoch {epoch+1:3d}: test={acc:.3f}") + + del model_v, optimizer_v, scheduler_v + torch.cuda.empty_cache() + + # --- Run Lyapunov version --- + print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...") + model_l = SpikingVGG( + num_classes=100, + blocks_per_stage=blocks_per_stage, + T=4, + ).to(device) + + optimizer_l = optim.Adam(model_l.parameters(), lr=lr) + scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs) + + best_lyap_acc = 0.0 + final_lyap = None + + for epoch in range(epochs): + _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device, + use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type) + scheduler_l.step() + final_lyap = lyap + + if (epoch + 1) % 10 == 0 or epoch == epochs - 1: + acc = evaluate(model_l, test_loader, device) + best_lyap_acc = max(best_lyap_acc, acc) + lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A" + print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}") + + del model_l, optimizer_l, scheduler_l + torch.cuda.empty_cache() + + delta = best_lyap_acc - best_vanilla + + result = ExperimentResult( + depth=depth, + reg_type=reg_type, + lambda_reg=lambda_reg, + vanilla_acc=best_vanilla, + lyapunov_acc=best_lyap_acc, + final_lyap=final_lyap, + delta=delta, + ) + + print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}") + + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12]) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print("=" * 70) + print("SCALED REGULARIZATION GRID SEARCH") + print("=" * 70) + print(f"Depth: {args.depth}") + print(f"Epochs: {args.epochs}") + print(f"Device: {device}") + if device.type == "cuda": + print(f"GPU: {torch.cuda.get_device_name()}") + print("=" * 70) + + # Grid parameters + lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks + reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning + + print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments") + print(f"λ_reg values: {lambda_regs}") + print(f"reg_types: {reg_types}") + + # Load data + print(f"\nLoading CIFAR-100...") + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + + train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train) + test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=4, pin_memory=True) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=4, pin_memory=True) + + print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}") + + # Run grid search + results = [] + + for lambda_reg, reg_type in product(lambda_regs, reg_types): + result = run_single_experiment( + depth=args.depth, + reg_type=reg_type, + lambda_reg=lambda_reg, + train_loader=train_loader, + test_loader=test_loader, + device=device, + epochs=args.epochs, + lr=args.lr, + ) + results.append(result) + + # Print summary table + print("\n" + "=" * 70) + print(f"SUMMARY: DEPTH = {args.depth}") + print("=" * 70) + print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}") + print("-" * 70) + + for r in results: + lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A" + delta_str = f"{r.delta:+.3f}" + print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}") + + # Find best configuration + best = max(results, key=lambda x: x.lyapunov_acc) + print("-" * 70) + print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})") + + # Save results + os.makedirs(args.out_dir, exist_ok=True) + out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json") + with open(out_file, "w") as f: + json.dump([asdict(r) for r in results], f, indent=2) + print(f"\nResults saved to: {out_file}") + + print("\n" + "=" * 70) + print("GRID SEARCH COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + main() |
