summaryrefslogtreecommitdiff
path: root/files/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/experiments
init commit
Diffstat (limited to 'files/experiments')
-rw-r--r--files/experiments/benchmark_experiment.py518
-rw-r--r--files/experiments/cifar10_conv_experiment.py448
-rw-r--r--files/experiments/depth_comparison.py542
-rw-r--r--files/experiments/depth_scaling_benchmark.py1035
-rw-r--r--files/experiments/hyperparameter_grid_search.py597
-rw-r--r--files/experiments/lyapunov_diffonly_benchmark.py590
-rw-r--r--files/experiments/lyapunov_speedup_benchmark.py638
-rw-r--r--files/experiments/plot_depth_comparison.py305
-rw-r--r--files/experiments/posthoc_finetune.py323
-rw-r--r--files/experiments/scaled_reg_grid_search.py301
10 files changed, 5297 insertions, 0 deletions
diff --git a/files/experiments/benchmark_experiment.py b/files/experiments/benchmark_experiment.py
new file mode 100644
index 0000000..fb01ff2
--- /dev/null
+++ b/files/experiments/benchmark_experiment.py
@@ -0,0 +1,518 @@
+"""
+Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets.
+
+Datasets:
+- Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks
+- Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory
+- CIFAR-10: Rate-coded images, requires hierarchical features
+
+Usage:
+ python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8
+ python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+from files.data_io.benchmark_datasets import get_benchmark_dataloader
+
+
+@dataclass
+class EpochMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ grad_max_sv: Optional[float]
+ grad_min_sv: Optional[float]
+ grad_condition: Optional[float]
+ time_sec: float
+
+
+def compute_gradient_svs(model):
+ """Compute gradient singular value statistics."""
+ max_svs = []
+ min_svs = []
+
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ with torch.no_grad():
+ G = param.grad.detach()
+ try:
+ sv = torch.linalg.svdvals(G)
+ if len(sv) > 0:
+ max_svs.append(sv[0].item())
+ min_svs.append(sv[-1].item())
+ except Exception:
+ pass
+
+ if not max_svs:
+ return None, None, None
+
+ max_sv = max(max_svs)
+ min_sv = min(min_svs)
+ condition = max_sv / (min_sv + 1e-12)
+
+ return max_sv, min_sv, condition
+
+
+def create_model(
+ input_dim: int,
+ num_classes: int,
+ depth: int,
+ hidden_dim: int = 128,
+ beta: float = 0.9,
+) -> LyapunovSNN:
+ """Create SNN with specified depth."""
+ hidden_dims = [hidden_dim] * depth
+ return LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=beta,
+ threshold=1.0,
+ )
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ compute_sv_every: int = 10,
+) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]:
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+ grad_max_svs = []
+ grad_min_svs = []
+ grad_conditions = []
+
+ for batch_idx, (x, y) in enumerate(loader):
+ x, y = x.to(device), y.to(device)
+
+ # Handle different input shapes
+ if x.ndim == 2:
+ x = x.unsqueeze(-1) # (B, T) -> (B, T, 1)
+
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ record_states=False,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan'), None, None, None
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ # Compute gradient SVs periodically
+ if batch_idx % compute_sv_every == 0:
+ max_sv, min_sv, cond = compute_gradient_svs(model)
+ if max_sv is not None:
+ grad_max_svs.append(max_sv)
+ grad_min_svs.append(min_sv)
+ grad_conditions.append(cond)
+
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ avg_loss = total_loss / total_samples
+ avg_acc = total_correct / total_samples
+ avg_lyap = np.mean(lyap_vals) if lyap_vals else None
+ avg_grad = np.mean(grad_norms)
+ avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None
+ avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None
+ avg_cond = np.mean(grad_conditions) if grad_conditions else None
+
+ return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+) -> Tuple[float, float]:
+ """Evaluate on validation set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+
+ if x.ndim == 2:
+ x = x.unsqueeze(-1)
+
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+ loss = ce_loss(logits, y)
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_experiment(
+ depth: int,
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run single experiment configuration."""
+ torch.manual_seed(seed)
+
+ model = create_model(
+ input_dim=input_dim,
+ num_classes=num_classes,
+ depth=depth,
+ hidden_dim=hidden_dim,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ method = "Lyapunov" if use_lyapunov else "Vanilla"
+ metrics_history = []
+
+ iterator = range(1, epochs + 1)
+ if progress:
+ iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False)
+
+ for epoch in iterator:
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, ce_loss, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps,
+ )
+
+ val_loss, val_acc = evaluate(model, val_loader, ce_loss, device)
+ dt = time.time() - t0
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=val_loss,
+ val_acc=val_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ if progress:
+ lyap_str = f"λ={lyap:.2f}" if lyap else ""
+ iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str})
+
+ if np.isnan(train_loss):
+ print(f" Training diverged at epoch {epoch}")
+ break
+
+ return metrics_history
+
+
+def run_depth_comparison(
+ dataset_name: str,
+ depths: List[int],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> Dict[str, Dict[int, List[EpochMetrics]]]:
+ """Run comparison across depths."""
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for depth in depths:
+ print(f"\n{'='*60}")
+ print(f"Depth = {depth} layers")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ metrics = run_experiment(
+ depth=depth,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ hidden_dim=hidden_dim,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=lyap_eps,
+ device=device,
+ seed=seed,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ final = metrics[-1]
+ lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A"
+ print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} "
+ f"val_acc={final.val_acc:.3f} {lyap_str}")
+
+ return results
+
+
+def print_summary(results: Dict, dataset_name: str):
+ """Print summary table."""
+ print("\n" + "=" * 90)
+ print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy")
+ print("=" * 90)
+ print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}")
+ print("-" * 90)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A"
+ van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A"
+
+ print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}")
+
+ print("=" * 90)
+
+ # Gradient health analysis
+ print("\nGRADIENT HEALTH:")
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ van_cond = van.grad_condition if van.grad_condition else 0
+ if van_cond > 1e6:
+ print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})")
+ elif van_cond > 1e4:
+ print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})")
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ """Save results to JSON."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN")
+
+ # Dataset
+ p.add_argument("--dataset", type=str, default="smnist",
+ choices=["smnist", "psmnist", "cifar10"],
+ help="Dataset to use")
+ p.add_argument("--data_dir", type=str, default="./data")
+
+ # Model
+ p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8],
+ help="Network depths to test")
+ p.add_argument("--hidden_dim", type=int, default=128)
+
+ # Training
+ p.add_argument("--epochs", type=int, default=30)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+
+ # Lyapunov
+ p.add_argument("--lambda_reg", type=float, default=0.3,
+ help="Lyapunov regularization weight (higher for harder tasks)")
+ p.add_argument("--lambda_target", type=float, default=-0.1,
+ help="Target Lyapunov exponent (negative for stability)")
+ p.add_argument("--lyap_eps", type=float, default=1e-4)
+
+ # Dataset-specific
+ p.add_argument("--T", type=int, default=100,
+ help="Timesteps for CIFAR-10 (sMNIST uses 784)")
+ p.add_argument("--n_repeat", type=int, default=1,
+ help="Repeat each pixel n times for sMNIST")
+
+ # Other
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--out_dir", type=str, default="runs/benchmark")
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}")
+ print("=" * 70)
+ print(f"Depths: {args.depths}")
+ print(f"Hidden dim: {args.hidden_dim}")
+ print(f"Epochs: {args.epochs}")
+ print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}")
+ print(f"Device: {device}")
+ print("=" * 70)
+
+ # Load dataset
+ print(f"\nLoading {args.dataset} dataset...")
+
+ if args.dataset == "smnist":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "smnist",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ n_repeat=args.n_repeat,
+ spike_encoding="direct",
+ )
+ elif args.dataset == "psmnist":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "psmnist",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ n_repeat=args.n_repeat,
+ spike_encoding="direct",
+ )
+ elif args.dataset == "cifar10":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "cifar10",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ T=args.T,
+ )
+
+ print(f"Dataset info: {info}")
+ print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
+
+ # Run experiments
+ results = run_depth_comparison(
+ dataset_name=args.dataset,
+ depths=args.depths,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=info["D"],
+ num_classes=info["classes"],
+ hidden_dim=args.hidden_dim,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Print summary
+ print_summary(results, args.dataset)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}")
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/cifar10_conv_experiment.py b/files/experiments/cifar10_conv_experiment.py
new file mode 100644
index 0000000..a582f9f
--- /dev/null
+++ b/files/experiments/cifar10_conv_experiment.py
@@ -0,0 +1,448 @@
+"""
+CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization.
+
+Uses proper convolutional architecture that preserves spatial structure.
+Tests whether Lyapunov regularization helps train deeper Conv-SNNs.
+
+Architecture:
+ Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output
+
+Usage:
+ python files/experiments/cifar10_conv_experiment.py --model simple --T 25
+ python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+from files.models.conv_snn import create_conv_snn
+
+
+@dataclass
+class EpochMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ time_sec: float
+
+
+def get_cifar10_loaders(
+ data_dir: str = './data',
+ batch_size: int = 128,
+ num_workers: int = 4,
+) -> Tuple[DataLoader, DataLoader]:
+ """
+ Get CIFAR-10 dataloaders with standard normalization.
+
+ Images normalized to [0, 1] for rate encoding.
+ """
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ # Note: For rate encoding, we keep values in [0, 1]
+ # No normalization to negative values
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ train_dataset = datasets.CIFAR10(
+ root=data_dir, train=True, download=True, transform=transform_train
+ )
+ test_dataset = datasets.CIFAR10(
+ root=data_dir, train=False, download=True, transform=transform_test
+ )
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True,
+ num_workers=num_workers, pin_memory=True
+ )
+ test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True
+ )
+
+ return train_loader, test_loader
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ progress: bool = True,
+) -> Tuple[float, float, Optional[float], float]:
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32)
+
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan')
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+ optimizer.step()
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ if progress:
+ iterator.set_postfix({
+ "loss": f"{loss.item():.3f}",
+ "acc": f"{total_correct/total_samples:.3f}",
+ })
+
+ return (
+ total_loss / total_samples,
+ total_correct / total_samples,
+ np.mean(lyap_vals) if lyap_vals else None,
+ np.mean(grad_norms),
+ )
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+ progress: bool = True,
+) -> Tuple[float, float]:
+ """Evaluate on test set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ iterator = tqdm(loader, desc="eval", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+
+ loss = ce_loss(logits, y)
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_experiment(
+ model_type: str,
+ channels: List[int],
+ T: int,
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run single experiment."""
+ torch.manual_seed(seed)
+
+ model = create_conv_snn(
+ model_type=model_type,
+ in_channels=3,
+ num_classes=10,
+ channels=channels,
+ T=T,
+ encoding='rate',
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f" Model: {model_type}, params: {num_params:,}")
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ ce_loss = nn.CrossEntropyLoss()
+
+ metrics_history = []
+ best_acc = 0.0
+
+ for epoch in range(1, epochs + 1):
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm = train_epoch(
+ model, train_loader, optimizer, ce_loss, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_acc = max(best_acc, test_acc)
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=test_loss,
+ val_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)")
+
+ if np.isnan(train_loss):
+ print(" Training diverged!")
+ break
+
+ print(f" Best test accuracy: {best_acc:.3f}")
+ return metrics_history
+
+
+def run_comparison(
+ model_type: str,
+ channels_configs: List[List[int]],
+ T: int,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool,
+) -> Dict:
+ """Compare vanilla vs Lyapunov across different depths."""
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for channels in channels_configs:
+ depth = len(channels)
+ print(f"\n{'='*60}")
+ print(f"Depth = {depth} conv layers, channels = {channels}")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ metrics = run_experiment(
+ model_type=model_type,
+ channels=channels,
+ T=T,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=1e-4,
+ device=device,
+ seed=seed,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ return results
+
+
+def print_summary(results: Dict):
+ """Print comparison summary."""
+ print("\n" + "=" * 70)
+ print("SUMMARY: CIFAR-10 Conv-SNN Results")
+ print("=" * 70)
+ print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}")
+ print("-" * 70)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+
+ print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}")
+
+ print("=" * 70)
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ """Save results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+
+ # Model
+ p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"])
+ p.add_argument("--channels", type=int, nargs="+", default=None,
+ help="Channel sizes (default: test multiple depths)")
+ p.add_argument("--T", type=int, default=25, help="Timesteps")
+
+ # Training
+ p.add_argument("--epochs", type=int, default=50)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+
+ # Lyapunov
+ p.add_argument("--lambda_reg", type=float, default=0.3)
+ p.add_argument("--lambda_target", type=float, default=-0.1)
+
+ # Other
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--out_dir", type=str, default="runs/cifar10_conv")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print("CIFAR-10 Conv-SNN Experiment")
+ print("=" * 70)
+ print(f"Model: {args.model}")
+ print(f"Timesteps: {args.T}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Device: {device}")
+ print("=" * 70)
+
+ # Load data
+ print("\nLoading CIFAR-10...")
+ train_loader, test_loader = get_cifar10_loaders(
+ data_dir=args.data_dir,
+ batch_size=args.batch_size,
+ )
+ print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
+
+ # Define depth configurations to test
+ if args.channels:
+ channels_configs = [args.channels]
+ else:
+ # Test increasing depths
+ channels_configs = [
+ [64, 128], # 2 conv layers (shallow)
+ [64, 128, 256], # 3 conv layers
+ [64, 128, 256, 512], # 4 conv layers (deep)
+ ]
+
+ # Run comparison
+ results = run_comparison(
+ model_type=args.model,
+ channels_configs=channels_configs,
+ T=args.T,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Summary
+ print_summary(results)
+
+ # Save
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/depth_comparison.py b/files/experiments/depth_comparison.py
new file mode 100644
index 0000000..48c62d8
--- /dev/null
+++ b/files/experiments/depth_comparison.py
@@ -0,0 +1,542 @@
+"""
+Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths.
+
+Hypothesis:
+- Shallow networks (1-2 layers): Both methods train successfully
+- Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds
+
+Usage:
+ # Quick test (synthetic data)
+ python files/experiments/depth_comparison.py --synthetic --epochs 20
+
+ # Full experiment with SHD data
+ python files/experiments/depth_comparison.py --epochs 50
+
+ # Specific depths to test
+ python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+from files.analysis.stability_monitor import StabilityMonitor
+
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for a single experiment run."""
+ depth: int
+ hidden_dim: int
+ use_lyapunov: bool
+ lambda_reg: float
+ lambda_target: float
+ lyap_eps: float
+ epochs: int
+ lr: float
+ batch_size: int
+ beta: float
+ threshold: float
+ seed: int
+
+
+@dataclass
+class EpochMetrics:
+ """Metrics collected per epoch."""
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ firing_rate: float
+ dead_neurons: float
+ time_sec: float
+
+
+def create_synthetic_data(
+ n_train: int = 2000,
+ n_val: int = 500,
+ T: int = 50,
+ D: int = 100,
+ n_classes: int = 10,
+ seed: int = 42,
+) -> Tuple[DataLoader, DataLoader]:
+ """Create synthetic spike data for testing."""
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ def generate_data(n_samples):
+ # Generate class-conditional spike patterns
+ x = torch.zeros(n_samples, T, D)
+ y = torch.randint(0, n_classes, (n_samples,))
+
+ for i in range(n_samples):
+ label = y[i].item()
+ # Each class has different firing rate pattern
+ base_rate = 0.05 + 0.02 * label
+ # Class-specific channels fire more
+ class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes))
+ for t in range(T):
+ # Background activity
+ x[i, t] = (torch.rand(D) < base_rate).float()
+ # Enhanced activity for class-specific channels
+ for c in class_channels:
+ if torch.rand(1) < base_rate * 3:
+ x[i, t, c] = 1.0
+
+ return x, y
+
+ x_train, y_train = generate_data(n_train)
+ x_val, y_val = generate_data(n_val)
+
+ train_loader = DataLoader(
+ TensorDataset(x_train, y_train),
+ batch_size=64,
+ shuffle=True,
+ )
+ val_loader = DataLoader(
+ TensorDataset(x_val, y_val),
+ batch_size=64,
+ shuffle=False,
+ )
+
+ return train_loader, val_loader, T, D, n_classes
+
+
+def create_model(
+ input_dim: int,
+ num_classes: int,
+ depth: int,
+ hidden_dim: int = 128,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+) -> LyapunovSNN:
+ """Create SNN with specified depth."""
+ # Create hidden dims list based on depth
+ # Gradually decrease size for deeper networks to keep param count reasonable
+ hidden_dims = []
+ current_dim = hidden_dim
+ for i in range(depth):
+ hidden_dims.append(current_dim)
+ # Optionally decrease dim in deeper layers
+ # current_dim = max(64, current_dim // 2)
+
+ return LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=beta,
+ threshold=threshold,
+ )
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ monitor: StabilityMonitor,
+) -> Tuple[float, float, float, float, float, float]:
+ """Train for one epoch, return metrics."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+ firing_rates = []
+ dead_fracs = []
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, recordings = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ record_states=True,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ # Check for NaN
+ if torch.isnan(loss):
+ return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0
+
+ loss.backward()
+
+ # Gradient clipping for stability comparison fairness
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ optimizer.step()
+
+ # Collect metrics
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ # Stability metrics
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ if recordings is not None:
+ spikes = recordings['spikes']
+ fr = spikes.mean().item()
+ dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item()
+ firing_rates.append(fr)
+ dead_fracs.append(dead)
+
+ avg_loss = total_loss / total_samples
+ avg_acc = total_correct / total_samples
+ avg_lyap = np.mean(lyap_vals) if lyap_vals else None
+ avg_grad = np.mean(grad_norms)
+ avg_fr = np.mean(firing_rates) if firing_rates else 0.0
+ avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0
+
+ return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+) -> Tuple[float, float]:
+ """Evaluate model on validation set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+
+ loss = ce_loss(logits, y)
+ if torch.isnan(loss):
+ return float('nan'), 0.0
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_single_experiment(
+ config: ExperimentConfig,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ device: torch.device,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run a single experiment with given configuration."""
+ torch.manual_seed(config.seed)
+
+ model = create_model(
+ input_dim=input_dim,
+ num_classes=num_classes,
+ depth=config.depth,
+ hidden_dim=config.hidden_dim,
+ beta=config.beta,
+ threshold=config.threshold,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=config.lr)
+ ce_loss = nn.CrossEntropyLoss()
+ monitor = StabilityMonitor()
+
+ metrics_history = []
+ method = "Lyapunov" if config.use_lyapunov else "Vanilla"
+
+ iterator = range(1, config.epochs + 1)
+ if progress:
+ iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False)
+
+ for epoch in iterator:
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch(
+ model=model,
+ loader=train_loader,
+ optimizer=optimizer,
+ ce_loss=ce_loss,
+ device=device,
+ use_lyapunov=config.use_lyapunov,
+ lambda_reg=config.lambda_reg,
+ lambda_target=config.lambda_target,
+ lyap_eps=config.lyap_eps,
+ monitor=monitor,
+ )
+
+ val_loss, val_acc = evaluate(model, val_loader, ce_loss, device)
+
+ dt = time.time() - t0
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=val_loss,
+ val_acc=val_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ firing_rate=fr,
+ dead_neurons=dead,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ # Early stopping if training diverged
+ if np.isnan(train_loss):
+ print(f" Training diverged at epoch {epoch}")
+ break
+
+ return metrics_history
+
+
+def run_depth_comparison(
+ depths: List[int],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ device: torch.device,
+ epochs: int = 30,
+ hidden_dim: int = 128,
+ lr: float = 1e-3,
+ lambda_reg: float = 0.1,
+ lambda_target: float = 0.0,
+ lyap_eps: float = 1e-4,
+ beta: float = 0.9,
+ seed: int = 42,
+ progress: bool = True,
+) -> Dict[str, Dict[int, List[EpochMetrics]]]:
+ """
+ Run comparison experiments across depths.
+
+ Returns:
+ Dictionary with structure:
+ {
+ "vanilla": {1: [metrics...], 2: [metrics...], ...},
+ "lyapunov": {1: [metrics...], 2: [metrics...], ...}
+ }
+ """
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for depth in depths:
+ print(f"\n{'='*50}")
+ print(f"Depth = {depth} layers")
+ print(f"{'='*50}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ config = ExperimentConfig(
+ depth=depth,
+ hidden_dim=hidden_dim,
+ use_lyapunov=use_lyap,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=lyap_eps,
+ epochs=epochs,
+ lr=lr,
+ batch_size=64,
+ beta=beta,
+ threshold=1.0,
+ seed=seed,
+ )
+
+ metrics = run_single_experiment(
+ config=config,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ device=device,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ # Print final metrics
+ final = metrics[-1]
+ lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A"
+ print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} "
+ f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}")
+
+ return results
+
+
+def save_results(results: Dict, output_dir: str, config: dict):
+ """Save experiment results to JSON."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Convert metrics to dicts
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]):
+ """Print summary comparison table."""
+ print("\n" + "=" * 70)
+ print("SUMMARY: Final Validation Accuracy by Depth")
+ print("=" * 70)
+ print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}")
+ print("-" * 70)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van_metrics = results["vanilla"][depth]
+ lyap_metrics = results["lyapunov"][depth]
+
+ van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0
+ lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0
+
+ van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED"
+ lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED"
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}")
+
+ print("=" * 70)
+
+ # Gradient analysis
+ print("\nGradient Norm Analysis (final epoch):")
+ print("-" * 70)
+ print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}")
+ print("-" * 70)
+ for depth in depths:
+ van_grad = results["vanilla"][depth][-1].grad_norm
+ lyap_grad = results["lyapunov"][depth][-1].grad_norm
+ print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths")
+ p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6],
+ help="Network depths to test")
+ p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer")
+ p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment")
+ p.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent")
+ p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov")
+ p.add_argument("--beta", type=float, default=0.9, help="Membrane decay")
+ p.add_argument("--seed", type=int, default=42, help="Random seed")
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config")
+ p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--no-progress", action="store_true", help="Disable progress bars")
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print("Experiment: Vanilla vs Lyapunov-Regularized SNN")
+ print("=" * 70)
+ print(f"Depths: {args.depths}")
+ print(f"Hidden dim: {args.hidden_dim}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Lambda_reg: {args.lambda_reg}")
+ print(f"Device: {device}")
+
+ # Load data
+ if args.synthetic:
+ print("\nUsing SYNTHETIC data for quick testing")
+ train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed)
+ else:
+ print(f"\nLoading data from {args.cfg}")
+ from files.data_io.dataset_loader import get_dataloader
+ train_loader, val_loader = get_dataloader(args.cfg)
+ xb, _ = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = 20 # SHD has 20 classes
+
+ print(f"Data: T={T}, D={D}, classes={C}")
+
+ # Run experiments
+ results = run_depth_comparison(
+ depths=args.depths,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=D,
+ num_classes=C,
+ device=device,
+ epochs=args.epochs,
+ hidden_dim=args.hidden_dim,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ beta=args.beta,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Print summary
+ print_summary(results)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/depth_scaling_benchmark.py b/files/experiments/depth_scaling_benchmark.py
new file mode 100644
index 0000000..efab140
--- /dev/null
+++ b/files/experiments/depth_scaling_benchmark.py
@@ -0,0 +1,1035 @@
+"""
+Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs.
+
+Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve.
+
+Key hypothesis (from literature):
+- Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet)
+- Deep SNNs without regularization fail to train (gradient issues)
+- Deep SNNs WITH Lyapunov regularization achieve higher accuracy
+
+Reference results:
+- Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI)
+- SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS)
+- Spikformer on ImageNet: ~74.8% top-1 (arXiv)
+
+Usage:
+ python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+# =============================================================================
+# VGG-style Spiking Network (scalable depth)
+# =============================================================================
+
+class SpikingVGGBlock(nn.Module):
+ """Conv-BN-LIF block for VGG-style architecture."""
+
+ def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None):
+ super().__init__()
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+
+ def forward(self, x, mem):
+ h = self.bn(self.conv(x))
+ spk, mem = self.lif(h, mem)
+ return spk, mem
+
+
+class SpikingVGG(nn.Module):
+ """
+ Scalable VGG-style Spiking Neural Network.
+
+ Architecture follows VGG pattern:
+ - Multiple conv blocks between pooling layers
+ - Depth controlled by num_blocks_per_stage
+
+ Args:
+ in_channels: Input channels (3 for RGB)
+ num_classes: Output classes
+ base_channels: Starting channel count (doubled each stage)
+ num_stages: Number of pooling stages (3-4 typical)
+ blocks_per_stage: Conv blocks per stage (controls depth)
+ T: Number of timesteps
+ beta: LIF membrane decay
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ base_channels: int = 64,
+ num_stages: int = 3,
+ blocks_per_stage: int = 2,
+ T: int = 4,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ dropout: float = 0.25,
+ stable_init: bool = False,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+ self.total_conv_layers = num_stages * blocks_per_stage
+ self.stable_init = stable_init
+
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ # Build stages
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ ch_in = in_channels
+ ch_out = base_channels
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for b in range(blocks_per_stage):
+ block_in = ch_in if b == 0 else ch_out
+ stage_blocks.append(
+ SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad)
+ )
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ ch_in = ch_out
+ ch_out = min(ch_out * 2, 512) # Cap at 512
+
+ # Calculate spatial size after pooling
+ # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages)
+ final_spatial = 32 // (2 ** num_stages)
+ final_channels = min(base_channels * (2 ** (num_stages - 1)), 512)
+ fc_input = final_channels * final_spatial * final_spatial
+
+ # Classifier
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(fc_input, num_classes)
+
+ if stable_init:
+ self._init_weights_stable()
+ else:
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_weights_stable(self):
+ """
+ Stability-aware initialization for SNNs.
+
+ Uses smaller weight magnitudes to produce less chaotic initial dynamics.
+ The key insight: Lyapunov exponent depends on weight magnitudes.
+ Smaller weights → smaller gradients → more stable dynamics.
+
+ Strategy:
+ - Use orthogonal init (preserves gradient magnitude across layers)
+ - Scale down by factor of 0.5 to reduce initial chaos
+ - This should produce λ closer to 0 from the start
+ """
+ scale_factor = 0.5 # Reduce weight magnitudes
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # Orthogonal init for conv (reshape to 2D, init, reshape back)
+ weight_shape = m.weight.shape
+ fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3]
+ fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3]
+
+ # Use smaller gain for stability
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ with torch.no_grad():
+ m.weight.mul_(scale_factor)
+
+ elif isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight, gain=scale_factor)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_mems(self, batch_size, device, dtype, P=1):
+ """Initialize membrane potentials for all LIF layers.
+
+ Args:
+ batch_size: Batch size B
+ device: torch device
+ dtype: torch dtype
+ P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov)
+
+ Returns:
+ List of membrane tensors with shape (P, B, C, H, W)
+ """
+ mems = []
+ H, W = 32, 32
+ ch = 64
+
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ ch = min(ch * 2, 512)
+
+ return mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+ """
+ Forward pass with optimized Lyapunov computation (Approach A: trajectory batching).
+
+ When compute_lyapunov=True, both original and perturbed trajectories are
+ processed together by batching them along a new dimension P=2. This avoids
+ redundant computation, especially for the first conv layer where inputs are identical.
+
+ Args:
+ x: (B, C, H, W) static image (will be repeated for T steps)
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude
+
+ Returns:
+ logits, lyap_est, recordings
+ """
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+
+ # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed)
+ P = 2 if compute_lyapunov else 1
+
+ # Initialize membrane potentials with shape (P, B, C, H, W)
+ mems = self._init_mems(B, device, dtype, P=P)
+
+ # Initialize perturbed trajectory
+ if compute_lyapunov:
+ for i in range(len(mems)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0])
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = None
+
+ # Time loop - repeat static input
+ for t in range(self.T):
+ mem_idx = 0
+ new_mems = []
+ is_first_block = True
+
+ # Process through stages
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ # First block: input x is identical for both trajectories
+ # Compute conv+bn ONCE, then expand to (P, B, C, H, W)
+ h_conv = block.bn(block.conv(x)) # (B, C, H, W)
+ h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy
+
+ # LIF with batched membrane states
+ # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ spk_flat, mem_new_flat = block.lif(h_flat, mem_flat)
+
+ # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+
+ h = spk
+ new_mems.append(mem_new)
+ is_first_block = False
+ else:
+ # Subsequent blocks: inputs differ between trajectories
+ # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+
+ # Full block forward (conv+bn+lif)
+ h_conv = block.bn(block.conv(h_flat))
+ spk_flat, mem_new_flat = block.lif(h_conv, mem_flat)
+
+ # Reshape back
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+
+ h = spk
+ new_mems.append(mem_new)
+
+ mem_idx += 1
+
+ # Pool: apply to batched tensor
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ h_pooled = pool(h_flat)
+ h = h_pooled.view(P, B, *h_pooled.shape[1:])
+
+ mems = new_mems
+
+ # Accumulate final spikes from ORIGINAL trajectory only (index 0)
+ h_orig = h[0].view(B, -1) # (B, C*H*W)
+ if spike_sum is None:
+ spike_sum = h_orig
+ else:
+ spike_sum = spike_sum + h_orig
+
+ # Lyapunov divergence and renormalization (Option 1: global delta + global renorm)
+ # This is the textbook Benettin-style Lyapunov exponent estimator where
+ # the perturbation is treated as one vector in the concatenated state space.
+ if compute_lyapunov:
+ # Compute GLOBAL divergence across all layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W)
+ delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # GLOBAL renormalization: same scale factor for all layers
+ # This ensures ||perturbation||_global = eps after renorm
+ scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting
+
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ # Update perturbed trajectory: scale the diff to have global norm = eps
+ mems[i] = torch.stack([
+ new_mems[i][0],
+ new_mems[i][0] + diff * scale
+ ], dim=0)
+
+ # Readout
+ out = self.dropout(spike_sum)
+ logits = self.fc(out)
+
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est, None
+
+ @property
+ def depth(self):
+ return self.total_conv_layers
+
+
+# =============================================================================
+# Dataset Loading
+# =============================================================================
+
+def get_dataset(
+ name: str,
+ data_dir: str = './data',
+ batch_size: int = 128,
+ num_workers: int = 4,
+) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]:
+ """
+ Get train/test loaders for various datasets.
+
+ Returns:
+ train_loader, test_loader, num_classes, input_shape
+ """
+
+ if name == 'mnist':
+ transform = transforms.Compose([
+ transforms.Resize(32), # Resize to 32x32 for consistency
+ transforms.ToTensor(),
+ ])
+ train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+ test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
+ num_classes = 10
+ input_shape = (1, 32, 32)
+
+ elif name == 'fashion_mnist':
+ transform = transforms.Compose([
+ transforms.Resize(32),
+ transforms.ToTensor(),
+ ])
+ train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
+ test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
+ num_classes = 10
+ input_shape = (1, 32, 32)
+
+ elif name == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ])
+ transform_test = transforms.Compose([transforms.ToTensor()])
+ train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
+ test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test)
+ num_classes = 10
+ input_shape = (3, 32, 32)
+
+ elif name == 'cifar100':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ])
+ transform_test = transforms.Compose([transforms.ToTensor()])
+ train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
+ test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test)
+ num_classes = 100
+ input_shape = (3, 32, 32)
+
+ else:
+ raise ValueError(f"Unknown dataset: {name}")
+
+ train_loader = DataLoader(
+ train_ds, batch_size=batch_size, shuffle=True,
+ num_workers=num_workers, pin_memory=True
+ )
+ test_loader = DataLoader(
+ test_ds, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True
+ )
+
+ return train_loader, test_loader, num_classes, input_shape
+
+
+# =============================================================================
+# Training
+# =============================================================================
+
+@dataclass
+class TrainingMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ test_loss: float
+ test_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ grad_max_sv: Optional[float] # Max singular value of gradients
+ grad_min_sv: Optional[float] # Min singular value of gradients
+ grad_condition: Optional[float] # Condition number
+ lr: float
+ time_sec: float
+
+
+def compute_gradient_svs(model):
+ """Compute gradient singular value statistics for all weight matrices."""
+ max_svs = []
+ min_svs = []
+
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ with torch.no_grad():
+ G = param.grad.detach()
+ try:
+ sv = torch.linalg.svdvals(G)
+ if len(sv) > 0:
+ max_svs.append(sv[0].item())
+ min_svs.append(sv[-1].item())
+ except Exception:
+ pass
+
+ if not max_svs:
+ return None, None, None
+
+ max_sv = max(max_svs)
+ min_sv = min(min_svs)
+ condition = max_sv / (min_sv + 1e-12)
+
+ return max_sv, min_sv, condition
+
+
+def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float,
+ lyap_threshold: float = 2.0) -> torch.Tensor:
+ """
+ Compute Lyapunov regularization loss with different penalty types.
+
+ Args:
+ lyap_est: Estimated Lyapunov exponent (scalar tensor)
+ reg_type: Type of regularization:
+ - "squared": (λ - target)² - original, aggressive
+ - "hinge": max(0, λ - threshold)² - only penalize chaos
+ - "asymmetric": strong penalty for chaos, weak for collapse
+ - "extreme": only penalize when λ > lyap_threshold (configurable)
+ - "adaptive_linear": penalty scales linearly with excess over threshold
+ - "adaptive_exp": penalty grows exponentially for severe chaos
+ - "adaptive_sigmoid": smooth sigmoid transition around threshold
+ lambda_target: Target value (used for squared reg_type)
+ lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0)
+
+ Returns:
+ Regularization loss (scalar tensor)
+ """
+ if reg_type == "squared":
+ # Original: penalize any deviation from target
+ return (lyap_est - lambda_target) ** 2
+
+ elif reg_type == "hinge":
+ # Only penalize when λ > threshold (too chaotic)
+ # threshold = 0 means: only penalize positive Lyapunov (chaos)
+ threshold = 0.0
+ excess = torch.relu(lyap_est - threshold)
+ return excess ** 2
+
+ elif reg_type == "asymmetric":
+ # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1)
+ # This allows the network to be stable without being dead
+ chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0
+ collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1
+ return chaos_penalty + collapse_penalty
+
+ elif reg_type == "extreme":
+ # Only penalize when λ > threshold (VERY chaotic)
+ # This allows moderate chaos while preventing extreme instability
+ # Threshold is now configurable via lyap_threshold argument
+ excess = torch.relu(lyap_est - lyap_threshold)
+ return excess ** 2
+
+ elif reg_type == "adaptive_linear":
+ # Penalty scales linearly with how far above threshold we are
+ # loss = excess * excess² = excess³
+ # This naturally makes the penalty weaker for small excesses
+ # and much stronger for large excesses
+ excess = torch.relu(lyap_est - lyap_threshold)
+ return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic
+
+ elif reg_type == "adaptive_exp":
+ # Exponential penalty for severe chaos
+ # loss = (exp(excess) - 1) * excess² for excess > 0
+ # This gives very weak penalty near threshold, explosive penalty for chaos
+ excess = torch.relu(lyap_est - lyap_threshold)
+ # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after
+ exp_scale = torch.exp(excess) - 1.0
+ return exp_scale * excess # exp(excess) * excess - excess
+
+ elif reg_type == "adaptive_sigmoid":
+ # Smooth sigmoid transition around threshold
+ # The "sharpness" of transition is controlled by a temperature parameter
+ # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness
+ # Using T=0.5 for moderately sharp transition
+ temperature = 0.5
+ weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature)
+ # Penalize deviation from target, weighted by how far past threshold
+ deviation = lyap_est - lambda_target
+ return weight * (deviation ** 2)
+
+ # =========================================================================
+ # SCALED MULTIPLIER REGULARIZATION
+ # loss = (λ_reg × g(relu(λ))) × relu(λ)
+ # └─────────────────┘ └──────┘
+ # scaled multiplier penalty toward target=0
+ #
+ # The multiplier itself scales with λ, making it mild when λ is small
+ # and aggressive when λ is large.
+ # =========================================================================
+
+ elif reg_type == "mult_linear":
+ # Multiplier scales linearly: g(x) = x
+ # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)²
+ # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0
+ pos_lyap = torch.relu(lyap_est)
+ return pos_lyap * pos_lyap # relu(λ)²
+
+ elif reg_type == "mult_squared":
+ # Multiplier scales quadratically: g(x) = x²
+ # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³
+ # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0
+ pos_lyap = torch.relu(lyap_est)
+ return pos_lyap * pos_lyap * pos_lyap # relu(λ)³
+
+ elif reg_type == "mult_log":
+ # Multiplier scales logarithmically: g(x) = log(1+x)
+ # loss = (λ_reg × log(1+relu(λ))) × relu(λ)
+ # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16
+ pos_lyap = torch.relu(lyap_est)
+ return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ
+
+ else:
+ raise ValueError(f"Unknown reg_type: {reg_type}")
+
+
+def train_epoch(
+ model, loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps,
+ progress=True, compute_sv_every=10,
+ reg_type="squared", current_lambda_reg=None,
+ lyap_threshold=2.0
+):
+ """
+ Train one epoch.
+
+ Args:
+ current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg.
+ reg_type: "squared", "hinge", "asymmetric", or "extreme"
+ lyap_threshold: Threshold for extreme reg_type
+ """
+ model.train()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+ lyap_vals = []
+ grad_norms = []
+ grad_max_svs = []
+ grad_min_svs = []
+ grad_conditions = []
+
+ # Use warmup value if provided
+ effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for batch_idx, (x, y) in enumerate(iterator):
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps)
+
+ loss = criterion(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold)
+ loss = loss + effective_lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan'), None, None, None
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5
+ grad_norms.append(grad_norm)
+
+ # Compute gradient SVs periodically (expensive)
+ if batch_idx % compute_sv_every == 0:
+ max_sv, min_sv, cond = compute_gradient_svs(model)
+ if max_sv is not None:
+ grad_max_svs.append(max_sv)
+ grad_min_svs.append(min_sv)
+ grad_conditions.append(cond)
+
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ return (
+ total_loss / total,
+ correct / total,
+ np.mean(lyap_vals) if lyap_vals else None,
+ np.mean(grad_norms),
+ np.mean(grad_max_svs) if grad_max_svs else None,
+ np.mean(grad_min_svs) if grad_min_svs else None,
+ np.mean(grad_conditions) if grad_conditions else None,
+ )
+
+
+@torch.no_grad()
+def evaluate(model, loader, criterion, device, progress=True):
+ model.eval()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+
+ iterator = tqdm(loader, desc="eval", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+
+ loss = criterion(logits, y)
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ return total_loss / total, correct / total
+
+
+def run_single_config(
+ dataset_name: str,
+ depth_config: Tuple[int, int], # (num_stages, blocks_per_stage)
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+ reg_type: str = "squared",
+ warmup_epochs: int = 0,
+ stable_init: bool = False,
+ lyap_threshold: float = 2.0,
+) -> List[TrainingMetrics]:
+ """Run training for a single configuration."""
+ torch.manual_seed(seed)
+
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ model = SpikingVGG(
+ in_channels=in_channels,
+ num_classes=num_classes,
+ base_channels=64,
+ num_stages=num_stages,
+ blocks_per_stage=blocks_per_stage,
+ T=T,
+ stable_init=stable_init,
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ method = "Lyapunov" if use_lyapunov else "Vanilla"
+ print(f" {method}: depth={total_depth}, params={num_params:,}")
+
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ criterion = nn.CrossEntropyLoss()
+
+ history = []
+ best_acc = 0.0
+
+ for epoch in range(1, epochs + 1):
+ t0 = time.time()
+
+ # Warmup: gradually increase lambda_reg
+ if warmup_epochs > 0 and epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, lambda_target, 1e-4, progress,
+ reg_type=reg_type, current_lambda_reg=current_lambda_reg,
+ lyap_threshold=lyap_threshold
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_acc = max(best_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == epochs:
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else ""
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}")
+
+ if np.isnan(train_loss):
+ print(f" DIVERGED at epoch {epoch}")
+ break
+
+ print(f" Best test acc: {best_acc:.3f}")
+ return history
+
+
+def run_depth_scaling_experiment(
+ dataset_name: str,
+ depth_configs: List[Tuple[int, int]],
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool,
+ reg_type: str = "squared",
+ warmup_epochs: int = 0,
+ stable_init: bool = False,
+ lyap_threshold: float = 2.0,
+) -> Dict:
+ """Run full depth scaling experiment."""
+
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ print(f"Regularization type: {reg_type}")
+ print(f"Warmup epochs: {warmup_epochs}")
+ print(f"Stable init: {stable_init}")
+ print(f"Lyapunov threshold: {lyap_threshold}")
+
+ for depth_config in depth_configs:
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ print(f"\n{'='*60}")
+ print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+
+ history = run_single_config(
+ dataset_name=dataset_name,
+ depth_config=depth_config,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=T,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ device=device,
+ seed=seed,
+ progress=progress,
+ reg_type=reg_type,
+ warmup_epochs=warmup_epochs,
+ stable_init=stable_init,
+ lyap_threshold=lyap_threshold,
+ )
+
+ results[method][total_depth] = history
+
+ return results
+
+
+def print_summary(results: Dict, dataset_name: str):
+ """Print final summary table."""
+ print("\n" + "=" * 100)
+ print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}")
+ print("=" * 100)
+ print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}")
+ print("-" * 100)
+
+ depths = sorted(results["vanilla"].keys())
+
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}"
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+ lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A"
+
+ van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A"
+ lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A"
+ van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A"
+
+ print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}")
+
+ print("=" * 100)
+
+ # Gradient health analysis
+ print("\nGRADIENT HEALTH ANALYSIS:")
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_cond = van.grad_condition if van.grad_condition else 0
+ lyap_cond = lyap.grad_condition if lyap.grad_condition else 0
+
+ status = ""
+ if van_cond > 1e6:
+ status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)"
+ elif van_cond > 1e4:
+ status = "~ Vanilla has moderately ill-conditioned gradients"
+
+ if status:
+ print(f" Depth {depth}: {status}")
+
+ print("")
+
+ # Analysis
+ print("\nKEY OBSERVATIONS:")
+ shallow = min(depths)
+ deep = max(depths)
+
+ van_shallow = results["vanilla"][shallow][-1].test_acc
+ van_deep = results["vanilla"][deep][-1].test_acc
+ lyap_shallow = results["lyapunov"][shallow][-1].test_acc
+ lyap_deep = results["lyapunov"][deep][-1].test_acc
+
+ van_gain = van_deep - van_shallow
+ lyap_gain = lyap_deep - lyap_shallow
+
+ print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change")
+ print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change")
+
+ if lyap_gain > van_gain + 0.02:
+ print(f" ✓ Lyapunov regularization enables better depth scaling!")
+ elif lyap_gain > van_gain:
+ print(f" ~ Lyapunov shows slight improvement in depth scaling")
+ else:
+ print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range")
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, history in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in history]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs")
+
+ p.add_argument("--dataset", type=str, default="cifar100",
+ choices=["mnist", "fashion_mnist", "cifar10", "cifar100"])
+ p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16],
+ help="Total conv layer depths to test")
+ p.add_argument("--T", type=int, default=4, help="Timesteps")
+ p.add_argument("--epochs", type=int, default=100)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--lambda_reg", type=float, default=0.3)
+ p.add_argument("--lambda_target", type=float, default=-0.1)
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--out_dir", type=str, default="runs/depth_scaling")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--no-progress", action="store_true")
+ p.add_argument("--reg_type", type=str, default="squared",
+ choices=["squared", "hinge", "asymmetric", "extreme"],
+ help="Lyapunov regularization type")
+ p.add_argument("--warmup_epochs", type=int, default=0,
+ help="Epochs to warmup lambda_reg (0 = no warmup)")
+ p.add_argument("--stable_init", action="store_true",
+ help="Use stability-aware weight initialization")
+ p.add_argument("--lyap_threshold", type=float, default=2.0,
+ help="Threshold for extreme reg_type (only penalize λ > threshold)")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("DEPTH SCALING BENCHMARK")
+ print("=" * 80)
+ print(f"Dataset: {args.dataset}")
+ print(f"Depths: {args.depths}")
+ print(f"Timesteps: {args.T}")
+ print(f"Epochs: {args.epochs}")
+ print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}")
+ print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Load data
+ print(f"\nLoading {args.dataset}...")
+ train_loader, test_loader, num_classes, input_shape = get_dataset(
+ args.dataset, args.data_dir, args.batch_size
+ )
+ in_channels = input_shape[0]
+ print(f"Classes: {num_classes}, Input: {input_shape}")
+ print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
+
+ # Convert depths to (num_stages, blocks_per_stage) configs
+ # We use 4 stages (3 for smaller nets), adjust blocks_per_stage
+ depth_configs = []
+ for d in args.depths:
+ if d <= 4:
+ depth_configs.append((d, 1)) # d stages, 1 block each
+ elif d <= 8:
+ depth_configs.append((4, d // 4)) # 4 stages
+ else:
+ depth_configs.append((4, d // 4)) # 4 stages, more blocks
+
+ print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}")
+
+ # Run experiment
+ results = run_depth_scaling_experiment(
+ dataset_name=args.dataset,
+ depth_configs=depth_configs,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=args.T,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ reg_type=args.reg_type,
+ warmup_epochs=args.warmup_epochs,
+ stable_init=args.stable_init,
+ lyap_threshold=args.lyap_threshold,
+ )
+
+ # Summary
+ print_summary(results, args.dataset)
+
+ # Save
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}")
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/hyperparameter_grid_search.py b/files/experiments/hyperparameter_grid_search.py
new file mode 100644
index 0000000..011387f
--- /dev/null
+++ b/files/experiments/hyperparameter_grid_search.py
@@ -0,0 +1,597 @@
+"""
+Hyperparameter Grid Search for Lyapunov-Regularized SNNs.
+
+Goal: Find optimal (lambda_reg, lambda_target) for each network depth
+ and derive an adaptive curve for automatic hyperparameter selection.
+
+Usage:
+ python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Tuple
+from itertools import product
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+
+
+@dataclass
+class GridSearchResult:
+ """Result from a single grid search configuration."""
+ depth: int
+ lambda_reg: float
+ lambda_target: float
+ final_train_acc: float
+ final_val_acc: float
+ final_lyapunov: float
+ final_grad_norm: float
+ converged: bool # Did training succeed (not NaN)?
+ epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never)
+
+
+def create_synthetic_data(
+ n_train: int = 2000,
+ n_val: int = 500,
+ T: int = 50,
+ D: int = 100,
+ n_classes: int = 10,
+ seed: int = 42,
+ batch_size: int = 128,
+) -> Tuple[DataLoader, DataLoader, int, int, int]:
+ """Create synthetic spike data."""
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ def generate_data(n_samples):
+ x = torch.zeros(n_samples, T, D)
+ y = torch.randint(0, n_classes, (n_samples,))
+ for i in range(n_samples):
+ label = y[i].item()
+ base_rate = 0.05 + 0.02 * label
+ class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes))
+ for t in range(T):
+ x[i, t] = (torch.rand(D) < base_rate).float()
+ for c in class_channels:
+ if torch.rand(1) < base_rate * 3:
+ x[i, t, c] = 1.0
+ return x, y
+
+ x_train, y_train = generate_data(n_train)
+ x_val, y_val = generate_data(n_val)
+
+ train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
+ val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False)
+
+ return train_loader, val_loader, T, D, n_classes
+
+
+def train_and_evaluate(
+ depth: int,
+ lambda_reg: float,
+ lambda_target: float,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ device: torch.device,
+ seed: int = 42,
+ warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early
+) -> GridSearchResult:
+ """Train a single configuration and return results."""
+ torch.manual_seed(seed)
+
+ # Create model
+ hidden_dims = [hidden_dim] * depth
+ model = LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=0.9,
+ threshold=1.0,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ best_val_acc = 0.0
+ epochs_to_90 = -1
+ final_lyap = 0.0
+ final_grad = 0.0
+ converged = True
+
+ for epoch in range(1, epochs + 1):
+ # Warmup: gradually increase lambda_reg
+ if epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ # Training
+ model.train()
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False)
+
+ ce = ce_loss(logits, y)
+ if lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + current_lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ converged = False
+ break
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+ optimizer.step()
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ if not converged:
+ break
+
+ train_acc = total_correct / total_samples
+ final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0
+ final_grad = np.mean(grad_norms) if grad_norms else 0.0
+
+ # Track epochs to 90% accuracy
+ if epochs_to_90 < 0 and train_acc >= 0.9:
+ epochs_to_90 = epoch
+
+ # Validation
+ model.eval()
+ val_correct = 0
+ val_total = 0
+ with torch.no_grad():
+ for x, y in val_loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+ preds = logits.argmax(dim=1)
+ val_correct += (preds == y).sum().item()
+ val_total += x.size(0)
+ val_acc = val_correct / val_total
+ best_val_acc = max(best_val_acc, val_acc)
+
+ return GridSearchResult(
+ depth=depth,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ final_train_acc=train_acc if converged else 0.0,
+ final_val_acc=best_val_acc if converged else 0.0,
+ final_lyapunov=final_lyap,
+ final_grad_norm=final_grad,
+ converged=converged,
+ epochs_to_90pct=epochs_to_90,
+ )
+
+
+def run_grid_search(
+ depths: List[int],
+ lambda_regs: List[float],
+ lambda_targets: List[float],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ device: torch.device,
+ seed: int = 42,
+ progress: bool = True,
+) -> List[GridSearchResult]:
+ """Run full grid search."""
+ results = []
+
+ # Total configurations
+ configs = list(product(depths, lambda_regs, lambda_targets))
+ total = len(configs)
+
+ iterator = tqdm(configs, desc="Grid Search", disable=not progress)
+
+ for depth, lambda_reg, lambda_target in iterator:
+ if progress:
+ iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target})
+
+ result = train_and_evaluate(
+ depth=depth,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ hidden_dim=hidden_dim,
+ epochs=epochs,
+ lr=lr,
+ device=device,
+ seed=seed,
+ )
+ results.append(result)
+
+ if progress:
+ iterator.set_postfix({
+ "d": depth,
+ "λr": lambda_reg,
+ "λt": lambda_target,
+ "acc": f"{result.final_val_acc:.2f}"
+ })
+
+ return results
+
+
+def analyze_results(results: List[GridSearchResult]) -> Dict:
+ """Analyze grid search results and find optimal hyperparameters per depth."""
+
+ # Group by depth
+ by_depth = {}
+ for r in results:
+ if r.depth not in by_depth:
+ by_depth[r.depth] = []
+ by_depth[r.depth].append(r)
+
+ analysis = {
+ "optimal_per_depth": {},
+ "all_results": [asdict(r) for r in results],
+ }
+
+ print("\n" + "=" * 80)
+ print("GRID SEARCH ANALYSIS")
+ print("=" * 80)
+
+ # Find optimal for each depth
+ print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}")
+ print("-" * 80)
+
+ optimal_lambda_regs = []
+ optimal_lambda_targets = []
+ depths_list = []
+
+ for depth in sorted(by_depth.keys()):
+ depth_results = by_depth[depth]
+ # Find best by validation accuracy
+ best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0)
+
+ analysis["optimal_per_depth"][depth] = {
+ "lambda_reg": best.lambda_reg,
+ "lambda_target": best.lambda_target,
+ "val_acc": best.final_val_acc,
+ "lyapunov": best.final_lyapunov,
+ "epochs_to_90": best.epochs_to_90pct,
+ }
+
+ print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} "
+ f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}")
+
+ if best.final_val_acc > 0.5: # Only use successful runs for curve fitting
+ depths_list.append(depth)
+ optimal_lambda_regs.append(best.lambda_reg)
+ optimal_lambda_targets.append(best.lambda_target)
+
+ # Fit adaptive curves
+ print("\n" + "=" * 80)
+ print("ADAPTIVE HYPERPARAMETER CURVES")
+ print("=" * 80)
+
+ if len(depths_list) >= 3:
+ # Fit polynomial curves
+ depths_arr = np.array(depths_list)
+ lambda_regs_arr = np.array(optimal_lambda_regs)
+ lambda_targets_arr = np.array(optimal_lambda_targets)
+
+ # Fit lambda_reg vs depth (expect increasing with depth)
+ try:
+ reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1))
+ reg_poly = np.poly1d(reg_coeffs)
+ print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}"
+ if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}")
+ analysis["lambda_reg_curve"] = reg_coeffs.tolist()
+ except Exception as e:
+ print(f"Could not fit λ_reg curve: {e}")
+
+ # Fit lambda_target vs depth (expect decreasing / more negative with depth)
+ try:
+ target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1))
+ target_poly = np.poly1d(target_coeffs)
+ print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}"
+ if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}")
+ analysis["lambda_target_curve"] = target_coeffs.tolist()
+ except Exception as e:
+ print(f"Could not fit λ_target curve: {e}")
+
+ # Print recommendations
+ print("\n" + "-" * 80)
+ print("RECOMMENDED HYPERPARAMETERS BY DEPTH:")
+ print("-" * 80)
+ for d in [2, 4, 6, 8, 10, 12, 14, 16]:
+ rec_reg = max(0.01, reg_poly(d))
+ rec_target = min(0.0, target_poly(d))
+ print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}")
+
+ else:
+ print("Not enough successful runs to fit curves")
+
+ return analysis
+
+
+def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict):
+ """Save grid search results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f:
+ json.dump(analysis, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def plot_grid_search(results: List[GridSearchResult], output_dir: str):
+ """Generate visualization of grid search results."""
+ try:
+ import matplotlib.pyplot as plt
+ except ImportError:
+ print("matplotlib not available, skipping plots")
+ return
+
+ # Group by depth
+ by_depth = {}
+ for r in results:
+ if r.depth not in by_depth:
+ by_depth[r.depth] = []
+ by_depth[r.depth].append(r)
+
+ depths = sorted(by_depth.keys())
+
+ # Get unique lambda values
+ lambda_regs = sorted(set(r.lambda_reg for r in results))
+ lambda_targets = sorted(set(r.lambda_target for r in results))
+
+ # Create heatmaps for each depth
+ n_depths = len(depths)
+ fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10))
+ axes = axes.flatten()
+
+ for idx, depth in enumerate(depths):
+ ax = axes[idx]
+ depth_results = by_depth[depth]
+
+ # Create accuracy matrix
+ acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs)))
+ for r in depth_results:
+ i = lambda_targets.index(r.lambda_target)
+ j = lambda_regs.index(r.lambda_reg)
+ acc_matrix[i, j] = r.final_val_acc
+
+ im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
+ ax.set_xticks(range(len(lambda_regs)))
+ ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45)
+ ax.set_yticks(range(len(lambda_targets)))
+ ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets])
+ ax.set_xlabel("λ_reg")
+ ax.set_ylabel("λ_target")
+ ax.set_title(f"Depth {depth}")
+
+ # Mark best
+ best = max(depth_results, key=lambda r: r.final_val_acc)
+ bi = lambda_targets.index(best.lambda_target)
+ bj = lambda_regs.index(best.lambda_reg)
+ ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2)
+
+ # Add colorbar
+ plt.colorbar(im, ax=ax, label='Val Acc')
+
+ # Hide unused subplots
+ for idx in range(len(depths), len(axes)):
+ axes[idx].axis('off')
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight')
+ plt.close()
+
+ # Plot optimal hyperparameters vs depth
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
+
+ optimal_regs = []
+ optimal_targets = []
+ optimal_accs = []
+ for depth in depths:
+ best = max(by_depth[depth], key=lambda r: r.final_val_acc)
+ optimal_regs.append(best.lambda_reg)
+ optimal_targets.append(best.lambda_target)
+ optimal_accs.append(best.final_val_acc)
+
+ axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8)
+ axes[0].set_xlabel("Network Depth")
+ axes[0].set_ylabel("Optimal λ_reg")
+ axes[0].set_title("Optimal Regularization Strength vs Depth")
+ axes[0].grid(True, alpha=0.3)
+
+ axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange')
+ axes[1].set_xlabel("Network Depth")
+ axes[1].set_ylabel("Optimal λ_target")
+ axes[1].set_title("Optimal Target Lyapunov vs Depth")
+ axes[1].grid(True, alpha=0.3)
+
+ axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green')
+ axes[2].set_xlabel("Network Depth")
+ axes[2].set_ylabel("Best Validation Accuracy")
+ axes[2].set_title("Best Achievable Accuracy vs Depth")
+ axes[2].set_ylim(0, 1.05)
+ axes[2].grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight')
+ plt.close()
+
+ print(f"Plots saved to {output_dir}")
+
+
+def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'):
+ """Get CIFAR-10 dataloaders with rate encoding for SNN."""
+ from torchvision import datasets, transforms
+
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)
+ val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform)
+
+ # Rate encoding: convert images to spike sequences
+ class RateEncodedDataset(torch.utils.data.Dataset):
+ def __init__(self, dataset, T):
+ self.dataset = dataset
+ self.T = T
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ img, label = self.dataset[idx]
+ # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D)
+ flat = img.view(-1) # (3072,)
+ # Rate encoding: probability of spike = pixel intensity
+ spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float()
+ return spikes, label
+
+ train_encoded = RateEncodedDataset(train_ds, T)
+ val_encoded = RateEncodedDataset(val_ds, T)
+
+ train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4)
+
+ return train_loader, val_loader, T, 3072, 10 # T, D, num_classes
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN")
+
+ # Grid search parameters
+ p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10],
+ help="Network depths to test")
+ p.add_argument("--lambda_regs", type=float, nargs="+",
+ default=[0.01, 0.05, 0.1, 0.2, 0.3],
+ help="Lambda_reg values to test")
+ p.add_argument("--lambda_targets", type=float, nargs="+",
+ default=[0.0, -0.05, -0.1, -0.2],
+ help="Lambda_target values to test")
+
+ # Model parameters
+ p.add_argument("--hidden_dim", type=int, default=256)
+ p.add_argument("--epochs", type=int, default=15)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--seed", type=int, default=42)
+
+ # Data
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)")
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding")
+
+ # Output
+ p.add_argument("--out_dir", type=str, default="runs/grid_search")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("HYPERPARAMETER GRID SEARCH")
+ print("=" * 80)
+ print(f"Depths: {args.depths}")
+ print(f"λ_reg values: {args.lambda_regs}")
+ print(f"λ_target values: {args.lambda_targets}")
+ print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}")
+ print(f"Epochs per config: {args.epochs}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Load data
+ if args.synthetic:
+ print("\nUsing synthetic data")
+ train_loader, val_loader, T, D, C = create_synthetic_data(
+ seed=args.seed, batch_size=args.batch_size
+ )
+ else:
+ print("\nUsing CIFAR-10 with rate encoding")
+ train_loader, val_loader, T, D, C = get_cifar10_loaders(
+ batch_size=args.batch_size,
+ T=args.T,
+ data_dir=args.data_dir
+ )
+
+ print(f"Data: T={T}, D={D}, classes={C}\n")
+
+ # Run grid search
+ results = run_grid_search(
+ depths=args.depths,
+ lambda_regs=args.lambda_regs,
+ lambda_targets=args.lambda_targets,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=D,
+ num_classes=C,
+ hidden_dim=args.hidden_dim,
+ epochs=args.epochs,
+ lr=args.lr,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Analyze results
+ analysis = analyze_results(results)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, analysis, output_dir, vars(args))
+
+ # Generate plots
+ plot_grid_search(results, output_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/lyapunov_diffonly_benchmark.py b/files/experiments/lyapunov_diffonly_benchmark.py
new file mode 100644
index 0000000..05dbcd2
--- /dev/null
+++ b/files/experiments/lyapunov_diffonly_benchmark.py
@@ -0,0 +1,590 @@
+"""
+Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation.
+
+Optimization B: Instead of storing two full membrane trajectories:
+ mems[i][0] = base trajectory
+ mems[i][1] = perturbed trajectory
+
+Store only:
+ base_mems[i] = base trajectory
+ delta_mems[i] = perturbation (perturbed - base)
+
+Benefits:
+ - ~2x less memory for membrane states
+ - Fewer memory reads/writes during renormalization
+ - Better cache utilization
+"""
+
+import os
+import sys
+import time
+import torch
+import torch.nn as nn
+from typing import Tuple, Optional, List
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+class SpikingVGGBlock(nn.Module):
+ """Conv-BN-LIF block."""
+
+ def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None):
+ super().__init__()
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+
+ def forward(self, x, mem):
+ h = self.bn(self.conv(x))
+ spk, mem = self.lif(h, mem)
+ return spk, mem
+
+
+class SpikingVGG_Original(nn.Module):
+ """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W)."""
+
+ def __init__(self, in_channels=3, num_classes=100, base_channels=64,
+ num_stages=3, blocks_per_stage=2, T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+
+ # Build stages
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ in_ch = in_channels
+ out_ch = base_channels
+ current_size = 32 # CIFAR
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for _ in range(blocks_per_stage):
+ stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta))
+ in_ch = out_ch
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ current_size //= 2
+ if stage < num_stages - 1:
+ out_ch = min(out_ch * 2, 512)
+
+ self.fc = nn.Linear(in_ch * current_size * current_size, num_classes)
+ self._channel_sizes = self._compute_channel_sizes(base_channels)
+
+ def _compute_channel_sizes(self, base):
+ sizes = []
+ ch = base
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ sizes.append(ch)
+ if stage < self.num_stages - 1:
+ ch = min(ch * 2, 512)
+ return sizes
+
+ def _init_mems(self, batch_size, device, dtype, P=1):
+ mems = []
+ H, W = 32, 32
+ for stage in range(self.num_stages):
+ for block_idx in range(self.blocks_per_stage):
+ layer_idx = stage * self.blocks_per_stage + block_idx
+ ch = self._channel_sizes[layer_idx]
+ mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ return mems
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ P = 2 if compute_lyapunov else 1
+
+ mems = self._init_mems(B, device, dtype, P=P)
+
+ if compute_lyapunov:
+ for i in range(len(mems)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0])
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = None
+
+ for t in range(self.T):
+ mem_idx = 0
+ new_mems = []
+ is_first_block = True
+
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ h_conv = block.bn(block.conv(x))
+ h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ spk_flat, mem_new_flat = block.lif(h_flat, mem_flat)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+ h = spk
+ new_mems.append(mem_new)
+ is_first_block = False
+ else:
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ h_conv = block.bn(block.conv(h_flat))
+ spk_flat, mem_new_flat = block.lif(h_conv, mem_flat)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+ h = spk
+ new_mems.append(mem_new)
+ mem_idx += 1
+
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ h_pooled = pool(h_flat)
+ h = h_pooled.view(P, B, *h_pooled.shape[1:])
+
+ mems = new_mems
+
+ h_orig = h[0].view(B, -1)
+ if spike_sum is None:
+ spike_sum = h_orig
+ else:
+ spike_sum = spike_sum + h_orig
+
+ if compute_lyapunov:
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ scale = (lyap_eps / delta).view(B, 1, 1, 1)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ mems[i] = torch.stack([
+ new_mems[i][0],
+ new_mems[i][0] + diff * scale
+ ], dim=0)
+
+ logits = self.fc(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+class SpikingVGG_DiffOnly(nn.Module):
+ """
+ Optimized implementation: stores base + diff instead of 2 full trajectories.
+
+ Memory layout:
+ base_mems[i]: (B, C, H, W) - base trajectory membrane
+ delta_mems[i]: (B, C, H, W) - perturbation vector
+
+ Perturbed trajectory is materialized as (base + delta) only when needed.
+ """
+
+ def __init__(self, in_channels=3, num_classes=100, base_channels=64,
+ num_stages=3, blocks_per_stage=2, T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ in_ch = in_channels
+ out_ch = base_channels
+ current_size = 32
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for _ in range(blocks_per_stage):
+ stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta))
+ in_ch = out_ch
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ current_size //= 2
+ if stage < num_stages - 1:
+ out_ch = min(out_ch * 2, 512)
+
+ self.fc = nn.Linear(in_ch * current_size * current_size, num_classes)
+ self._channel_sizes = self._compute_channel_sizes(base_channels)
+
+ def _compute_channel_sizes(self, base):
+ sizes = []
+ ch = base
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ sizes.append(ch)
+ if stage < self.num_stages - 1:
+ ch = min(ch * 2, 512)
+ return sizes
+
+ def _init_mems(self, batch_size, device, dtype):
+ """Initialize base membrane states (B, C, H, W)."""
+ base_mems = []
+ H, W = 32, 32
+ for stage in range(self.num_stages):
+ for block_idx in range(self.blocks_per_stage):
+ layer_idx = stage * self.blocks_per_stage + block_idx
+ ch = self._channel_sizes[layer_idx]
+ base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ return base_mems
+
+ def _init_deltas(self, base_mems, lyap_eps):
+ """Initialize perturbation vectors δ with ||δ||_global = eps."""
+ delta_mems = []
+ for base in base_mems:
+ delta_mems.append(lyap_eps * torch.randn_like(base))
+ return delta_mems
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+
+ # Initialize base membrane states
+ base_mems = self._init_mems(B, device, dtype)
+
+ # Initialize perturbations if computing Lyapunov
+ if compute_lyapunov:
+ delta_mems = self._init_deltas(base_mems, lyap_eps)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+ else:
+ delta_mems = None
+
+ spike_sum = None
+
+ for t in range(self.T):
+ mem_idx = 0
+ new_base_mems = []
+ new_delta_mems = [] if compute_lyapunov else None
+
+ # Track spikes for base and perturbed (if computing Lyapunov)
+ h_base = None
+ h_delta = None # Will store (h_perturbed - h_base)
+ is_first_block = True
+
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ # First block: input x is same for both trajectories
+ h_conv = block.bn(block.conv(x)) # (B, C, H, W)
+
+ # Base trajectory
+ spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx])
+ new_base_mems.append(mem_base_new)
+ h_base = spk_base
+
+ if compute_lyapunov:
+ # Perturbed trajectory: mem = base + delta
+ mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx]
+ spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed)
+ # Store delta for new membrane
+ new_delta_mems.append(mem_perturbed_new - mem_base_new)
+ # Store spike difference for propagation
+ h_delta = spk_perturbed - spk_base
+
+ is_first_block = False
+ else:
+ # Subsequent blocks: inputs differ
+ # Base trajectory
+ h_conv_base = block.bn(block.conv(h_base))
+ spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx])
+ new_base_mems.append(mem_base_new)
+
+ if compute_lyapunov:
+ # Perturbed trajectory: h_perturbed = h_base + h_delta
+ h_perturbed = h_base + h_delta
+ h_conv_perturbed = block.bn(block.conv(h_perturbed))
+ mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx]
+ spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed)
+ new_delta_mems.append(mem_perturbed_new - mem_base_new)
+ h_delta = spk_perturbed - spk_base
+
+ h_base = spk_base
+
+ mem_idx += 1
+
+ # Pooling
+ h_base = pool(h_base)
+ if compute_lyapunov:
+ # Pool both and compute new delta
+ h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general
+ # But for AvgPool, it's linear so this is fine
+ h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool
+ h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear)
+
+ # Update membrane states
+ base_mems = new_base_mems
+
+ # Accumulate spikes from base trajectory
+ h_flat = h_base.view(B, -1)
+ if spike_sum is None:
+ spike_sum = h_flat
+ else:
+ spike_sum = spike_sum + h_flat
+
+ # Lyapunov: compute global divergence and renormalize
+ if compute_lyapunov:
+ # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||²
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for delta in new_delta_mems:
+ delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3))
+
+ delta_norm = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12)
+
+ # Renormalize: scale all deltas so ||δ||_global = eps
+ scale = (lyap_eps / delta_norm).view(B, 1, 1, 1)
+ delta_mems = [delta * scale for delta in new_delta_mems]
+
+ logits = self.fc(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters())
+
+
+def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20):
+ """Benchmark forward pass time."""
+ device = x.device
+
+ # Warmup
+ for _ in range(num_warmup):
+ with torch.no_grad():
+ _ = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+
+ # Timed runs
+ times = []
+ for _ in range(num_runs):
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ times.append(end - start)
+
+ return times, lyap
+
+
+def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov,
+ lambda_reg=0.3, num_warmup=5, num_runs=20):
+ """Benchmark forward + backward pass time."""
+ device = x.device
+
+ # Warmup
+ for _ in range(num_warmup):
+ model.zero_grad()
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+ loss = criterion(logits, y)
+ if compute_lyapunov and lyap is not None:
+ loss = loss + lambda_reg * (lyap ** 2)
+ loss.backward()
+
+ torch.cuda.synchronize()
+
+ # Timed runs
+ times = []
+ for _ in range(num_runs):
+ model.zero_grad()
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+ loss = criterion(logits, y)
+ if compute_lyapunov and lyap is not None:
+ loss = loss + lambda_reg * (lyap ** 2)
+ loss.backward()
+
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ times.append(end - start)
+
+ return times
+
+
+def measure_memory(model, x, compute_lyapunov):
+ """Measure peak GPU memory during forward pass."""
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.synchronize()
+
+ with torch.no_grad():
+ _ = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+ peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB
+ return peak_mem
+
+
+def run_benchmark():
+ print("=" * 70)
+ print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage")
+ print("=" * 70)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+
+ if device.type == "cuda":
+ print(f"GPU: {torch.cuda.get_device_name()}")
+
+ # Test configurations
+ configs = [
+ {"depth": 4, "blocks_per_stage": 1, "batch_size": 64},
+ {"depth": 8, "blocks_per_stage": 2, "batch_size": 64},
+ {"depth": 12, "blocks_per_stage": 3, "batch_size": 32},
+ ]
+
+ print("\n" + "=" * 70)
+
+ for cfg in configs:
+ depth = cfg["depth"]
+ blocks = cfg["blocks_per_stage"]
+ batch_size = cfg["batch_size"]
+
+ print(f"\n{'='*70}")
+ print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}")
+ print(f"{'='*70}")
+
+ # Create models
+ model_orig = SpikingVGG_Original(
+ blocks_per_stage=blocks, T=4
+ ).to(device)
+
+ model_diff = SpikingVGG_DiffOnly(
+ blocks_per_stage=blocks, T=4
+ ).to(device)
+
+ # Copy weights from original to diff-only
+ model_diff.load_state_dict(model_orig.state_dict())
+
+ print(f"Parameters: {count_parameters(model_orig):,}")
+
+ # Create input
+ x = torch.randn(batch_size, 3, 32, 32, device=device)
+ y = torch.randint(0, 100, (batch_size,), device=device)
+ criterion = nn.CrossEntropyLoss()
+
+ # ============================================================
+ # Test 1: Verify outputs match
+ # ============================================================
+ print("\n--- Output Verification ---")
+ model_orig.eval()
+ model_diff.eval()
+
+ torch.manual_seed(42)
+ with torch.no_grad():
+ logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4)
+
+ torch.manual_seed(42)
+ with torch.no_grad():
+ logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4)
+
+ logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5)
+ lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations
+
+ print(f"Logits match: {logits_match}")
+ print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}")
+ print(f"Lyapunov close (within 0.1): {lyap_close}")
+
+ # ============================================================
+ # Test 2: Forward-only speed (no grad)
+ # ============================================================
+ print("\n--- Forward Speed (no_grad) ---")
+ model_orig.eval()
+ model_diff.eval()
+
+ # Without Lyapunov
+ times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False)
+ times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False)
+
+ mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000
+ mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000
+
+ print(f" Without Lyapunov:")
+ print(f" Original: {mean_orig:.2f} ms")
+ print(f" Diff-only: {mean_diff:.2f} ms")
+
+ # With Lyapunov
+ times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True)
+ times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True)
+
+ mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000
+ mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000
+ speedup = mean_orig_ly / mean_diff_ly
+
+ print(f" With Lyapunov:")
+ print(f" Original: {mean_orig_ly:.2f} ms")
+ print(f" Diff-only: {mean_diff_ly:.2f} ms")
+ print(f" Speedup: {speedup:.2f}x")
+
+ # ============================================================
+ # Test 3: Forward + Backward speed (training mode)
+ # ============================================================
+ print("\n--- Forward+Backward Speed (training) ---")
+ model_orig.train()
+ model_diff.train()
+
+ times_orig_train = benchmark_forward_backward(
+ model_orig, x, y, criterion, compute_lyapunov=True
+ )
+ times_diff_train = benchmark_forward_backward(
+ model_diff, x, y, criterion, compute_lyapunov=True
+ )
+
+ mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000
+ mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000
+ speedup_train = mean_orig_train / mean_diff_train
+
+ print(f" With Lyapunov + backward:")
+ print(f" Original: {mean_orig_train:.2f} ms")
+ print(f" Diff-only: {mean_diff_train:.2f} ms")
+ print(f" Speedup: {speedup_train:.2f}x")
+
+ # ============================================================
+ # Test 4: Memory usage
+ # ============================================================
+ if device.type == "cuda":
+ print("\n--- Peak GPU Memory ---")
+
+ mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False)
+ mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False)
+
+ mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True)
+ mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True)
+
+ print(f" Without Lyapunov:")
+ print(f" Original: {mem_orig_noly:.1f} MB")
+ print(f" Diff-only: {mem_diff_noly:.1f} MB")
+ print(f" With Lyapunov:")
+ print(f" Original: {mem_orig_ly:.1f} MB")
+ print(f" Diff-only: {mem_diff_ly:.1f} MB")
+ print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)")
+
+ # Cleanup
+ del model_orig, model_diff, x, y
+ torch.cuda.empty_cache()
+
+ print("\n" + "=" * 70)
+ print("BENCHMARK COMPLETE")
+ print("=" * 70)
+
+
+if __name__ == "__main__":
+ run_benchmark()
diff --git a/files/experiments/lyapunov_speedup_benchmark.py b/files/experiments/lyapunov_speedup_benchmark.py
new file mode 100644
index 0000000..117009b
--- /dev/null
+++ b/files/experiments/lyapunov_speedup_benchmark.py
@@ -0,0 +1,638 @@
+"""
+Lyapunov Computation Speedup Benchmark
+
+Tests different optimization approaches for computing Lyapunov exponents
+during SNN training. All approaches should produce equivalent results
+(within numerical precision) but with different performance characteristics.
+
+Approaches tested:
+- Baseline: Current sequential implementation
+- Approach A: Trajectory-as-batch (P=2), share first Linear
+- Approach B: Global-norm divergence + single-scale renorm
+- Approach C: torch.compile the time loop
+- Combined: A + B + C together
+"""
+
+import os
+import sys
+import time
+from typing import Tuple, Optional, List
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import snntorch as snn
+from snntorch import surrogate
+
+# Ensure we can import from project
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+
+# =============================================================================
+# Baseline Implementation (Current)
+# =============================================================================
+
+class BaselineSNN(nn.Module):
+ """Current implementation: sequential perturbed trajectory."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ # Simple feedforward for benchmarking (not full VGG)
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims # Flattened input
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1) # Flatten
+
+ # Init membrane potentials
+ mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Original trajectory
+ h = x
+ new_mems = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h = lin(h)
+ spk, mem = lif(h, mems[i])
+ new_mems.append(mem)
+ h = spk
+ mems = new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory (SEPARATE PASS - this is slow)
+ h_p = x
+ new_mems_p = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h_p = lin(h_p)
+ spk_p, mem_p = lif(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # Divergence (per-layer norms, then sum)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq += (diff ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize (per-layer - SLOW)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12
+ new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm
+
+ mems_p = new_mems_p
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach A: Trajectory-as-batch (P=2), share first Linear
+# =============================================================================
+
+class ApproachA_SNN(nn.Module):
+ """Batch both trajectories together, share Linear_1."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ P = 2 if compute_lyapunov else 1
+
+ # State layout: (P, B, H) where P=2 for [original, perturbed]
+ mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ # Initialize perturbed state
+ for i in range(len(self.hidden_dims)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Layer 1: compute Linear ONCE, expand to (P, B, H1)
+ h1 = self.linears[0](x) # (B, H1) - computed ONCE
+
+ if compute_lyapunov:
+ h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view
+ else:
+ h = h1.unsqueeze(0) # (1, B, H1)
+
+ # LIF layer 1
+ spk, mems[0] = self.lifs[0](h, mems[0])
+ h = spk
+
+ # Layers 2+: different inputs for each trajectory
+ for i in range(1, len(self.hidden_dims)):
+ # Reshape to (P*B, H) for batched Linear
+ h_flat = h.reshape(P * B, -1)
+ h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i])
+ spk, mems[i] = self.lifs[i](h_lin, mems[i])
+ h = spk
+
+ # Accumulate spikes from original trajectory only
+ spike_sum = spike_sum + h[0]
+
+ if compute_lyapunov:
+ # Global divergence across all layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0] # (B, H_i)
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # Renormalize with global scale (per-layer still, but simpler)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12
+ mems[i][1] = mems[i][0] + lyap_eps * diff / norm
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach B: Global-norm divergence + single-scale renorm
+# =============================================================================
+
+class ApproachB_SNN(nn.Module):
+ """Global norm for divergence, single scale factor for renorm."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Original trajectory
+ h = x
+ new_mems = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h = lin(h)
+ spk, mem = lif(h, mems[i])
+ new_mems.append(mem)
+ h = spk
+ mems = new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ h_p = x
+ new_mems_p = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h_p = lin(h_p)
+ spk_p, mem_p = lif(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # GLOBAL divergence (one delta per batch element)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # SINGLE SCALE renormalization (key optimization)
+ scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ new_mems_p[i] = new_mems[i] + diff * scale
+
+ mems_p = new_mems_p
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach A+B Combined: Batched trajectories + global renorm
+# =============================================================================
+
+class ApproachAB_SNN(nn.Module):
+ """Combined: trajectory-as-batch + global-norm renorm."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ P = 2 if compute_lyapunov else 1
+
+ # State: (P, B, H)
+ mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ for i in range(len(self.hidden_dims)):
+ mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Layer 1: Linear computed ONCE
+ h1 = self.linears[0](x)
+ h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0)
+
+ spk, mems[0] = self.lifs[0](h, mems[0])
+ h = spk
+
+ # Layers 2+
+ for i in range(1, len(self.hidden_dims)):
+ h_flat = h.reshape(P * B, -1)
+ h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i])
+ spk, mems[i] = self.lifs[i](h_lin, mems[i])
+ h = spk
+
+ spike_sum = spike_sum + h[0]
+
+ if compute_lyapunov:
+ # Global divergence
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # Global scale renorm
+ scale = (lyap_eps / delta).unsqueeze(-1)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ mems[i][1] = mems[i][0] + diff * scale
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach C: torch.compile wrapper
+# =============================================================================
+
+def make_compiled_model(model_class, *args, **kwargs):
+ """Create a model and compile its forward pass."""
+ model = model_class(*args, **kwargs)
+ # Compile the forward method
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ return model
+
+
+# =============================================================================
+# Benchmarking
+# =============================================================================
+
+@dataclass
+class BenchmarkResult:
+ name: str
+ forward_time_ms: float
+ backward_time_ms: float
+ total_time_ms: float
+ lyap_value: float
+ memory_mb: float
+
+ def __str__(self):
+ return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | "
+ f"Bwd: {self.backward_time_ms:7.2f}ms | "
+ f"Total: {self.total_time_ms:7.2f}ms | "
+ f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB")
+
+
+def benchmark_model(
+ model: nn.Module,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ name: str,
+ warmup_iters: int = 5,
+ bench_iters: int = 20,
+) -> BenchmarkResult:
+ """Benchmark a single model configuration."""
+
+ device = x.device
+ criterion = nn.CrossEntropyLoss()
+
+ # Warmup
+ for _ in range(warmup_iters):
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0)
+ loss.backward()
+ model.zero_grad()
+
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ fwd_times = []
+ bwd_times = []
+ lyap_vals = []
+
+ for _ in range(bench_iters):
+ # Forward
+ torch.cuda.synchronize()
+ t0 = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0)
+
+ torch.cuda.synchronize()
+ t1 = time.perf_counter()
+
+ # Backward
+ loss.backward()
+
+ torch.cuda.synchronize()
+ t2 = time.perf_counter()
+
+ fwd_times.append((t1 - t0) * 1000)
+ bwd_times.append((t2 - t1) * 1000)
+ if lyap is not None:
+ lyap_vals.append(lyap.item())
+
+ model.zero_grad()
+
+ peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
+
+ return BenchmarkResult(
+ name=name,
+ forward_time_ms=sum(fwd_times) / len(fwd_times),
+ backward_time_ms=sum(bwd_times) / len(bwd_times),
+ total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times),
+ lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0,
+ memory_mb=peak_mem,
+ )
+
+
+def run_benchmarks(
+ batch_size: int = 64,
+ T: int = 4,
+ hidden_dims: List[int] = [64, 128, 256],
+ device: str = "cuda",
+):
+ """Run all benchmarks and compare."""
+
+ print("=" * 80)
+ print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK")
+ print("=" * 80)
+ print(f"Batch size: {batch_size}")
+ print(f"Timesteps: {T}")
+ print(f"Hidden dims: {hidden_dims}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Create dummy data
+ x = torch.randn(batch_size, 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (batch_size,), device=device)
+
+ results = []
+
+ # 1. Baseline
+ print("\n[1/6] Benchmarking Baseline...")
+ model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "Baseline"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 2. Approach A (batched trajectories)
+ print("[2/6] Benchmarking Approach A (batched)...")
+ model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "A: Batched trajectories"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 3. Approach B (global renorm)
+ print("[3/6] Benchmarking Approach B (global renorm)...")
+ model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "B: Global renorm"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 4. Approach A+B combined
+ print("[4/6] Benchmarking Approach A+B (combined)...")
+ model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "A+B: Combined"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 5. Approach C (torch.compile on baseline)
+ print("[5/6] Benchmarking Approach C (compiled baseline)...")
+ try:
+ model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device)
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10))
+ del model
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f" torch.compile failed: {e}")
+ results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0))
+
+ # 6. A+B+C (all combined)
+ print("[6/6] Benchmarking A+B+C (all optimizations)...")
+ try:
+ model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10))
+ del model
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f" torch.compile failed: {e}")
+ results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0))
+
+ # Print results
+ print("\n" + "=" * 80)
+ print("RESULTS")
+ print("=" * 80)
+
+ baseline_time = results[0].total_time_ms
+
+ for r in results:
+ print(r)
+
+ print("\n" + "-" * 80)
+ print("SPEEDUP vs BASELINE:")
+ print("-" * 80)
+
+ for r in results[1:]:
+ if r.total_time_ms > 0:
+ speedup = baseline_time / r.total_time_ms
+ print(f" {r.name:<25}: {speedup:.2f}x")
+
+ # Verify Lyapunov values are consistent
+ print("\n" + "-" * 80)
+ print("LYAPUNOV VALUE CONSISTENCY CHECK:")
+ print("-" * 80)
+
+ base_lyap = results[0].lyap_value
+ for r in results[1:]:
+ if r.lyap_value != 0:
+ diff = abs(r.lyap_value - base_lyap)
+ status = "✓" if diff < 0.1 else "✗"
+ print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}")
+
+ return results
+
+
+def run_scaling_test(device: str = "cuda"):
+ """Test how approaches scale with batch size and timesteps."""
+
+ print("\n" + "=" * 80)
+ print("SCALING TESTS")
+ print("=" * 80)
+
+ configs = [
+ {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model
+ ]
+
+ print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}")
+ print("-" * 80)
+
+ for cfg in configs:
+ x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (cfg["batch_size"],), device=device)
+
+ # Baseline
+ model_base = BaselineSNN(**cfg).to(device)
+ r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10)
+ del model_base
+
+ # A+B
+ model_ab = ApproachAB_SNN(**cfg).to(device)
+ r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10)
+ del model_ab
+
+ torch.cuda.empty_cache()
+
+ speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0
+
+ cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}"
+ print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--T", type=int, default=4)
+ parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256])
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--scaling", action="store_true", help="Run scaling tests")
+ args = parser.parse_args()
+
+ if not torch.cuda.is_available():
+ print("CUDA not available, using CPU (results will not be representative)")
+ args.device = "cpu"
+
+ # Main benchmark
+ results = run_benchmarks(
+ batch_size=args.batch_size,
+ T=args.T,
+ hidden_dims=args.hidden_dims,
+ device=args.device,
+ )
+
+ # Scaling tests
+ if args.scaling:
+ run_scaling_test(args.device)
diff --git a/files/experiments/plot_depth_comparison.py b/files/experiments/plot_depth_comparison.py
new file mode 100644
index 0000000..2222b7b
--- /dev/null
+++ b/files/experiments/plot_depth_comparison.py
@@ -0,0 +1,305 @@
+"""
+Visualization for depth comparison experiments.
+
+Usage:
+ python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP
+"""
+
+import os
+import sys
+import json
+import argparse
+from typing import Dict, List
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+
+def load_results(results_dir: str) -> Dict:
+ """Load results from JSON file."""
+ with open(os.path.join(results_dir, "results.json"), "r") as f:
+ return json.load(f)
+
+
+def load_config(results_dir: str) -> Dict:
+ """Load config from JSON file."""
+ config_path = os.path.join(results_dir, "config.json")
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ return json.load(f)
+ return {}
+
+
+def plot_training_curves(results: Dict, output_path: str):
+ """
+ Plot training curves for each depth.
+
+ Creates a figure with subplots for each depth showing:
+ - Training loss
+ - Validation accuracy
+ - Lyapunov exponent (if available)
+ - Gradient norm
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+ n_depths = len(depths)
+
+ fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths))
+ if n_depths == 1:
+ axes = axes.reshape(1, -1)
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+ labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"}
+
+ for i, depth in enumerate(depths):
+ for method in ["vanilla", "lyapunov"]:
+ metrics = results[method][str(depth)]
+ epochs = [m["epoch"] for m in metrics]
+
+ # Training Loss
+ train_loss = [m["train_loss"] for m in metrics]
+ axes[i, 0].plot(epochs, train_loss, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 0].set_ylabel("Train Loss")
+ axes[i, 0].set_title(f"Depth={depth}: Training Loss")
+ axes[i, 0].set_yscale("log")
+ axes[i, 0].grid(True, alpha=0.3)
+
+ # Validation Accuracy
+ val_acc = [m["val_acc"] for m in metrics]
+ axes[i, 1].plot(epochs, val_acc, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 1].set_ylabel("Val Accuracy")
+ axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy")
+ axes[i, 1].set_ylim(0, 1)
+ axes[i, 1].grid(True, alpha=0.3)
+
+ # Lyapunov Exponent
+ lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None]
+ lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None]
+ if lyap:
+ axes[i, 2].plot(lyap_epochs, lyap, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ axes[i, 2].set_ylabel("Lyapunov λ")
+ axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent")
+ axes[i, 2].grid(True, alpha=0.3)
+
+ # Gradient Norm
+ grad_norm = [m["grad_norm"] for m in metrics]
+ axes[i, 3].plot(epochs, grad_norm, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 3].set_ylabel("Gradient Norm")
+ axes[i, 3].set_title(f"Depth={depth}: Gradient Norm")
+ axes[i, 3].set_yscale("log")
+ axes[i, 3].grid(True, alpha=0.3)
+
+ # Add legend to first row
+ if i == 0:
+ for ax in axes[i]:
+ ax.legend(loc="upper right")
+
+ # Set x-labels on bottom row
+ for ax in axes[-1]:
+ ax.set_xlabel("Epoch")
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved training curves to {output_path}")
+
+
+def plot_depth_summary(results: Dict, output_path: str):
+ """
+ Plot summary comparing methods across depths.
+
+ Creates a figure showing:
+ - Final validation accuracy vs depth
+ - Final gradient norm vs depth
+ - Final Lyapunov exponent vs depth
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+ markers = {"vanilla": "o", "lyapunov": "s"}
+
+ # Collect final metrics
+ van_acc = []
+ lyap_acc = []
+ van_grad = []
+ lyap_grad = []
+ lyap_lambda = []
+
+ for depth in depths:
+ van_metrics = results["vanilla"][str(depth)][-1]
+ lyap_metrics = results["lyapunov"][str(depth)][-1]
+
+ van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0)
+ lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0)
+
+ van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0)
+ lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0)
+
+ if lyap_metrics["lyapunov"] is not None:
+ lyap_lambda.append(lyap_metrics["lyapunov"])
+ else:
+ lyap_lambda.append(0)
+
+ # Plot 1: Validation Accuracy vs Depth
+ ax = axes[0]
+ ax.plot(depths, van_acc, 'o-', color=colors["vanilla"],
+ label="Vanilla", linewidth=2, markersize=8)
+ ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"],
+ label="Lyapunov", linewidth=2, markersize=8)
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Validation Accuracy")
+ ax.set_title("Accuracy vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05)
+
+ # Plot 2: Gradient Norm vs Depth
+ ax = axes[1]
+ ax.plot(depths, van_grad, 'o-', color=colors["vanilla"],
+ label="Vanilla", linewidth=2, markersize=8)
+ ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"],
+ label="Lyapunov", linewidth=2, markersize=8)
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Gradient Norm")
+ ax.set_title("Gradient Stability vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_yscale("log")
+
+ # Plot 3: Lyapunov Exponent vs Depth
+ ax = axes[2]
+ ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"],
+ linewidth=2, markersize=8)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)")
+ ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region")
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Lyapunov Exponent")
+ ax.set_title("Lyapunov Exponent vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved depth summary to {output_path}")
+
+
+def plot_stability_comparison(results: Dict, output_path: str):
+ """
+ Plot stability metrics comparison.
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+
+ # Collect metrics over training
+ for depth in depths:
+ van_metrics = results["vanilla"][str(depth)]
+ lyap_metrics = results["lyapunov"][str(depth)]
+
+ van_epochs = [m["epoch"] for m in van_metrics]
+ lyap_epochs = [m["epoch"] for m in lyap_metrics]
+
+ # Firing rate
+ van_fr = [m["firing_rate"] for m in van_metrics]
+ lyap_fr = [m["firing_rate"] for m in lyap_metrics]
+ axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+ axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+
+ # Dead neurons
+ van_dead = [m["dead_neurons"] for m in van_metrics]
+ lyap_dead = [m["dead_neurons"] for m in lyap_metrics]
+ axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+ axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+
+ axes[0, 0].set_xlabel("Epoch")
+ axes[0, 0].set_ylabel("Firing Rate")
+ axes[0, 0].set_title("Firing Rate Over Training")
+ axes[0, 0].grid(True, alpha=0.3)
+
+ axes[0, 1].set_xlabel("Epoch")
+ axes[0, 1].set_ylabel("Dead Neuron Fraction")
+ axes[0, 1].set_title("Dead Neurons Over Training")
+ axes[0, 1].grid(True, alpha=0.3)
+
+ # Final metrics bar chart
+ van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths]
+ lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths]
+
+ x = np.arange(len(depths))
+ width = 0.35
+
+ axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"])
+ axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"])
+ axes[1, 0].set_xlabel("Network Depth")
+ axes[1, 0].set_ylabel("Final Validation Accuracy")
+ axes[1, 0].set_title("Final Accuracy Comparison")
+ axes[1, 0].set_xticks(x)
+ axes[1, 0].set_xticklabels(depths)
+ axes[1, 0].legend()
+ axes[1, 0].grid(True, alpha=0.3, axis='y')
+
+ # Improvement percentage
+ improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)]
+ colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements]
+
+ axes[1, 1].bar(x, improvements, color=colors_bar)
+ axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5)
+ axes[1, 1].set_xlabel("Network Depth")
+ axes[1, 1].set_ylabel("Accuracy Improvement")
+ axes[1, 1].set_title("Lyapunov Improvement over Vanilla")
+ axes[1, 1].set_xticks(x)
+ axes[1, 1].set_xticklabels(depths)
+ axes[1, 1].grid(True, alpha=0.3, axis='y')
+
+ # Add legend for line plots
+ custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2),
+ Line2D([0], [0], color=colors["lyapunov"], lw=2)]
+ axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov'])
+ axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov'])
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved stability comparison to {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--results_dir", type=str, required=True,
+ help="Directory containing results.json")
+ parser.add_argument("--output_dir", type=str, default=None,
+ help="Output directory for plots (default: same as results_dir)")
+ args = parser.parse_args()
+
+ output_dir = args.output_dir or args.results_dir
+
+ print(f"Loading results from {args.results_dir}")
+ results = load_results(args.results_dir)
+ config = load_config(args.results_dir)
+
+ print(f"Config: {config}")
+
+ # Generate plots
+ plot_training_curves(results, os.path.join(output_dir, "training_curves.png"))
+ plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png"))
+ plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png"))
+
+ print(f"\nAll plots saved to {output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/posthoc_finetune.py b/files/experiments/posthoc_finetune.py
new file mode 100644
index 0000000..3f3bf6c
--- /dev/null
+++ b/files/experiments/posthoc_finetune.py
@@ -0,0 +1,323 @@
+"""
+Post-hoc Lyapunov Fine-tuning Experiment
+
+Strategy:
+1. Train network with vanilla (no Lyapunov) for N epochs
+2. Then fine-tune with Lyapunov regularization for M epochs
+
+This allows the network to learn task-relevant features first,
+then stabilize dynamics without starting from chaotic initialization.
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from files.experiments.depth_scaling_benchmark import (
+ SpikingVGG,
+ get_dataset,
+ train_epoch,
+ evaluate,
+ TrainingMetrics,
+ compute_lyap_reg_loss,
+)
+
+
+def run_posthoc_experiment(
+ dataset_name: str,
+ depth_config: Tuple[int, int],
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ pretrain_epochs: int,
+ finetune_epochs: int,
+ lr: float,
+ finetune_lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ reg_type: str = "extreme",
+ lyap_threshold: float = 2.0,
+ progress: bool = True,
+) -> Dict:
+ """Run post-hoc fine-tuning experiment."""
+ torch.manual_seed(seed)
+
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ print(f"\n{'='*60}")
+ print(f"POST-HOC FINE-TUNING: Depth = {total_depth}")
+ print(f"Pretrain: {pretrain_epochs} epochs (vanilla)")
+ print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})")
+ print(f"{'='*60}")
+
+ model = SpikingVGG(
+ in_channels=in_channels,
+ num_classes=num_classes,
+ base_channels=64,
+ num_stages=num_stages,
+ blocks_per_stage=blocks_per_stage,
+ T=T,
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"Parameters: {num_params:,}")
+
+ criterion = nn.CrossEntropyLoss()
+
+ # Phase 1: Vanilla pre-training
+ print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---")
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs)
+
+ pretrain_history = []
+ best_pretrain_acc = 0.0
+
+ for epoch in range(1, pretrain_epochs + 1):
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov=False, # No Lyapunov during pre-training
+ lambda_reg=0, lambda_target=0, lyap_eps=1e-4,
+ progress=progress,
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_pretrain_acc = max(best_pretrain_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ pretrain_history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == pretrain_epochs:
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}")
+
+ print(f" Best pretrain acc: {best_pretrain_acc:.3f}")
+
+ # Phase 2: Lyapunov fine-tuning
+ print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---")
+ print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}")
+
+ # Reset optimizer with lower learning rate for fine-tuning
+ optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs)
+
+ finetune_history = []
+ best_finetune_acc = 0.0
+
+ for epoch in range(1, finetune_epochs + 1):
+ t0 = time.time()
+
+ # Warmup lambda_reg over first 10 epochs of fine-tuning
+ warmup_epochs = 10
+ if epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov=True,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=1e-4,
+ progress=progress,
+ reg_type=reg_type,
+ current_lambda_reg=current_lambda_reg,
+ lyap_threshold=lyap_threshold,
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_finetune_acc = max(best_finetune_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=pretrain_epochs + epoch, # Continue epoch numbering
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ finetune_history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == finetune_epochs:
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}")
+
+ if np.isnan(train_loss):
+ print(f" DIVERGED at epoch {epoch}")
+ break
+
+ print(f" Best finetune acc: {best_finetune_acc:.3f}")
+ print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "")
+
+ return {
+ "depth": total_depth,
+ "pretrain_history": pretrain_history,
+ "finetune_history": finetune_history,
+ "best_pretrain_acc": best_pretrain_acc,
+ "best_finetune_acc": best_finetune_acc,
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning")
+ parser.add_argument("--dataset", type=str, default="cifar100",
+ choices=["mnist", "fashion_mnist", "cifar10", "cifar100"])
+ parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16])
+ parser.add_argument("--T", type=int, default=4)
+ parser.add_argument("--pretrain_epochs", type=int, default=100)
+ parser.add_argument("--finetune_epochs", type=int, default=50)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--lr", type=float, default=1e-3)
+ parser.add_argument("--finetune_lr", type=float, default=1e-4)
+ parser.add_argument("--lambda_reg", type=float, default=0.1)
+ parser.add_argument("--lambda_target", type=float, default=-0.1)
+ parser.add_argument("--reg_type", type=str, default="extreme")
+ parser.add_argument("--lyap_threshold", type=float, default=2.0)
+ parser.add_argument("--data_dir", type=str, default="./data")
+ parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--no-progress", action="store_true")
+
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT")
+ print("=" * 80)
+ print(f"Dataset: {args.dataset}")
+ print(f"Depths: {args.depths}")
+ print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})")
+ print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})")
+ print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}")
+ print("=" * 80)
+
+ # Load data
+ train_loader, test_loader, num_classes, input_shape = get_dataset(
+ args.dataset, args.data_dir, args.batch_size
+ )
+ in_channels = input_shape[0]
+
+ # Convert depths to configs
+ depth_configs = []
+ for d in args.depths:
+ if d <= 4:
+ depth_configs.append((d, 1))
+ else:
+ depth_configs.append((4, d // 4))
+
+ # Run experiments
+ all_results = []
+ for depth_config in depth_configs:
+ result = run_posthoc_experiment(
+ dataset_name=args.dataset,
+ depth_config=depth_config,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=args.T,
+ pretrain_epochs=args.pretrain_epochs,
+ finetune_epochs=args.finetune_epochs,
+ lr=args.lr,
+ finetune_lr=args.finetune_lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ reg_type=args.reg_type,
+ lyap_threshold=args.lyap_threshold,
+ progress=not args.no_progress,
+ )
+ all_results.append(result)
+
+ # Summary
+ print("\n" + "=" * 80)
+ print("SUMMARY")
+ print("=" * 80)
+ print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}")
+ print("-" * 80)
+
+ for r in all_results:
+ pre_acc = r["best_pretrain_acc"]
+ fine_acc = r["best_finetune_acc"]
+ change = fine_acc - pre_acc
+ final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None
+ lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A"
+ change_str = f"{change:+.3f}"
+
+ print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}")
+
+ print("=" * 80)
+
+ # Save results
+ os.makedirs(args.out_dir, exist_ok=True)
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json")
+
+ serializable_results = []
+ for r in all_results:
+ sr = {
+ "depth": r["depth"],
+ "best_pretrain_acc": r["best_pretrain_acc"],
+ "best_finetune_acc": r["best_finetune_acc"],
+ "pretrain_history": [asdict(m) for m in r["pretrain_history"]],
+ "finetune_history": [asdict(m) for m in r["finetune_history"]],
+ }
+ serializable_results.append(sr)
+
+ with open(output_file, "w") as f:
+ json.dump({"config": vars(args), "results": serializable_results}, f, indent=2)
+
+ print(f"\nResults saved to {output_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/scaled_reg_grid_search.py b/files/experiments/scaled_reg_grid_search.py
new file mode 100644
index 0000000..928caff
--- /dev/null
+++ b/files/experiments/scaled_reg_grid_search.py
@@ -0,0 +1,301 @@
+"""
+Grid Search: Multiplier-Scaled Regularization Experiments
+
+Tests the new multiplier-scaled regularization approach:
+ loss = (λ_reg × g(relu(λ))) × relu(λ)
+
+Where g(x) is the multiplier scaling function:
+ - mult_linear: g(x) = x → loss = λ_reg × relu(λ)²
+ - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³
+ - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ)
+
+Grid:
+ - λ_reg: 0.01, 0.05, 0.1, 0.3
+ - reg_type: mult_linear, mult_squared, mult_log
+ - depth: specified via command line
+
+Usage:
+ python scaled_reg_grid_search.py --depth 4
+ python scaled_reg_grid_search.py --depth 8
+ python scaled_reg_grid_search.py --depth 12
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional
+from itertools import product
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+# Import from main benchmark
+from depth_scaling_benchmark import (
+ SpikingVGG,
+ compute_lyap_reg_loss,
+)
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+@dataclass
+class ExperimentResult:
+ depth: int
+ reg_type: str
+ lambda_reg: float
+ vanilla_acc: float
+ lyapunov_acc: float
+ final_lyap: Optional[float]
+ delta: float
+
+
+def train_epoch(model, loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, reg_type, progress=False):
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4)
+ loss = criterion(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ # Target is implicitly 0 for scaled reg types
+ lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0)
+ loss = loss + lambda_reg * lyap_reg
+ lyap_vals.append(lyap_est.item())
+
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ _, pred = logits.max(1)
+ correct += pred.eq(y).sum().item()
+ total += x.size(0)
+
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return total_loss / total, correct / total, avg_lyap
+
+
+def evaluate(model, loader, device):
+ """Evaluate model."""
+ model.eval()
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+ _, pred = logits.max(1)
+ correct += pred.eq(y).sum().item()
+ total += x.size(0)
+
+ return correct / total
+
+
+def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader,
+ device, epochs=100, lr=0.001):
+ """Run a single experiment configuration."""
+
+ # Determine blocks per stage based on depth
+ # depth = num_stages * blocks_per_stage, with num_stages=4
+ blocks_per_stage = depth // 4
+
+ print(f"\n{'='*60}")
+ print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}")
+ print(f"{'='*60}")
+
+ # --- Run Vanilla baseline ---
+ print(f" Training Vanilla...")
+ model_v = SpikingVGG(
+ num_classes=100,
+ blocks_per_stage=blocks_per_stage,
+ T=4,
+ ).to(device)
+
+ optimizer_v = optim.Adam(model_v.parameters(), lr=lr)
+ criterion = nn.CrossEntropyLoss()
+ scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs)
+
+ best_vanilla = 0.0
+ for epoch in range(epochs):
+ train_epoch(model_v, train_loader, optimizer_v, criterion, device,
+ use_lyapunov=False, lambda_reg=0, reg_type="squared")
+ scheduler_v.step()
+
+ if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
+ acc = evaluate(model_v, test_loader, device)
+ best_vanilla = max(best_vanilla, acc)
+ print(f" Epoch {epoch+1:3d}: test={acc:.3f}")
+
+ del model_v, optimizer_v, scheduler_v
+ torch.cuda.empty_cache()
+
+ # --- Run Lyapunov version ---
+ print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...")
+ model_l = SpikingVGG(
+ num_classes=100,
+ blocks_per_stage=blocks_per_stage,
+ T=4,
+ ).to(device)
+
+ optimizer_l = optim.Adam(model_l.parameters(), lr=lr)
+ scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs)
+
+ best_lyap_acc = 0.0
+ final_lyap = None
+
+ for epoch in range(epochs):
+ _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device,
+ use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type)
+ scheduler_l.step()
+ final_lyap = lyap
+
+ if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
+ acc = evaluate(model_l, test_loader, device)
+ best_lyap_acc = max(best_lyap_acc, acc)
+ lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A"
+ print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}")
+
+ del model_l, optimizer_l, scheduler_l
+ torch.cuda.empty_cache()
+
+ delta = best_lyap_acc - best_vanilla
+
+ result = ExperimentResult(
+ depth=depth,
+ reg_type=reg_type,
+ lambda_reg=lambda_reg,
+ vanilla_acc=best_vanilla,
+ lyapunov_acc=best_lyap_acc,
+ final_lyap=final_lyap,
+ delta=delta,
+ )
+
+ print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}")
+
+ return result
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12])
+ parser.add_argument("--epochs", type=int, default=100)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--lr", type=float, default=0.001)
+ parser.add_argument("--data_dir", type=str, default="./data")
+ parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid")
+ args = parser.parse_args()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ print("=" * 70)
+ print("SCALED REGULARIZATION GRID SEARCH")
+ print("=" * 70)
+ print(f"Depth: {args.depth}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Device: {device}")
+ if device.type == "cuda":
+ print(f"GPU: {torch.cuda.get_device_name()}")
+ print("=" * 70)
+
+ # Grid parameters
+ lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks
+ reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning
+
+ print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments")
+ print(f"λ_reg values: {lambda_regs}")
+ print(f"reg_types: {reg_types}")
+
+ # Load data
+ print(f"\nLoading CIFAR-100...")
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+
+ train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train)
+ test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test)
+
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=4, pin_memory=True)
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
+ num_workers=4, pin_memory=True)
+
+ print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
+
+ # Run grid search
+ results = []
+
+ for lambda_reg, reg_type in product(lambda_regs, reg_types):
+ result = run_single_experiment(
+ depth=args.depth,
+ reg_type=reg_type,
+ lambda_reg=lambda_reg,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ device=device,
+ epochs=args.epochs,
+ lr=args.lr,
+ )
+ results.append(result)
+
+ # Print summary table
+ print("\n" + "=" * 70)
+ print(f"SUMMARY: DEPTH = {args.depth}")
+ print("=" * 70)
+ print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}")
+ print("-" * 70)
+
+ for r in results:
+ lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A"
+ delta_str = f"{r.delta:+.3f}"
+ print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}")
+
+ # Find best configuration
+ best = max(results, key=lambda x: x.lyapunov_acc)
+ print("-" * 70)
+ print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})")
+
+ # Save results
+ os.makedirs(args.out_dir, exist_ok=True)
+ out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json")
+ with open(out_file, "w") as f:
+ json.dump([asdict(r) for r in results], f, indent=2)
+ print(f"\nResults saved to: {out_file}")
+
+ print("\n" + "=" * 70)
+ print("GRID SEARCH COMPLETE")
+ print("=" * 70)
+
+
+if __name__ == "__main__":
+ main()