summaryrefslogtreecommitdiff
path: root/files
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files')
-rw-r--r--files/__init__.py0
-rw-r--r--files/analysis/plot_mvp.py139
-rw-r--r--files/analysis/stability_monitor.py395
-rw-r--r--files/data_io/__init__.py3
-rw-r--r--files/data_io/benchmark_datasets.py360
-rw-r--r--files/data_io/configs/__init__.py0
-rw-r--r--files/data_io/configs/dvs.yaml0
-rw-r--r--files/data_io/configs/shd.yaml11
-rw-r--r--files/data_io/configs/ssc.yaml0
-rw-r--r--files/data_io/dataset_loader.py281
-rw-r--r--files/data_io/encoders/__init__.py1
-rw-r--r--files/data_io/encoders/base_encoder.py10
-rw-r--r--files/data_io/encoders/latency_encoder.py13
-rw-r--r--files/data_io/encoders/poisson_encoder.py91
-rw-r--r--files/data_io/encoders/rank_order_encoder.py10
-rw-r--r--files/data_io/transforms/__init__.py1
-rw-r--r--files/data_io/transforms/normalization.py52
-rw-r--r--files/data_io/transforms/spike_augmentation.py10
-rw-r--r--files/data_io/utils/__init__.py1
-rw-r--r--files/data_io/utils/file_utils.py15
-rw-r--r--files/data_io/utils/spike_tools.py10
-rw-r--r--files/data_io/utils/visualize.py19
-rw-r--r--files/experiments/benchmark_experiment.py518
-rw-r--r--files/experiments/cifar10_conv_experiment.py448
-rw-r--r--files/experiments/depth_comparison.py542
-rw-r--r--files/experiments/depth_scaling_benchmark.py1035
-rw-r--r--files/experiments/hyperparameter_grid_search.py597
-rw-r--r--files/experiments/lyapunov_diffonly_benchmark.py590
-rw-r--r--files/experiments/lyapunov_speedup_benchmark.py638
-rw-r--r--files/experiments/plot_depth_comparison.py305
-rw-r--r--files/experiments/posthoc_finetune.py323
-rw-r--r--files/experiments/scaled_reg_grid_search.py301
-rw-r--r--files/models/conv_snn.py483
-rw-r--r--files/models/snn.py141
-rw-r--r--files/models/snn_snntorch.py398
-rw-r--r--files/tests/conftest.py11
-rw-r--r--files/tests/test_data_io.py11
-rw-r--r--files/tests/test_shd_loader_properties.py40
-rw-r--r--files/tests/test_train_smoke.py33
-rw-r--r--files/tests/test_transforms_normalize.py53
-rw-r--r--files/train_mvp.py202
-rw-r--r--files/train_snntorch.py345
42 files changed, 8436 insertions, 0 deletions
diff --git a/files/__init__.py b/files/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/__init__.py
diff --git a/files/analysis/plot_mvp.py b/files/analysis/plot_mvp.py
new file mode 100644
index 0000000..495a9c2
--- /dev/null
+++ b/files/analysis/plot_mvp.py
@@ -0,0 +1,139 @@
+import argparse
+import csv
+import json
+import os
+from glob import glob
+from typing import Dict, List, Tuple
+
+import matplotlib.pyplot as plt
+
+
+def load_run(run_dir: str) -> Dict:
+ args_path = os.path.join(run_dir, "args.json")
+ metrics_path = os.path.join(run_dir, "metrics.csv")
+ if not (os.path.exists(args_path) and os.path.exists(metrics_path)):
+ return {}
+ with open(args_path, "r") as f:
+ args = json.load(f)
+ epochs = []
+ loss = []
+ acc = []
+ lyap = []
+ with open(metrics_path, "r") as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ if row.get("step", "") != "epoch":
+ continue
+ try:
+ e = int(row.get("epoch", "0"))
+ l = float(row.get("loss", "nan"))
+ a = float(row.get("acc", "nan"))
+ y = row.get("lyap", "nan")
+ y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y)
+ except Exception:
+ continue
+ epochs.append(e)
+ loss.append(l)
+ acc.append(a)
+ lyap.append(y)
+ if not epochs:
+ return {}
+ return {
+ "args": args,
+ "epochs": epochs,
+ "loss": loss,
+ "acc": acc,
+ "lyap": lyap,
+ "run_dir": run_dir,
+ }
+
+
+def label_for_run(run: Dict) -> str:
+ args = run["args"]
+ if args.get("lyapunov", False):
+ lam = args.get("lambda_reg", None)
+ tgt = args.get("lambda_target", None)
+ hid = args.get("hidden", None)
+ return f"Lyap λ={lam}, tgt={tgt}, H={hid}"
+ else:
+ hid = args.get("hidden", None)
+ return f"Baseline H={hid}"
+
+
+def gather_runs(base_dir: str) -> List[Dict]:
+ cand = sorted(glob(os.path.join(base_dir, "*")))
+ runs = []
+ for rd in cand:
+ data = load_run(rd)
+ if data:
+ runs.append(data)
+ return runs
+
+
+def plot_runs(runs: List[Dict], out_path: str):
+ if not runs:
+ raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)")
+ # Split baseline vs lyapunov
+ base = [r for r in runs if not r["args"].get("lyapunov", False)]
+ lyap = [r for r in runs if r["args"].get("lyapunov", False)]
+
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))
+
+ # Loss
+ ax = axes[0]
+ for r in base:
+ ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9)
+ for r in lyap:
+ ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9)
+ ax.set_title("Training loss")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Loss")
+ ax.grid(True, alpha=0.3)
+ ax.legend(fontsize=8)
+
+ # Accuracy
+ ax = axes[1]
+ for r in base:
+ ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9)
+ for r in lyap:
+ ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9)
+ ax.set_title("Training accuracy")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Accuracy")
+ ax.grid(True, alpha=0.3)
+ ax.legend(fontsize=8)
+
+ # Lyapunov estimate (only lyap runs)
+ ax = axes[2]
+ if lyap:
+ for r in lyap:
+ ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9)
+ ax.set_title("Surrogate Lyapunov estimate")
+ ax.set_xlabel("Epoch")
+ ax.set_ylabel("Avg log growth")
+ ax.grid(True, alpha=0.3)
+ ax.legend(fontsize=8)
+ else:
+ ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes)
+ ax.set_axis_off()
+
+ fig.tight_layout()
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ fig.savefig(out_path, dpi=150)
+ print(f"Saved figure to {out_path}")
+
+
+def main():
+ ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison")
+ ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders")
+ ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path")
+ args = ap.parse_args()
+
+ runs = gather_runs(args.base_dir)
+ plot_runs(runs, args.out)
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/files/analysis/stability_monitor.py b/files/analysis/stability_monitor.py
new file mode 100644
index 0000000..18cad0f
--- /dev/null
+++ b/files/analysis/stability_monitor.py
@@ -0,0 +1,395 @@
+"""
+Stability monitoring utilities for SNN training.
+
+Provides metrics to diagnose training stability:
+- Lyapunov exponent (trajectory divergence)
+- Gradient norms (vanishing/exploding)
+- Firing rates (dead/saturated neurons)
+- Membrane potential statistics
+"""
+
+from typing import Dict, List, Optional, Tuple
+from dataclasses import dataclass, field
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+@dataclass
+class StabilityMetrics:
+ """Container for stability measurements."""
+ lyapunov: Optional[float] = None
+ grad_norm: Optional[float] = None
+ grad_max_sv: Optional[float] = None # Max singular value of gradients
+ grad_min_sv: Optional[float] = None # Min singular value of gradients
+ grad_condition: Optional[float] = None # Condition number (max_sv / min_sv)
+ firing_rate_mean: Optional[float] = None
+ firing_rate_std: Optional[float] = None
+ dead_neuron_frac: Optional[float] = None
+ saturated_neuron_frac: Optional[float] = None
+ membrane_mean: Optional[float] = None
+ membrane_std: Optional[float] = None
+
+ def to_dict(self) -> Dict[str, float]:
+ return {k: v for k, v in self.__dict__.items() if v is not None}
+
+ def __str__(self) -> str:
+ parts = []
+ if self.lyapunov is not None:
+ parts.append(f"λ={self.lyapunov:.4f}")
+ if self.grad_norm is not None:
+ parts.append(f"∇={self.grad_norm:.4f}")
+ if self.firing_rate_mean is not None:
+ parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}")
+ if self.dead_neuron_frac is not None:
+ parts.append(f"dead={self.dead_neuron_frac:.1%}")
+ if self.saturated_neuron_frac is not None:
+ parts.append(f"sat={self.saturated_neuron_frac:.1%}")
+ return " | ".join(parts)
+
+
+class StabilityMonitor:
+ """
+ Monitor SNN stability during training.
+
+ Usage:
+ monitor = StabilityMonitor()
+
+ # During training
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss.backward()
+
+ metrics = monitor.compute(
+ model=model,
+ lyapunov=lyap,
+ spikes=spike_recordings, # optional
+ membrane=membrane_recordings, # optional
+ )
+ print(metrics)
+ """
+
+ def __init__(
+ self,
+ dead_threshold: float = 0.01,
+ saturated_threshold: float = 0.9,
+ history_size: int = 100,
+ ):
+ """
+ Args:
+ dead_threshold: Firing rate below this = dead neuron
+ saturated_threshold: Firing rate above this = saturated neuron
+ history_size: Number of batches to track for moving averages
+ """
+ self.dead_threshold = dead_threshold
+ self.saturated_threshold = saturated_threshold
+ self.history_size = history_size
+
+ # History tracking
+ self.lyap_history: List[float] = []
+ self.grad_history: List[float] = []
+ self.fr_history: List[float] = []
+
+ def compute_gradient_norm(self, model: nn.Module) -> float:
+ """Compute total gradient norm across all parameters."""
+ total_norm = 0.0
+ for p in model.parameters():
+ if p.grad is not None:
+ total_norm += p.grad.data.norm(2).item() ** 2
+ return total_norm ** 0.5
+
+ def compute_gradient_singular_values(
+ self,
+ model: nn.Module,
+ top_k: int = 5,
+ ) -> Dict[str, Tuple[np.ndarray, float]]:
+ """
+ Compute singular values of gradient matrices.
+
+ Per Gradient Flossing paper: singular value spectrum reveals
+ gradient pathologies (vanishing/exploding/rank collapse).
+
+ Args:
+ model: The SNN model
+ top_k: Number of top/bottom singular values to return
+
+ Returns:
+ Dict mapping layer name to (singular_values, condition_number)
+ """
+ results = {}
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ # Only compute for weight matrices (2D)
+ with torch.no_grad():
+ G = param.grad.detach().cpu()
+ try:
+ # Full SVD is expensive; use truncated for large matrices
+ if G.shape[0] * G.shape[1] > 1e6:
+ # For very large matrices, just compute extremes
+ U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape)))
+ sv = S.numpy()
+ else:
+ sv = torch.linalg.svdvals(G).numpy()
+
+ max_sv = sv[0] if len(sv) > 0 else 0
+ min_sv = sv[-1] if len(sv) > 0 else 0
+ condition = max_sv / (min_sv + 1e-12)
+
+ results[name] = (sv[:top_k], condition)
+ except Exception:
+ pass # Skip if SVD fails
+ return results
+
+ def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]:
+ """
+ Get aggregate gradient singular value statistics.
+
+ Returns:
+ (max_sv, min_sv, avg_condition_number) across all layers
+ """
+ sv_results = self.compute_gradient_singular_values(model)
+ if not sv_results:
+ return 0.0, 0.0, 1.0
+
+ max_svs = []
+ min_svs = []
+ conditions = []
+
+ for name, (sv, cond) in sv_results.items():
+ if len(sv) > 0:
+ max_svs.append(sv[0])
+ min_svs.append(sv[-1])
+ conditions.append(cond)
+
+ if not max_svs:
+ return 0.0, 0.0, 1.0
+
+ return (
+ float(np.max(max_svs)),
+ float(np.min(min_svs)),
+ float(np.mean(conditions))
+ )
+
+ def compute_firing_stats(
+ self,
+ spikes: torch.Tensor,
+ ) -> Tuple[float, float, float, float]:
+ """
+ Compute firing rate statistics.
+
+ Args:
+ spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H)
+ Values should be 0/1.
+
+ Returns:
+ (mean_rate, std_rate, dead_frac, saturated_frac)
+ """
+ with torch.no_grad():
+ # Flatten to (num_samples, num_neurons) if needed
+ if spikes.ndim == 3:
+ # (B, T, H) -> compute rate per neuron per sample
+ rates = spikes.float().mean(dim=1) # (B, H)
+ elif spikes.ndim == 2:
+ # Could be (T, H) or (B, H) - assume (T, H) for single sample
+ rates = spikes.float().mean(dim=0, keepdim=True) # (1, H)
+ else:
+ rates = spikes.float().unsqueeze(0)
+
+ # Per-neuron average rate across batch
+ neuron_rates = rates.mean(dim=0) # (H,)
+
+ mean_rate = neuron_rates.mean().item()
+ std_rate = neuron_rates.std().item()
+
+ dead_frac = (neuron_rates < self.dead_threshold).float().mean().item()
+ saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item()
+
+ return mean_rate, std_rate, dead_frac, saturated_frac
+
+ def compute_membrane_stats(
+ self,
+ membrane: torch.Tensor,
+ ) -> Tuple[float, float]:
+ """
+ Compute membrane potential statistics.
+
+ Args:
+ membrane: Membrane potential tensor, any shape
+
+ Returns:
+ (mean, std)
+ """
+ with torch.no_grad():
+ return membrane.mean().item(), membrane.std().item()
+
+ def compute(
+ self,
+ model: nn.Module,
+ lyapunov: Optional[torch.Tensor] = None,
+ spikes: Optional[torch.Tensor] = None,
+ membrane: Optional[torch.Tensor] = None,
+ compute_sv: bool = False,
+ ) -> StabilityMetrics:
+ """
+ Compute all available stability metrics.
+
+ Args:
+ model: The SNN model (for gradient norms)
+ lyapunov: Lyapunov exponent from forward pass
+ spikes: Recorded spikes (optional)
+ membrane: Recorded membrane potentials (optional)
+ compute_sv: Whether to compute gradient singular values (expensive)
+
+ Returns:
+ StabilityMetrics object
+ """
+ metrics = StabilityMetrics()
+
+ # Lyapunov exponent
+ if lyapunov is not None:
+ if isinstance(lyapunov, torch.Tensor):
+ lyapunov = lyapunov.item()
+ metrics.lyapunov = lyapunov
+ self.lyap_history.append(lyapunov)
+ if len(self.lyap_history) > self.history_size:
+ self.lyap_history.pop(0)
+
+ # Gradient norm
+ grad_norm = self.compute_gradient_norm(model)
+ metrics.grad_norm = grad_norm
+ self.grad_history.append(grad_norm)
+ if len(self.grad_history) > self.history_size:
+ self.grad_history.pop(0)
+
+ # Gradient singular values (optional, expensive)
+ if compute_sv:
+ max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model)
+ metrics.grad_max_sv = max_sv
+ metrics.grad_min_sv = min_sv
+ metrics.grad_condition = avg_cond
+
+ # Firing rate statistics
+ if spikes is not None:
+ fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes)
+ metrics.firing_rate_mean = fr_mean
+ metrics.firing_rate_std = fr_std
+ metrics.dead_neuron_frac = dead_frac
+ metrics.saturated_neuron_frac = sat_frac
+ self.fr_history.append(fr_mean)
+ if len(self.fr_history) > self.history_size:
+ self.fr_history.pop(0)
+
+ # Membrane potential statistics
+ if membrane is not None:
+ mem_mean, mem_std = self.compute_membrane_stats(membrane)
+ metrics.membrane_mean = mem_mean
+ metrics.membrane_std = mem_std
+
+ return metrics
+
+ def get_trends(self) -> Dict[str, str]:
+ """
+ Analyze trends in stability metrics.
+
+ Returns:
+ Dictionary with trend analysis for each metric.
+ """
+ trends = {}
+
+ if len(self.lyap_history) >= 10:
+ recent = np.mean(self.lyap_history[-10:])
+ older = np.mean(self.lyap_history[:10])
+ if recent > older + 0.1:
+ trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)"
+ elif recent < older - 0.1:
+ trends["lyapunov"] = "✓ DECREASING (stabilizing)"
+ else:
+ trends["lyapunov"] = "→ STABLE"
+
+ if len(self.grad_history) >= 10:
+ recent = np.mean(self.grad_history[-10:])
+ older = np.mean(self.grad_history[:10])
+ ratio = recent / (older + 1e-8)
+ if ratio > 10:
+ trends["gradients"] = "⚠️ EXPLODING"
+ elif ratio < 0.1:
+ trends["gradients"] = "⚠️ VANISHING"
+ else:
+ trends["gradients"] = "✓ STABLE"
+
+ if len(self.fr_history) >= 10:
+ recent = np.mean(self.fr_history[-10:])
+ if recent < 0.01:
+ trends["firing"] = "⚠️ DEAD NETWORK"
+ elif recent > 0.8:
+ trends["firing"] = "⚠️ SATURATED"
+ else:
+ trends["firing"] = "✓ HEALTHY"
+
+ return trends
+
+ def diagnose(self) -> str:
+ """Generate a diagnostic summary."""
+ trends = self.get_trends()
+
+ lines = ["=== Stability Diagnosis ==="]
+
+ if self.lyap_history:
+ avg_lyap = np.mean(self.lyap_history[-20:])
+ lines.append(f"Lyapunov exponent: {avg_lyap:.4f}")
+ if avg_lyap > 0.5:
+ lines.append(" → Network is CHAOTIC (trajectories diverge quickly)")
+ lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate")
+ elif avg_lyap < -0.5:
+ lines.append(" → Network is OVER-STABLE (trajectories collapse)")
+ lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization")
+ else:
+ lines.append(" → Network is at EDGE OF CHAOS (good for learning)")
+
+ if self.grad_history:
+ avg_grad = np.mean(self.grad_history[-20:])
+ max_grad = max(self.grad_history[-20:])
+ lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}")
+ if "gradients" in trends:
+ lines.append(f" → {trends['gradients']}")
+
+ if self.fr_history:
+ avg_fr = np.mean(self.fr_history[-20:])
+ lines.append(f"Firing rate: {avg_fr:.4f}")
+ if "firing" in trends:
+ lines.append(f" → {trends['firing']}")
+
+ return "\n".join(lines)
+
+
+def compute_spectral_radius(weight_matrix: torch.Tensor) -> float:
+ """
+ Compute spectral radius of a weight matrix.
+
+ For recurrent networks, spectral radius > 1 indicates potential instability.
+
+ Args:
+ weight_matrix: 2D weight tensor
+
+ Returns:
+ Spectral radius (largest absolute eigenvalue)
+ """
+ with torch.no_grad():
+ W = weight_matrix.detach().cpu().numpy()
+ eigenvalues = np.linalg.eigvals(W)
+ return float(np.max(np.abs(eigenvalues)))
+
+
+def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]:
+ """
+ Analyze spectral properties of all weight matrices.
+
+ Returns:
+ Dictionary mapping layer names to spectral radii.
+ """
+ results = {}
+ for name, param in model.named_parameters():
+ if "weight" in name and param.ndim == 2:
+ if param.shape[0] == param.shape[1]: # Square matrix (recurrent)
+ results[name] = compute_spectral_radius(param)
+ return results
diff --git a/files/data_io/__init__.py b/files/data_io/__init__.py
new file mode 100644
index 0000000..de423db
--- /dev/null
+++ b/files/data_io/__init__.py
@@ -0,0 +1,3 @@
+"""
+data_io package: unified data loading, encoding, and preprocessing for spiking neural networks.
+"""
diff --git a/files/data_io/benchmark_datasets.py b/files/data_io/benchmark_datasets.py
new file mode 100644
index 0000000..302ed51
--- /dev/null
+++ b/files/data_io/benchmark_datasets.py
@@ -0,0 +1,360 @@
+"""
+Challenging benchmark datasets for deep SNN evaluation.
+
+Datasets:
+1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps
+2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order
+3. CIFAR-10 with rate coding
+4. DVS-CIFAR10 (requires tonic library)
+
+These benchmarks are harder than SHD and benefit from deeper networks.
+"""
+
+import os
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+
+
+class SequentialMNIST(Dataset):
+ """
+ Sequential MNIST - feed pixels one at a time.
+
+ Each 28x28 image becomes a sequence of 784 timesteps,
+ each with a single pixel intensity converted to spike probability.
+
+ This is MUCH harder than standard MNIST because:
+ - Network must remember information across 784 timesteps
+ - Tests long-range temporal dependencies
+ - Shallow networks fail due to vanishing gradients
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ permute: If True, use fixed random permutation (psMNIST)
+ spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity)
+ max_rate: Maximum firing rate for rate coding
+ seed: Random seed for permutation
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ permute: bool = False,
+ spike_encoding: str = "rate",
+ max_rate: float = 100.0,
+ n_repeat: int = 1, # Repeat each pixel n times for more spikes
+ seed: int = 42,
+ download: bool = True,
+ ):
+ try:
+ from torchvision import datasets, transforms
+ except ImportError:
+ raise ImportError("torchvision required: pip install torchvision")
+
+ self.train = train
+ self.permute = permute
+ self.spike_encoding = spike_encoding
+ self.max_rate = max_rate
+ self.n_repeat = n_repeat
+
+ # Load MNIST
+ self.mnist = datasets.MNIST(
+ root=root,
+ train=train,
+ download=download,
+ transform=transforms.ToTensor(),
+ )
+
+ # Create fixed permutation for psMNIST
+ if permute:
+ rng = np.random.RandomState(seed)
+ self.perm = torch.from_numpy(rng.permutation(784))
+ else:
+ self.perm = None
+
+ def __len__(self):
+ return len(self.mnist)
+
+ def __getitem__(self, idx):
+ img, label = self.mnist[idx]
+
+ # Flatten to (784,)
+ pixels = img.view(-1)
+
+ # Apply permutation
+ if self.perm is not None:
+ pixels = pixels[self.perm]
+
+ # Convert to spike sequence (T, 1) where T = 784 * n_repeat
+ T = 784 * self.n_repeat
+
+ if self.spike_encoding == "direct":
+ # Direct intensity: repeat each pixel n_repeat times
+ spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1)
+
+ elif self.spike_encoding == "rate":
+ # Rate coding: Poisson spikes based on intensity
+ probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins
+ probs = probs.clamp(0, 1)
+ # Repeat and sample
+ probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1)
+ spikes = torch.bernoulli(probs_expanded)
+
+ elif self.spike_encoding == "latency":
+ # Latency coding: spike time proportional to intensity
+ # High intensity = early spike, low = late spike
+ spikes = torch.zeros(T, 1)
+ for i, p in enumerate(pixels):
+ if p > 0.1: # Threshold for spiking
+ # Spike time: higher intensity = earlier
+ spike_time = int((1 - p) * (self.n_repeat - 1))
+ t = i * self.n_repeat + spike_time
+ spikes[t, 0] = 1.0
+
+ return spikes, label
+
+
+class RateCodingCIFAR10(Dataset):
+ """
+ CIFAR-10 with rate coding for SNNs.
+
+ Converts 32x32x3 images to spike trains:
+ - Each pixel channel becomes a Poisson spike train
+ - Total input dimension: 32*32*3 = 3072
+ - Sequence length: T timesteps
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ T: Number of timesteps
+ max_rate: Maximum firing rate (Hz)
+ flatten: If True, flatten spatial dimensions
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ T: int = 100,
+ max_rate: float = 200.0,
+ flatten: bool = True,
+ download: bool = True,
+ ):
+ try:
+ from torchvision import datasets, transforms
+ except ImportError:
+ raise ImportError("torchvision required: pip install torchvision")
+
+ self.T = T
+ self.max_rate = max_rate
+ self.flatten = flatten
+
+ # Normalize to [0, 1]
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ self.cifar = datasets.CIFAR10(
+ root=root,
+ train=train,
+ download=download,
+ transform=transform,
+ )
+
+ def __len__(self):
+ return len(self.cifar)
+
+ def __getitem__(self, idx):
+ img, label = self.cifar[idx] # (3, 32, 32)
+
+ if self.flatten:
+ img = img.view(-1) # (3072,)
+
+ # Rate coding
+ prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps
+ prob_per_step = prob_per_step.clamp(0, 1)
+
+ # Generate spikes for T timesteps
+ if self.flatten:
+ probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072)
+ else:
+ probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32)
+
+ spikes = torch.bernoulli(probs)
+
+ return spikes, label
+
+
+class DVSCIFAR10(Dataset):
+ """
+ DVS-CIFAR10 dataset wrapper.
+
+ Requires the 'tonic' library for neuromorphic datasets:
+ pip install tonic
+
+ DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching
+ CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark.
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ T: Number of time bins
+ spatial_downsample: Downsample spatial resolution
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ T: int = 100,
+ dt_ms: float = 10.0,
+ download: bool = True,
+ ):
+ try:
+ import tonic
+ from tonic import transforms as tonic_transforms
+ except ImportError:
+ raise ImportError(
+ "tonic library required for DVS datasets: pip install tonic"
+ )
+
+ self.T = T
+
+ # Time binning transform
+ sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
+ frame_transform = tonic_transforms.ToFrame(
+ sensor_size=sensor_size,
+ time_window=dt_ms * 1000, # Convert to microseconds
+ )
+
+ self.dataset = tonic.datasets.CIFAR10DVS(
+ save_to=root,
+ train=train,
+ transform=frame_transform,
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities
+
+ # Convert to tensor and flatten spatial dims
+ frames = torch.from_numpy(frames).float()
+
+ # Adjust to target T
+ T_actual = frames.shape[0]
+ if T_actual > self.T:
+ # Subsample
+ indices = torch.linspace(0, T_actual - 1, self.T).long()
+ frames = frames[indices]
+ elif T_actual < self.T:
+ # Pad with zeros
+ pad = torch.zeros(self.T - T_actual, *frames.shape[1:])
+ frames = torch.cat([frames, pad], dim=0)
+
+ # Flatten: (T, 2, H, W) -> (T, 2*H*W)
+ frames = frames.view(self.T, -1)
+
+ return frames, label
+
+
+def get_benchmark_dataloader(
+ dataset_name: str,
+ batch_size: int = 64,
+ root: str = "./data",
+ **kwargs,
+) -> Tuple[DataLoader, DataLoader, dict]:
+ """
+ Get train and validation dataloaders for a benchmark dataset.
+
+ Args:
+ dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10'
+ batch_size: Batch size
+ root: Data directory
+ **kwargs: Additional arguments passed to dataset
+
+ Returns:
+ train_loader, val_loader, info_dict
+ """
+
+ if dataset_name == "smnist":
+ train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs)
+ val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs)
+ info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10,
+ "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"}
+
+ elif dataset_name == "psmnist":
+ train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs)
+ val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs)
+ info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10,
+ "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"}
+
+ elif dataset_name == "cifar10":
+ T = kwargs.pop("T", 100)
+ train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs)
+ val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs)
+ info = {"T": T, "D": 3072, "classes": 10,
+ "description": "CIFAR-10 with rate coding"}
+
+ elif dataset_name == "dvs_cifar10":
+ train_ds = DVSCIFAR10(root, train=True, **kwargs)
+ val_ds = DVSCIFAR10(root, train=False, **kwargs)
+ info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10,
+ "description": "DVS-CIFAR10 neuromorphic dataset"}
+
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}. "
+ f"Options: smnist, psmnist, cifar10, dvs_cifar10")
+
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
+
+ return train_loader, val_loader, info
+
+
+# Quick test
+if __name__ == "__main__":
+ print("Testing benchmark datasets...\n")
+
+ # Test sMNIST
+ print("1. Sequential MNIST")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "smnist", batch_size=32, n_repeat=1, spike_encoding="direct"
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ # Test psMNIST
+ print("\n2. Permuted Sequential MNIST")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct"
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ # Test CIFAR-10
+ print("\n3. CIFAR-10 (rate coded)")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "cifar10", batch_size=32, T=50
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ print("\nDone!")
diff --git a/files/data_io/configs/__init__.py b/files/data_io/configs/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/data_io/configs/__init__.py
diff --git a/files/data_io/configs/dvs.yaml b/files/data_io/configs/dvs.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/data_io/configs/dvs.yaml
diff --git a/files/data_io/configs/shd.yaml b/files/data_io/configs/shd.yaml
new file mode 100644
index 0000000..c786934
--- /dev/null
+++ b/files/data_io/configs/shd.yaml
@@ -0,0 +1,11 @@
+dataset: SHD
+data_dir: /u/yurenh2/ml-projects/snn-training/files/data
+shuffle: true
+encoder:
+ type: poisson
+ max_rate: 50
+ dt_ms: 1.0
+ seed: 42 # 也可不写,默认 42
+transforms:
+ normalize: true
+ spike_jitter: 0.01
diff --git a/files/data_io/configs/ssc.yaml b/files/data_io/configs/ssc.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/data_io/configs/ssc.yaml
diff --git a/files/data_io/dataset_loader.py b/files/data_io/dataset_loader.py
new file mode 100644
index 0000000..983b67b
--- /dev/null
+++ b/files/data_io/dataset_loader.py
@@ -0,0 +1,281 @@
+from typing import Optional
+import yaml
+import torch
+from torch.utils.data import Dataset, DataLoader
+import os
+import h5py
+import numpy as np
+from .encoders import poisson_encoder, latency_encoder, rank_order_encoder
+from .transforms import normalization, spike_augmentation
+# from .utils import file_utils # 目前未使用可先注释
+
+
+
+# -----
+# Base synthetic dataset placeholder
+# -----
+class BaseDataset(Dataset):
+ """
+ Abstract base class for datasets.
+ Each subclass must implement __getitem__ and __len__.
+
+ This base implementation uses synthetic placeholders for quick smoke tests.
+ """
+
+ def __init__(self, data_dir, encoder, transforms=None):
+ self.data_dir = data_dir
+ self.encoder = encoder
+ self.transforms = transforms or []
+ self.samples = list(range(200)) # placeholder synthetic samples
+
+ def __getitem__(self, idx):
+ # synthetic static input -> encode to (T, D) via encoder
+ raw_data = torch.rand(128) # 128-dim intensity
+ label = torch.randint(0, 10, (1,)).item()
+ encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data
+ x = torch.as_tensor(encoded)
+ for t in self.transforms:
+ x = t(x)
+ return x, label
+
+ def __len__(self):
+ return len(self.samples)
+
+
+# -----
+# SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D
+# -----
+import h5py # noqa: E402
+
+
+class SHDDataset(Dataset):
+ """
+ SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T.
+
+ H5 structure (Zenke Lab convention):
+ - f["labels"][i] -> scalar label for sample i
+ - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i
+ - f["spikes"]["units"][i] -> 1D array of channel ids for sample i
+
+ We:
+ (1) scan once to determine global T = ceil(max_time / dt_ms)
+ (2) decide D from max unit id (fallback to default_D=700)
+ (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D)
+ """
+
+ def __init__(
+ self,
+ data_dir: str,
+ encoder=None, # ignored for SHD (already spiking events)
+ transforms=None,
+ split: str = "train",
+ dt_ms: float = 1.0,
+ seed: Optional[int] = None,
+ default_D: int = 700
+ ):
+ super().__init__()
+ self.data_dir = data_dir
+ self.transforms = transforms or []
+ self.dt_ms = float(dt_ms)
+ self.seed = 42 if seed is None else int(seed)
+ self.encoder = None # IMPORTANT: do not apply intensity encoders to event data
+ self.default_D = int(default_D)
+
+ fname = f"shd_{split}.h5"
+ self.path = os.path.join(self.data_dir, fname)
+ if not os.path.exists(self.path):
+ raise FileNotFoundError(f"SHD file not found: {self.path}")
+
+ with h5py.File(self.path, "r") as f:
+ # labels is dense array
+ self.labels = np.array(f["labels"], dtype=np.int64)
+ self.N = int(self.labels.shape[0])
+
+ # ragged datasets for events
+ times_ds = f["spikes"]["times"]
+ units_ds = f["spikes"]["units"]
+
+ # scan once to compute global T and adaptive D
+ t_max_global = 0.0
+ max_unit = -1
+ for i in range(self.N):
+ ti = times_ds[i]
+ ui = units_ds[i]
+ if ti.size > 0:
+ last_t = float(ti[-1]) # ms
+ if last_t > t_max_global:
+ t_max_global = last_t
+ if ui.size > 0:
+ uimax = int(ui.max())
+ if uimax > max_unit:
+ max_unit = uimax
+
+ # decide D
+ if max_unit >= 0:
+ self.D = max(max_unit + 1, self.default_D)
+ else:
+ self.D = self.default_D
+
+ # decide T from global max time
+ self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1
+
+ # rng in case transforms need it
+ self._rng = np.random.default_rng(self.seed)
+
+ def __len__(self):
+ return self.N
+
+ def __getitem__(self, idx: int):
+ # open file per-sample for worker safety
+ with h5py.File(self.path, "r") as f:
+ ti = f["spikes"]["times"][idx][:]
+ ui = f["spikes"]["units"][idx][:]
+ y = int(f["labels"][idx])
+
+ # bin events to (T, D)
+ spikes = np.zeros((self.T, self.D), dtype=np.float32)
+ if ti.size > 0:
+ bins = (ti / self.dt_ms).astype(np.int64)
+ bins = np.clip(bins, 0, self.T - 1)
+ ui = np.clip(ui.astype(np.int64), 0, self.D - 1)
+ spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1
+
+ x = torch.from_numpy(spikes)
+
+ # apply transforms (on torch tensor)
+ for tr in self.transforms:
+ x = tr(x)
+
+ return x, y
+
+
+# -----
+# SSC / DVS placeholders (still synthetic; implement real readers later)
+# -----
+class SSCDataset(BaseDataset):
+ """Placeholder SSC dataset (synthetic)."""
+ pass
+
+
+class DVSDataset(BaseDataset):
+ """Placeholder DVS dataset (synthetic)."""
+ pass
+
+
+# -----
+# Helpers: encoders / transforms / cfg path resolution
+# -----
+def build_encoder(cfg):
+ """
+ Build encoder from config dict.
+
+ Expected schema:
+ encoder:
+ type: poisson | latency | rank_order
+ # Poisson-only optional fields:
+ max_rate: 50
+ T: 64
+ dt_ms: 1.0
+ seed: 123
+ """
+ etype = cfg["type"].lower()
+ if etype == "poisson":
+ return poisson_encoder.PoissonEncoder(
+ max_rate=cfg.get("max_rate", cfg.get("rate", 20)),
+ T=cfg.get("T", 50),
+ dt_ms=cfg.get("dt_ms", 1.0),
+ seed=cfg.get("seed", None),
+ )
+ elif etype == "latency":
+ return latency_encoder.LatencyEncoder()
+ elif etype == "rank_order":
+ return rank_order_encoder.RankOrderEncoder()
+ else:
+ raise ValueError(f"Unknown encoder type: {etype}")
+
+
+def build_transforms(cfg):
+ tlist = []
+ if cfg.get("normalize", False):
+ tlist.append(normalization.Normalize())
+ if cfg.get("spike_jitter", None) is not None:
+ tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"]))
+ return tlist
+
+
+def _resolve_cfg_path(cfg_path: str) -> str:
+ """
+ Resolve cfg_path against:
+ 1) as-is (absolute or CWD-relative)
+ 2) relative to this package directory
+ 3) <pkg_dir>/configs/<basename>
+ """
+ if os.path.isabs(cfg_path) and os.path.exists(cfg_path):
+ return cfg_path
+ if os.path.exists(cfg_path):
+ return cfg_path
+ pkg_dir = os.path.dirname(__file__)
+ cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path))
+ if os.path.exists(cand2):
+ return cand2
+ cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path))
+ if os.path.exists(cand3):
+ return cand3
+ raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}")
+
+
+# -----
+# Entry: get_dataloader
+# -----
+def get_dataloader(cfg_path):
+ """
+ Create train/val DataLoader from YAML config.
+ Handles SHD as true event dataset (encoder=None), others as synthetic placeholders.
+ """
+ cfg_path_resolved = _resolve_cfg_path(cfg_path)
+ with open(cfg_path_resolved, "r") as f:
+ cfg = yaml.safe_load(f)
+
+ dataset_name = cfg["dataset"].lower()
+ data_dir = cfg["data_dir"]
+ transforms = build_transforms(cfg.get("transforms", {}))
+
+ if dataset_name == "shd":
+ # event dataset: do NOT use intensity encoders here
+ dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0)
+ seed = cfg.get("encoder", {}).get("seed", 42)
+ ds_train = SHDDataset(
+ data_dir, encoder=None, transforms=transforms, split="train",
+ dt_ms=dt_ms, seed=seed
+ )
+ ds_val = SHDDataset(
+ data_dir, encoder=None, transforms=transforms, split="test",
+ dt_ms=dt_ms, seed=seed
+ )
+ elif dataset_name == "ssc":
+ # placeholder path; later implement true SSC reader
+ encoder = build_encoder(cfg["encoder"])
+ ds_train = SSCDataset(data_dir, encoder, transforms)
+ ds_val = SSCDataset(data_dir, encoder, transforms)
+ elif dataset_name == "dvs":
+ encoder = build_encoder(cfg["encoder"])
+ ds_train = DVSDataset(data_dir, encoder, transforms)
+ ds_val = DVSDataset(data_dir, encoder, transforms)
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}")
+
+ train_loader = DataLoader(
+ ds_train,
+ batch_size=cfg.get("batch_size", 16),
+ shuffle=cfg.get("shuffle", True),
+ num_workers=cfg.get("num_workers", 0),
+ pin_memory=cfg.get("pin_memory", False),
+ )
+ val_loader = DataLoader(
+ ds_val,
+ batch_size=cfg.get("batch_size", 16),
+ shuffle=False,
+ num_workers=cfg.get("num_workers", 0),
+ pin_memory=cfg.get("pin_memory", False),
+ )
+ return train_loader, val_loader \ No newline at end of file
diff --git a/files/data_io/encoders/__init__.py b/files/data_io/encoders/__init__.py
new file mode 100644
index 0000000..b2c5dc3
--- /dev/null
+++ b/files/data_io/encoders/__init__.py
@@ -0,0 +1 @@
+"""Encoder submodule: provides various spike encoding strategies."""
diff --git a/files/data_io/encoders/base_encoder.py b/files/data_io/encoders/base_encoder.py
new file mode 100644
index 0000000..e40451f
--- /dev/null
+++ b/files/data_io/encoders/base_encoder.py
@@ -0,0 +1,10 @@
+import numpy as np
+import torch
+
+class BaseEncoder:
+ """Abstract base class for all encoders."""
+ def encode(self, data: np.ndarray) -> torch.Tensor:
+ """
+ Convert static data (e.g., image, waveform) into spike tensor (T, input_dim).
+ """
+ raise NotImplementedError
diff --git a/files/data_io/encoders/latency_encoder.py b/files/data_io/encoders/latency_encoder.py
new file mode 100644
index 0000000..a7804ae
--- /dev/null
+++ b/files/data_io/encoders/latency_encoder.py
@@ -0,0 +1,13 @@
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class LatencyEncoder(BaseEncoder):
+ """Encode input intensity into spike latency."""
+ def __init__(self):
+ pass
+
+ def encode(self, data: np.ndarray) -> torch.Tensor:
+ # TODO: map value→time delay
+ spikes = torch.zeros(10, data.size) # placeholder
+ return spikes
diff --git a/files/data_io/encoders/poisson_encoder.py b/files/data_io/encoders/poisson_encoder.py
new file mode 100644
index 0000000..0b404f7
--- /dev/null
+++ b/files/data_io/encoders/poisson_encoder.py
@@ -0,0 +1,91 @@
+from typing import Optional, Union
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class PoissonEncoder(BaseEncoder):
+ r"""
+ PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling.
+
+ Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D}
+ by sampling at each time step t and dimension d:
+ S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ).
+
+ Parameters
+ ----------
+ max_rate : float
+ Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20.
+ T : int
+ Number of discrete time steps in the encoded spike train. Default: 50.
+ dt_ms : float
+ Time resolution per step in milliseconds. Default: 1.0 ms.
+ Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability.
+ seed : int or None
+ Optional RNG seed for reproducibility. If None, uses global RNG state.
+
+ Notes
+ -----
+ - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]).
+ If 1D, we return S with shape (T, D).
+ If 2D (batched), we broadcast probabilities across batch and return (T, B, D).
+ - Intensities outside [0,1] will be clipped to [0,1].
+ - Device of returned tensor follows torch default device (CPU) unless you move it later.
+ """
+
+ def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None):
+ super().__init__()
+ self.max_rate = float(max_rate)
+ self.T = int(T)
+ self.dt_ms = float(dt_ms)
+ self.seed = seed
+ # local generator for reproducibility if seed is provided
+ self._g = torch.Generator()
+ if seed is not None:
+ self._g.manual_seed(int(seed))
+
+ def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
+ if isinstance(data, torch.Tensor):
+ data = data.detach().cpu().numpy()
+ return np.asarray(data)
+
+ def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """
+ Convert input intensities to Poisson spike trains.
+
+ Parameters
+ ----------
+ data : np.ndarray or torch.Tensor
+ Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not).
+
+ Returns
+ -------
+ spikes : torch.Tensor
+ Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}.
+ """
+ x = self._ensure_numpy(data)
+ # clip to [0,1]
+ x = np.clip(x, 0.0, 1.0)
+
+ # compute per-step probability
+ p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity
+ # probability tensor (broadcast-friendly)
+ probs = x * p_step
+ probs = np.clip(probs, 0.0, 1.0)
+
+ # create Bernoulli samples for T steps
+ # If input is 1D: probs.shape = (D,) -> output (T, D)
+ # If input is 2D: probs.shape = (B, D) -> output (T, B, D)
+ if probs.ndim == 1:
+ D = probs.shape[0]
+ probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D)
+ probs_t = torch.from_numpy(probs_t.astype(np.float32))
+ spikes = torch.bernoulli(probs_t, generator=self._g)
+ return spikes
+ elif probs.ndim == 2:
+ B, D = probs.shape
+ probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D)
+ probs_t = torch.from_numpy(probs_t.astype(np.float32))
+ spikes = torch.bernoulli(probs_t, generator=self._g)
+ return spikes
+ else:
+ raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}")
diff --git a/files/data_io/encoders/rank_order_encoder.py b/files/data_io/encoders/rank_order_encoder.py
new file mode 100644
index 0000000..9102e90
--- /dev/null
+++ b/files/data_io/encoders/rank_order_encoder.py
@@ -0,0 +1,10 @@
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class RankOrderEncoder(BaseEncoder):
+ """Encode by rank order of input features."""
+ def encode(self, data: np.ndarray) -> torch.Tensor:
+ # TODO: implement rank order conversion
+ spikes = torch.zeros(10, data.size) # placeholder
+ return spikes
diff --git a/files/data_io/transforms/__init__.py b/files/data_io/transforms/__init__.py
new file mode 100644
index 0000000..f6aa496
--- /dev/null
+++ b/files/data_io/transforms/__init__.py
@@ -0,0 +1 @@
+"""Transforms for spike data (normalization, augmentation, etc.)."""
diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py
new file mode 100644
index 0000000..c86b8c7
--- /dev/null
+++ b/files/data_io/transforms/normalization.py
@@ -0,0 +1,52 @@
+import torch
+
+class Normalize:
+ """Normalize spike input (z-score or min-max)."""
+ def __init__(self, mode: str = "zscore", eps: float = 1e-6):
+ """
+ Parameters
+ ----------
+ mode : str
+ One of {"zscore", "minmax"}.
+ - "zscore": per-sample, per-channel standardization across time.
+ - "minmax": per-sample, per-channel min-max scaling across time.
+ eps : float
+ Small constant to avoid division by zero.
+ """
+ mode_l = str(mode).lower()
+ if mode_l not in {"zscore", "minmax"}:
+ raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}")
+ self.mode = mode_l
+ self.eps = float(eps)
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Apply normalization.
+ Accepts tensors of shape (T, D) or (B, T, D).
+ Normalization is computed per-sample and per-channel over the time axis T.
+ """
+ if not isinstance(x, torch.Tensor):
+ raise TypeError("Normalize expects a torch.Tensor")
+ if x.ndim == 2:
+ # (T, D) -> time dim = 0
+ time_dim = 0
+ keep_dims = True
+ elif x.ndim == 3:
+ # (B, T, D) -> time dim = 1
+ time_dim = 1
+ keep_dims = True
+ else:
+ raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}")
+
+ if self.mode == "zscore":
+ mean_t = x.mean(dim=time_dim, keepdim=keep_dims)
+ # population std (unbiased=False) to avoid NaNs for small T
+ std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False)
+ x_norm = (x - mean_t) / (std_t + self.eps)
+ return x_norm
+ else: # "minmax"
+ min_t = x.amin(dim=time_dim, keepdim=keep_dims)
+ max_t = x.amax(dim=time_dim, keepdim=keep_dims)
+ denom = (max_t - min_t).clamp_min(self.eps)
+ x_scaled = (x - min_t) / denom
+ return x_scaled
diff --git a/files/data_io/transforms/spike_augmentation.py b/files/data_io/transforms/spike_augmentation.py
new file mode 100644
index 0000000..9b7b687
--- /dev/null
+++ b/files/data_io/transforms/spike_augmentation.py
@@ -0,0 +1,10 @@
+import torch
+
+class SpikeJitter:
+ """Add temporal jitter noise to spikes."""
+ def __init__(self, std=0.01):
+ self.std = std
+
+ def __call__(self, spikes: torch.Tensor) -> torch.Tensor:
+ # TODO: add random jitter to spike timings
+ return spikes
diff --git a/files/data_io/utils/__init__.py b/files/data_io/utils/__init__.py
new file mode 100644
index 0000000..ee3ab2f
--- /dev/null
+++ b/files/data_io/utils/__init__.py
@@ -0,0 +1 @@
+"""Utility functions for file management, spike tools, and visualization."""
diff --git a/files/data_io/utils/file_utils.py b/files/data_io/utils/file_utils.py
new file mode 100644
index 0000000..0a1e846
--- /dev/null
+++ b/files/data_io/utils/file_utils.py
@@ -0,0 +1,15 @@
+import os
+
+def ensure_dir(path: str):
+ """Ensure that a directory exists."""
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def list_files(root: str, suffix: str):
+ """Recursively list files ending with suffix."""
+ matches = []
+ for dirpath, _, filenames in os.walk(root):
+ for f in filenames:
+ if f.endswith(suffix):
+ matches.append(os.path.join(dirpath, f))
+ return matches
diff --git a/files/data_io/utils/spike_tools.py b/files/data_io/utils/spike_tools.py
new file mode 100644
index 0000000..968ee72
--- /dev/null
+++ b/files/data_io/utils/spike_tools.py
@@ -0,0 +1,10 @@
+import torch
+import numpy as np
+
+def to_raster(spikes: torch.Tensor) -> np.ndarray:
+ """Convert spike tensor (T,B,N) to raster array (T,N)."""
+ return spikes.detach().cpu().numpy().mean(axis=1)
+
+def firing_rate(spikes: torch.Tensor, dt=1.0):
+ """Compute firing rate per neuron."""
+ return spikes.sum(dim=0) / (spikes.shape[0] * dt)
diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py
new file mode 100644
index 0000000..6f0de95
--- /dev/null
+++ b/files/data_io/utils/visualize.py
@@ -0,0 +1,19 @@
+import matplotlib.pyplot as plt
+import torch
+
+def plot_raster(spikes: torch.Tensor, title=None):
+ """
+ Plot raster diagram of spike activity (T,B,N) or (T,N).
+ """
+ s = spikes.detach().cpu()
+ if s.ndim == 3:
+ s = s[:, 0, :] # take first batch
+ t, n = s.shape
+ for i in range(n):
+ times = torch.nonzero(s[:, i]).squeeze().numpy()
+ plt.scatter(times, i * np.ones_like(times), s=2, c='black')
+ plt.xlabel("Time step")
+ plt.ylabel("Neuron index")
+ if title:
+ plt.title(title)
+ plt.show()
diff --git a/files/experiments/benchmark_experiment.py b/files/experiments/benchmark_experiment.py
new file mode 100644
index 0000000..fb01ff2
--- /dev/null
+++ b/files/experiments/benchmark_experiment.py
@@ -0,0 +1,518 @@
+"""
+Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets.
+
+Datasets:
+- Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks
+- Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory
+- CIFAR-10: Rate-coded images, requires hierarchical features
+
+Usage:
+ python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8
+ python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+from files.data_io.benchmark_datasets import get_benchmark_dataloader
+
+
+@dataclass
+class EpochMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ grad_max_sv: Optional[float]
+ grad_min_sv: Optional[float]
+ grad_condition: Optional[float]
+ time_sec: float
+
+
+def compute_gradient_svs(model):
+ """Compute gradient singular value statistics."""
+ max_svs = []
+ min_svs = []
+
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ with torch.no_grad():
+ G = param.grad.detach()
+ try:
+ sv = torch.linalg.svdvals(G)
+ if len(sv) > 0:
+ max_svs.append(sv[0].item())
+ min_svs.append(sv[-1].item())
+ except Exception:
+ pass
+
+ if not max_svs:
+ return None, None, None
+
+ max_sv = max(max_svs)
+ min_sv = min(min_svs)
+ condition = max_sv / (min_sv + 1e-12)
+
+ return max_sv, min_sv, condition
+
+
+def create_model(
+ input_dim: int,
+ num_classes: int,
+ depth: int,
+ hidden_dim: int = 128,
+ beta: float = 0.9,
+) -> LyapunovSNN:
+ """Create SNN with specified depth."""
+ hidden_dims = [hidden_dim] * depth
+ return LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=beta,
+ threshold=1.0,
+ )
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ compute_sv_every: int = 10,
+) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]:
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+ grad_max_svs = []
+ grad_min_svs = []
+ grad_conditions = []
+
+ for batch_idx, (x, y) in enumerate(loader):
+ x, y = x.to(device), y.to(device)
+
+ # Handle different input shapes
+ if x.ndim == 2:
+ x = x.unsqueeze(-1) # (B, T) -> (B, T, 1)
+
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ record_states=False,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan'), None, None, None
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ # Compute gradient SVs periodically
+ if batch_idx % compute_sv_every == 0:
+ max_sv, min_sv, cond = compute_gradient_svs(model)
+ if max_sv is not None:
+ grad_max_svs.append(max_sv)
+ grad_min_svs.append(min_sv)
+ grad_conditions.append(cond)
+
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ avg_loss = total_loss / total_samples
+ avg_acc = total_correct / total_samples
+ avg_lyap = np.mean(lyap_vals) if lyap_vals else None
+ avg_grad = np.mean(grad_norms)
+ avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None
+ avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None
+ avg_cond = np.mean(grad_conditions) if grad_conditions else None
+
+ return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+) -> Tuple[float, float]:
+ """Evaluate on validation set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+
+ if x.ndim == 2:
+ x = x.unsqueeze(-1)
+
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+ loss = ce_loss(logits, y)
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_experiment(
+ depth: int,
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run single experiment configuration."""
+ torch.manual_seed(seed)
+
+ model = create_model(
+ input_dim=input_dim,
+ num_classes=num_classes,
+ depth=depth,
+ hidden_dim=hidden_dim,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ method = "Lyapunov" if use_lyapunov else "Vanilla"
+ metrics_history = []
+
+ iterator = range(1, epochs + 1)
+ if progress:
+ iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False)
+
+ for epoch in iterator:
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, ce_loss, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps,
+ )
+
+ val_loss, val_acc = evaluate(model, val_loader, ce_loss, device)
+ dt = time.time() - t0
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=val_loss,
+ val_acc=val_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ if progress:
+ lyap_str = f"λ={lyap:.2f}" if lyap else ""
+ iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str})
+
+ if np.isnan(train_loss):
+ print(f" Training diverged at epoch {epoch}")
+ break
+
+ return metrics_history
+
+
+def run_depth_comparison(
+ dataset_name: str,
+ depths: List[int],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> Dict[str, Dict[int, List[EpochMetrics]]]:
+ """Run comparison across depths."""
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for depth in depths:
+ print(f"\n{'='*60}")
+ print(f"Depth = {depth} layers")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ metrics = run_experiment(
+ depth=depth,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ hidden_dim=hidden_dim,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=lyap_eps,
+ device=device,
+ seed=seed,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ final = metrics[-1]
+ lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A"
+ print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} "
+ f"val_acc={final.val_acc:.3f} {lyap_str}")
+
+ return results
+
+
+def print_summary(results: Dict, dataset_name: str):
+ """Print summary table."""
+ print("\n" + "=" * 90)
+ print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy")
+ print("=" * 90)
+ print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}")
+ print("-" * 90)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A"
+ van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A"
+
+ print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}")
+
+ print("=" * 90)
+
+ # Gradient health analysis
+ print("\nGRADIENT HEALTH:")
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ van_cond = van.grad_condition if van.grad_condition else 0
+ if van_cond > 1e6:
+ print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})")
+ elif van_cond > 1e4:
+ print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})")
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ """Save results to JSON."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN")
+
+ # Dataset
+ p.add_argument("--dataset", type=str, default="smnist",
+ choices=["smnist", "psmnist", "cifar10"],
+ help="Dataset to use")
+ p.add_argument("--data_dir", type=str, default="./data")
+
+ # Model
+ p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8],
+ help="Network depths to test")
+ p.add_argument("--hidden_dim", type=int, default=128)
+
+ # Training
+ p.add_argument("--epochs", type=int, default=30)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+
+ # Lyapunov
+ p.add_argument("--lambda_reg", type=float, default=0.3,
+ help="Lyapunov regularization weight (higher for harder tasks)")
+ p.add_argument("--lambda_target", type=float, default=-0.1,
+ help="Target Lyapunov exponent (negative for stability)")
+ p.add_argument("--lyap_eps", type=float, default=1e-4)
+
+ # Dataset-specific
+ p.add_argument("--T", type=int, default=100,
+ help="Timesteps for CIFAR-10 (sMNIST uses 784)")
+ p.add_argument("--n_repeat", type=int, default=1,
+ help="Repeat each pixel n times for sMNIST")
+
+ # Other
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--out_dir", type=str, default="runs/benchmark")
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}")
+ print("=" * 70)
+ print(f"Depths: {args.depths}")
+ print(f"Hidden dim: {args.hidden_dim}")
+ print(f"Epochs: {args.epochs}")
+ print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}")
+ print(f"Device: {device}")
+ print("=" * 70)
+
+ # Load dataset
+ print(f"\nLoading {args.dataset} dataset...")
+
+ if args.dataset == "smnist":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "smnist",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ n_repeat=args.n_repeat,
+ spike_encoding="direct",
+ )
+ elif args.dataset == "psmnist":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "psmnist",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ n_repeat=args.n_repeat,
+ spike_encoding="direct",
+ )
+ elif args.dataset == "cifar10":
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "cifar10",
+ batch_size=args.batch_size,
+ root=args.data_dir,
+ T=args.T,
+ )
+
+ print(f"Dataset info: {info}")
+ print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
+
+ # Run experiments
+ results = run_depth_comparison(
+ dataset_name=args.dataset,
+ depths=args.depths,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=info["D"],
+ num_classes=info["classes"],
+ hidden_dim=args.hidden_dim,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Print summary
+ print_summary(results, args.dataset)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}")
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/cifar10_conv_experiment.py b/files/experiments/cifar10_conv_experiment.py
new file mode 100644
index 0000000..a582f9f
--- /dev/null
+++ b/files/experiments/cifar10_conv_experiment.py
@@ -0,0 +1,448 @@
+"""
+CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization.
+
+Uses proper convolutional architecture that preserves spatial structure.
+Tests whether Lyapunov regularization helps train deeper Conv-SNNs.
+
+Architecture:
+ Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output
+
+Usage:
+ python files/experiments/cifar10_conv_experiment.py --model simple --T 25
+ python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+from files.models.conv_snn import create_conv_snn
+
+
+@dataclass
+class EpochMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ time_sec: float
+
+
+def get_cifar10_loaders(
+ data_dir: str = './data',
+ batch_size: int = 128,
+ num_workers: int = 4,
+) -> Tuple[DataLoader, DataLoader]:
+ """
+ Get CIFAR-10 dataloaders with standard normalization.
+
+ Images normalized to [0, 1] for rate encoding.
+ """
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ # Note: For rate encoding, we keep values in [0, 1]
+ # No normalization to negative values
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ train_dataset = datasets.CIFAR10(
+ root=data_dir, train=True, download=True, transform=transform_train
+ )
+ test_dataset = datasets.CIFAR10(
+ root=data_dir, train=False, download=True, transform=transform_test
+ )
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True,
+ num_workers=num_workers, pin_memory=True
+ )
+ test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True
+ )
+
+ return train_loader, test_loader
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ progress: bool = True,
+) -> Tuple[float, float, Optional[float], float]:
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32)
+
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan')
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+ optimizer.step()
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ if progress:
+ iterator.set_postfix({
+ "loss": f"{loss.item():.3f}",
+ "acc": f"{total_correct/total_samples:.3f}",
+ })
+
+ return (
+ total_loss / total_samples,
+ total_correct / total_samples,
+ np.mean(lyap_vals) if lyap_vals else None,
+ np.mean(grad_norms),
+ )
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+ progress: bool = True,
+) -> Tuple[float, float]:
+ """Evaluate on test set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ iterator = tqdm(loader, desc="eval", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+
+ loss = ce_loss(logits, y)
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_experiment(
+ model_type: str,
+ channels: List[int],
+ T: int,
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run single experiment."""
+ torch.manual_seed(seed)
+
+ model = create_conv_snn(
+ model_type=model_type,
+ in_channels=3,
+ num_classes=10,
+ channels=channels,
+ T=T,
+ encoding='rate',
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f" Model: {model_type}, params: {num_params:,}")
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ ce_loss = nn.CrossEntropyLoss()
+
+ metrics_history = []
+ best_acc = 0.0
+
+ for epoch in range(1, epochs + 1):
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm = train_epoch(
+ model, train_loader, optimizer, ce_loss, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_acc = max(best_acc, test_acc)
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=test_loss,
+ val_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)")
+
+ if np.isnan(train_loss):
+ print(" Training diverged!")
+ break
+
+ print(f" Best test accuracy: {best_acc:.3f}")
+ return metrics_history
+
+
+def run_comparison(
+ model_type: str,
+ channels_configs: List[List[int]],
+ T: int,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool,
+) -> Dict:
+ """Compare vanilla vs Lyapunov across different depths."""
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for channels in channels_configs:
+ depth = len(channels)
+ print(f"\n{'='*60}")
+ print(f"Depth = {depth} conv layers, channels = {channels}")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ metrics = run_experiment(
+ model_type=model_type,
+ channels=channels,
+ T=T,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=1e-4,
+ device=device,
+ seed=seed,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ return results
+
+
+def print_summary(results: Dict):
+ """Print comparison summary."""
+ print("\n" + "=" * 70)
+ print("SUMMARY: CIFAR-10 Conv-SNN Results")
+ print("=" * 70)
+ print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}")
+ print("-" * 70)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+
+ print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}")
+
+ print("=" * 70)
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ """Save results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+
+ # Model
+ p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"])
+ p.add_argument("--channels", type=int, nargs="+", default=None,
+ help="Channel sizes (default: test multiple depths)")
+ p.add_argument("--T", type=int, default=25, help="Timesteps")
+
+ # Training
+ p.add_argument("--epochs", type=int, default=50)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+
+ # Lyapunov
+ p.add_argument("--lambda_reg", type=float, default=0.3)
+ p.add_argument("--lambda_target", type=float, default=-0.1)
+
+ # Other
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--out_dir", type=str, default="runs/cifar10_conv")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print("CIFAR-10 Conv-SNN Experiment")
+ print("=" * 70)
+ print(f"Model: {args.model}")
+ print(f"Timesteps: {args.T}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Device: {device}")
+ print("=" * 70)
+
+ # Load data
+ print("\nLoading CIFAR-10...")
+ train_loader, test_loader = get_cifar10_loaders(
+ data_dir=args.data_dir,
+ batch_size=args.batch_size,
+ )
+ print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
+
+ # Define depth configurations to test
+ if args.channels:
+ channels_configs = [args.channels]
+ else:
+ # Test increasing depths
+ channels_configs = [
+ [64, 128], # 2 conv layers (shallow)
+ [64, 128, 256], # 3 conv layers
+ [64, 128, 256, 512], # 4 conv layers (deep)
+ ]
+
+ # Run comparison
+ results = run_comparison(
+ model_type=args.model,
+ channels_configs=channels_configs,
+ T=args.T,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Summary
+ print_summary(results)
+
+ # Save
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/depth_comparison.py b/files/experiments/depth_comparison.py
new file mode 100644
index 0000000..48c62d8
--- /dev/null
+++ b/files/experiments/depth_comparison.py
@@ -0,0 +1,542 @@
+"""
+Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths.
+
+Hypothesis:
+- Shallow networks (1-2 layers): Both methods train successfully
+- Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds
+
+Usage:
+ # Quick test (synthetic data)
+ python files/experiments/depth_comparison.py --synthetic --epochs 20
+
+ # Full experiment with SHD data
+ python files/experiments/depth_comparison.py --epochs 50
+
+ # Specific depths to test
+ python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+from files.analysis.stability_monitor import StabilityMonitor
+
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for a single experiment run."""
+ depth: int
+ hidden_dim: int
+ use_lyapunov: bool
+ lambda_reg: float
+ lambda_target: float
+ lyap_eps: float
+ epochs: int
+ lr: float
+ batch_size: int
+ beta: float
+ threshold: float
+ seed: int
+
+
+@dataclass
+class EpochMetrics:
+ """Metrics collected per epoch."""
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ firing_rate: float
+ dead_neurons: float
+ time_sec: float
+
+
+def create_synthetic_data(
+ n_train: int = 2000,
+ n_val: int = 500,
+ T: int = 50,
+ D: int = 100,
+ n_classes: int = 10,
+ seed: int = 42,
+) -> Tuple[DataLoader, DataLoader]:
+ """Create synthetic spike data for testing."""
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ def generate_data(n_samples):
+ # Generate class-conditional spike patterns
+ x = torch.zeros(n_samples, T, D)
+ y = torch.randint(0, n_classes, (n_samples,))
+
+ for i in range(n_samples):
+ label = y[i].item()
+ # Each class has different firing rate pattern
+ base_rate = 0.05 + 0.02 * label
+ # Class-specific channels fire more
+ class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes))
+ for t in range(T):
+ # Background activity
+ x[i, t] = (torch.rand(D) < base_rate).float()
+ # Enhanced activity for class-specific channels
+ for c in class_channels:
+ if torch.rand(1) < base_rate * 3:
+ x[i, t, c] = 1.0
+
+ return x, y
+
+ x_train, y_train = generate_data(n_train)
+ x_val, y_val = generate_data(n_val)
+
+ train_loader = DataLoader(
+ TensorDataset(x_train, y_train),
+ batch_size=64,
+ shuffle=True,
+ )
+ val_loader = DataLoader(
+ TensorDataset(x_val, y_val),
+ batch_size=64,
+ shuffle=False,
+ )
+
+ return train_loader, val_loader, T, D, n_classes
+
+
+def create_model(
+ input_dim: int,
+ num_classes: int,
+ depth: int,
+ hidden_dim: int = 128,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+) -> LyapunovSNN:
+ """Create SNN with specified depth."""
+ # Create hidden dims list based on depth
+ # Gradually decrease size for deeper networks to keep param count reasonable
+ hidden_dims = []
+ current_dim = hidden_dim
+ for i in range(depth):
+ hidden_dims.append(current_dim)
+ # Optionally decrease dim in deeper layers
+ # current_dim = max(64, current_dim // 2)
+
+ return LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=beta,
+ threshold=threshold,
+ )
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ monitor: StabilityMonitor,
+) -> Tuple[float, float, float, float, float, float]:
+ """Train for one epoch, return metrics."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+ firing_rates = []
+ dead_fracs = []
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, recordings = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ record_states=True,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ # Check for NaN
+ if torch.isnan(loss):
+ return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0
+
+ loss.backward()
+
+ # Gradient clipping for stability comparison fairness
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ optimizer.step()
+
+ # Collect metrics
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ # Stability metrics
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ if recordings is not None:
+ spikes = recordings['spikes']
+ fr = spikes.mean().item()
+ dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item()
+ firing_rates.append(fr)
+ dead_fracs.append(dead)
+
+ avg_loss = total_loss / total_samples
+ avg_acc = total_correct / total_samples
+ avg_lyap = np.mean(lyap_vals) if lyap_vals else None
+ avg_grad = np.mean(grad_norms)
+ avg_fr = np.mean(firing_rates) if firing_rates else 0.0
+ avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0
+
+ return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+) -> Tuple[float, float]:
+ """Evaluate model on validation set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+
+ loss = ce_loss(logits, y)
+ if torch.isnan(loss):
+ return float('nan'), 0.0
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_single_experiment(
+ config: ExperimentConfig,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ device: torch.device,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run a single experiment with given configuration."""
+ torch.manual_seed(config.seed)
+
+ model = create_model(
+ input_dim=input_dim,
+ num_classes=num_classes,
+ depth=config.depth,
+ hidden_dim=config.hidden_dim,
+ beta=config.beta,
+ threshold=config.threshold,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=config.lr)
+ ce_loss = nn.CrossEntropyLoss()
+ monitor = StabilityMonitor()
+
+ metrics_history = []
+ method = "Lyapunov" if config.use_lyapunov else "Vanilla"
+
+ iterator = range(1, config.epochs + 1)
+ if progress:
+ iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False)
+
+ for epoch in iterator:
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch(
+ model=model,
+ loader=train_loader,
+ optimizer=optimizer,
+ ce_loss=ce_loss,
+ device=device,
+ use_lyapunov=config.use_lyapunov,
+ lambda_reg=config.lambda_reg,
+ lambda_target=config.lambda_target,
+ lyap_eps=config.lyap_eps,
+ monitor=monitor,
+ )
+
+ val_loss, val_acc = evaluate(model, val_loader, ce_loss, device)
+
+ dt = time.time() - t0
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=val_loss,
+ val_acc=val_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ firing_rate=fr,
+ dead_neurons=dead,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ # Early stopping if training diverged
+ if np.isnan(train_loss):
+ print(f" Training diverged at epoch {epoch}")
+ break
+
+ return metrics_history
+
+
+def run_depth_comparison(
+ depths: List[int],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ device: torch.device,
+ epochs: int = 30,
+ hidden_dim: int = 128,
+ lr: float = 1e-3,
+ lambda_reg: float = 0.1,
+ lambda_target: float = 0.0,
+ lyap_eps: float = 1e-4,
+ beta: float = 0.9,
+ seed: int = 42,
+ progress: bool = True,
+) -> Dict[str, Dict[int, List[EpochMetrics]]]:
+ """
+ Run comparison experiments across depths.
+
+ Returns:
+ Dictionary with structure:
+ {
+ "vanilla": {1: [metrics...], 2: [metrics...], ...},
+ "lyapunov": {1: [metrics...], 2: [metrics...], ...}
+ }
+ """
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for depth in depths:
+ print(f"\n{'='*50}")
+ print(f"Depth = {depth} layers")
+ print(f"{'='*50}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ config = ExperimentConfig(
+ depth=depth,
+ hidden_dim=hidden_dim,
+ use_lyapunov=use_lyap,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=lyap_eps,
+ epochs=epochs,
+ lr=lr,
+ batch_size=64,
+ beta=beta,
+ threshold=1.0,
+ seed=seed,
+ )
+
+ metrics = run_single_experiment(
+ config=config,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ device=device,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ # Print final metrics
+ final = metrics[-1]
+ lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A"
+ print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} "
+ f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}")
+
+ return results
+
+
+def save_results(results: Dict, output_dir: str, config: dict):
+ """Save experiment results to JSON."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Convert metrics to dicts
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]):
+ """Print summary comparison table."""
+ print("\n" + "=" * 70)
+ print("SUMMARY: Final Validation Accuracy by Depth")
+ print("=" * 70)
+ print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}")
+ print("-" * 70)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van_metrics = results["vanilla"][depth]
+ lyap_metrics = results["lyapunov"][depth]
+
+ van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0
+ lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0
+
+ van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED"
+ lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED"
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}")
+
+ print("=" * 70)
+
+ # Gradient analysis
+ print("\nGradient Norm Analysis (final epoch):")
+ print("-" * 70)
+ print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}")
+ print("-" * 70)
+ for depth in depths:
+ van_grad = results["vanilla"][depth][-1].grad_norm
+ lyap_grad = results["lyapunov"][depth][-1].grad_norm
+ print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths")
+ p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6],
+ help="Network depths to test")
+ p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer")
+ p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment")
+ p.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent")
+ p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov")
+ p.add_argument("--beta", type=float, default=0.9, help="Membrane decay")
+ p.add_argument("--seed", type=int, default=42, help="Random seed")
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config")
+ p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--no-progress", action="store_true", help="Disable progress bars")
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print("Experiment: Vanilla vs Lyapunov-Regularized SNN")
+ print("=" * 70)
+ print(f"Depths: {args.depths}")
+ print(f"Hidden dim: {args.hidden_dim}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Lambda_reg: {args.lambda_reg}")
+ print(f"Device: {device}")
+
+ # Load data
+ if args.synthetic:
+ print("\nUsing SYNTHETIC data for quick testing")
+ train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed)
+ else:
+ print(f"\nLoading data from {args.cfg}")
+ from files.data_io.dataset_loader import get_dataloader
+ train_loader, val_loader = get_dataloader(args.cfg)
+ xb, _ = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = 20 # SHD has 20 classes
+
+ print(f"Data: T={T}, D={D}, classes={C}")
+
+ # Run experiments
+ results = run_depth_comparison(
+ depths=args.depths,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=D,
+ num_classes=C,
+ device=device,
+ epochs=args.epochs,
+ hidden_dim=args.hidden_dim,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ beta=args.beta,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Print summary
+ print_summary(results)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/depth_scaling_benchmark.py b/files/experiments/depth_scaling_benchmark.py
new file mode 100644
index 0000000..efab140
--- /dev/null
+++ b/files/experiments/depth_scaling_benchmark.py
@@ -0,0 +1,1035 @@
+"""
+Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs.
+
+Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve.
+
+Key hypothesis (from literature):
+- Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet)
+- Deep SNNs without regularization fail to train (gradient issues)
+- Deep SNNs WITH Lyapunov regularization achieve higher accuracy
+
+Reference results:
+- Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI)
+- SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS)
+- Spikformer on ImageNet: ~74.8% top-1 (arXiv)
+
+Usage:
+ python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+# =============================================================================
+# VGG-style Spiking Network (scalable depth)
+# =============================================================================
+
+class SpikingVGGBlock(nn.Module):
+ """Conv-BN-LIF block for VGG-style architecture."""
+
+ def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None):
+ super().__init__()
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+
+ def forward(self, x, mem):
+ h = self.bn(self.conv(x))
+ spk, mem = self.lif(h, mem)
+ return spk, mem
+
+
+class SpikingVGG(nn.Module):
+ """
+ Scalable VGG-style Spiking Neural Network.
+
+ Architecture follows VGG pattern:
+ - Multiple conv blocks between pooling layers
+ - Depth controlled by num_blocks_per_stage
+
+ Args:
+ in_channels: Input channels (3 for RGB)
+ num_classes: Output classes
+ base_channels: Starting channel count (doubled each stage)
+ num_stages: Number of pooling stages (3-4 typical)
+ blocks_per_stage: Conv blocks per stage (controls depth)
+ T: Number of timesteps
+ beta: LIF membrane decay
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ base_channels: int = 64,
+ num_stages: int = 3,
+ blocks_per_stage: int = 2,
+ T: int = 4,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ dropout: float = 0.25,
+ stable_init: bool = False,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+ self.total_conv_layers = num_stages * blocks_per_stage
+ self.stable_init = stable_init
+
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ # Build stages
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ ch_in = in_channels
+ ch_out = base_channels
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for b in range(blocks_per_stage):
+ block_in = ch_in if b == 0 else ch_out
+ stage_blocks.append(
+ SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad)
+ )
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ ch_in = ch_out
+ ch_out = min(ch_out * 2, 512) # Cap at 512
+
+ # Calculate spatial size after pooling
+ # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages)
+ final_spatial = 32 // (2 ** num_stages)
+ final_channels = min(base_channels * (2 ** (num_stages - 1)), 512)
+ fc_input = final_channels * final_spatial * final_spatial
+
+ # Classifier
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(fc_input, num_classes)
+
+ if stable_init:
+ self._init_weights_stable()
+ else:
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_weights_stable(self):
+ """
+ Stability-aware initialization for SNNs.
+
+ Uses smaller weight magnitudes to produce less chaotic initial dynamics.
+ The key insight: Lyapunov exponent depends on weight magnitudes.
+ Smaller weights → smaller gradients → more stable dynamics.
+
+ Strategy:
+ - Use orthogonal init (preserves gradient magnitude across layers)
+ - Scale down by factor of 0.5 to reduce initial chaos
+ - This should produce λ closer to 0 from the start
+ """
+ scale_factor = 0.5 # Reduce weight magnitudes
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # Orthogonal init for conv (reshape to 2D, init, reshape back)
+ weight_shape = m.weight.shape
+ fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3]
+ fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3]
+
+ # Use smaller gain for stability
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ with torch.no_grad():
+ m.weight.mul_(scale_factor)
+
+ elif isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight, gain=scale_factor)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_mems(self, batch_size, device, dtype, P=1):
+ """Initialize membrane potentials for all LIF layers.
+
+ Args:
+ batch_size: Batch size B
+ device: torch device
+ dtype: torch dtype
+ P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov)
+
+ Returns:
+ List of membrane tensors with shape (P, B, C, H, W)
+ """
+ mems = []
+ H, W = 32, 32
+ ch = 64
+
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ ch = min(ch * 2, 512)
+
+ return mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+ """
+ Forward pass with optimized Lyapunov computation (Approach A: trajectory batching).
+
+ When compute_lyapunov=True, both original and perturbed trajectories are
+ processed together by batching them along a new dimension P=2. This avoids
+ redundant computation, especially for the first conv layer where inputs are identical.
+
+ Args:
+ x: (B, C, H, W) static image (will be repeated for T steps)
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude
+
+ Returns:
+ logits, lyap_est, recordings
+ """
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+
+ # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed)
+ P = 2 if compute_lyapunov else 1
+
+ # Initialize membrane potentials with shape (P, B, C, H, W)
+ mems = self._init_mems(B, device, dtype, P=P)
+
+ # Initialize perturbed trajectory
+ if compute_lyapunov:
+ for i in range(len(mems)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0])
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = None
+
+ # Time loop - repeat static input
+ for t in range(self.T):
+ mem_idx = 0
+ new_mems = []
+ is_first_block = True
+
+ # Process through stages
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ # First block: input x is identical for both trajectories
+ # Compute conv+bn ONCE, then expand to (P, B, C, H, W)
+ h_conv = block.bn(block.conv(x)) # (B, C, H, W)
+ h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy
+
+ # LIF with batched membrane states
+ # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ spk_flat, mem_new_flat = block.lif(h_flat, mem_flat)
+
+ # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+
+ h = spk
+ new_mems.append(mem_new)
+ is_first_block = False
+ else:
+ # Subsequent blocks: inputs differ between trajectories
+ # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+
+ # Full block forward (conv+bn+lif)
+ h_conv = block.bn(block.conv(h_flat))
+ spk_flat, mem_new_flat = block.lif(h_conv, mem_flat)
+
+ # Reshape back
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+
+ h = spk
+ new_mems.append(mem_new)
+
+ mem_idx += 1
+
+ # Pool: apply to batched tensor
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ h_pooled = pool(h_flat)
+ h = h_pooled.view(P, B, *h_pooled.shape[1:])
+
+ mems = new_mems
+
+ # Accumulate final spikes from ORIGINAL trajectory only (index 0)
+ h_orig = h[0].view(B, -1) # (B, C*H*W)
+ if spike_sum is None:
+ spike_sum = h_orig
+ else:
+ spike_sum = spike_sum + h_orig
+
+ # Lyapunov divergence and renormalization (Option 1: global delta + global renorm)
+ # This is the textbook Benettin-style Lyapunov exponent estimator where
+ # the perturbation is treated as one vector in the concatenated state space.
+ if compute_lyapunov:
+ # Compute GLOBAL divergence across all layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W)
+ delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # GLOBAL renormalization: same scale factor for all layers
+ # This ensures ||perturbation||_global = eps after renorm
+ scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting
+
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ # Update perturbed trajectory: scale the diff to have global norm = eps
+ mems[i] = torch.stack([
+ new_mems[i][0],
+ new_mems[i][0] + diff * scale
+ ], dim=0)
+
+ # Readout
+ out = self.dropout(spike_sum)
+ logits = self.fc(out)
+
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est, None
+
+ @property
+ def depth(self):
+ return self.total_conv_layers
+
+
+# =============================================================================
+# Dataset Loading
+# =============================================================================
+
+def get_dataset(
+ name: str,
+ data_dir: str = './data',
+ batch_size: int = 128,
+ num_workers: int = 4,
+) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]:
+ """
+ Get train/test loaders for various datasets.
+
+ Returns:
+ train_loader, test_loader, num_classes, input_shape
+ """
+
+ if name == 'mnist':
+ transform = transforms.Compose([
+ transforms.Resize(32), # Resize to 32x32 for consistency
+ transforms.ToTensor(),
+ ])
+ train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+ test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
+ num_classes = 10
+ input_shape = (1, 32, 32)
+
+ elif name == 'fashion_mnist':
+ transform = transforms.Compose([
+ transforms.Resize(32),
+ transforms.ToTensor(),
+ ])
+ train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
+ test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
+ num_classes = 10
+ input_shape = (1, 32, 32)
+
+ elif name == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ])
+ transform_test = transforms.Compose([transforms.ToTensor()])
+ train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
+ test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test)
+ num_classes = 10
+ input_shape = (3, 32, 32)
+
+ elif name == 'cifar100':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ])
+ transform_test = transforms.Compose([transforms.ToTensor()])
+ train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
+ test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test)
+ num_classes = 100
+ input_shape = (3, 32, 32)
+
+ else:
+ raise ValueError(f"Unknown dataset: {name}")
+
+ train_loader = DataLoader(
+ train_ds, batch_size=batch_size, shuffle=True,
+ num_workers=num_workers, pin_memory=True
+ )
+ test_loader = DataLoader(
+ test_ds, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True
+ )
+
+ return train_loader, test_loader, num_classes, input_shape
+
+
+# =============================================================================
+# Training
+# =============================================================================
+
+@dataclass
+class TrainingMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ test_loss: float
+ test_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ grad_max_sv: Optional[float] # Max singular value of gradients
+ grad_min_sv: Optional[float] # Min singular value of gradients
+ grad_condition: Optional[float] # Condition number
+ lr: float
+ time_sec: float
+
+
+def compute_gradient_svs(model):
+ """Compute gradient singular value statistics for all weight matrices."""
+ max_svs = []
+ min_svs = []
+
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ with torch.no_grad():
+ G = param.grad.detach()
+ try:
+ sv = torch.linalg.svdvals(G)
+ if len(sv) > 0:
+ max_svs.append(sv[0].item())
+ min_svs.append(sv[-1].item())
+ except Exception:
+ pass
+
+ if not max_svs:
+ return None, None, None
+
+ max_sv = max(max_svs)
+ min_sv = min(min_svs)
+ condition = max_sv / (min_sv + 1e-12)
+
+ return max_sv, min_sv, condition
+
+
+def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float,
+ lyap_threshold: float = 2.0) -> torch.Tensor:
+ """
+ Compute Lyapunov regularization loss with different penalty types.
+
+ Args:
+ lyap_est: Estimated Lyapunov exponent (scalar tensor)
+ reg_type: Type of regularization:
+ - "squared": (λ - target)² - original, aggressive
+ - "hinge": max(0, λ - threshold)² - only penalize chaos
+ - "asymmetric": strong penalty for chaos, weak for collapse
+ - "extreme": only penalize when λ > lyap_threshold (configurable)
+ - "adaptive_linear": penalty scales linearly with excess over threshold
+ - "adaptive_exp": penalty grows exponentially for severe chaos
+ - "adaptive_sigmoid": smooth sigmoid transition around threshold
+ lambda_target: Target value (used for squared reg_type)
+ lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0)
+
+ Returns:
+ Regularization loss (scalar tensor)
+ """
+ if reg_type == "squared":
+ # Original: penalize any deviation from target
+ return (lyap_est - lambda_target) ** 2
+
+ elif reg_type == "hinge":
+ # Only penalize when λ > threshold (too chaotic)
+ # threshold = 0 means: only penalize positive Lyapunov (chaos)
+ threshold = 0.0
+ excess = torch.relu(lyap_est - threshold)
+ return excess ** 2
+
+ elif reg_type == "asymmetric":
+ # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1)
+ # This allows the network to be stable without being dead
+ chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0
+ collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1
+ return chaos_penalty + collapse_penalty
+
+ elif reg_type == "extreme":
+ # Only penalize when λ > threshold (VERY chaotic)
+ # This allows moderate chaos while preventing extreme instability
+ # Threshold is now configurable via lyap_threshold argument
+ excess = torch.relu(lyap_est - lyap_threshold)
+ return excess ** 2
+
+ elif reg_type == "adaptive_linear":
+ # Penalty scales linearly with how far above threshold we are
+ # loss = excess * excess² = excess³
+ # This naturally makes the penalty weaker for small excesses
+ # and much stronger for large excesses
+ excess = torch.relu(lyap_est - lyap_threshold)
+ return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic
+
+ elif reg_type == "adaptive_exp":
+ # Exponential penalty for severe chaos
+ # loss = (exp(excess) - 1) * excess² for excess > 0
+ # This gives very weak penalty near threshold, explosive penalty for chaos
+ excess = torch.relu(lyap_est - lyap_threshold)
+ # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after
+ exp_scale = torch.exp(excess) - 1.0
+ return exp_scale * excess # exp(excess) * excess - excess
+
+ elif reg_type == "adaptive_sigmoid":
+ # Smooth sigmoid transition around threshold
+ # The "sharpness" of transition is controlled by a temperature parameter
+ # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness
+ # Using T=0.5 for moderately sharp transition
+ temperature = 0.5
+ weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature)
+ # Penalize deviation from target, weighted by how far past threshold
+ deviation = lyap_est - lambda_target
+ return weight * (deviation ** 2)
+
+ # =========================================================================
+ # SCALED MULTIPLIER REGULARIZATION
+ # loss = (λ_reg × g(relu(λ))) × relu(λ)
+ # └─────────────────┘ └──────┘
+ # scaled multiplier penalty toward target=0
+ #
+ # The multiplier itself scales with λ, making it mild when λ is small
+ # and aggressive when λ is large.
+ # =========================================================================
+
+ elif reg_type == "mult_linear":
+ # Multiplier scales linearly: g(x) = x
+ # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)²
+ # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0
+ pos_lyap = torch.relu(lyap_est)
+ return pos_lyap * pos_lyap # relu(λ)²
+
+ elif reg_type == "mult_squared":
+ # Multiplier scales quadratically: g(x) = x²
+ # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³
+ # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0
+ pos_lyap = torch.relu(lyap_est)
+ return pos_lyap * pos_lyap * pos_lyap # relu(λ)³
+
+ elif reg_type == "mult_log":
+ # Multiplier scales logarithmically: g(x) = log(1+x)
+ # loss = (λ_reg × log(1+relu(λ))) × relu(λ)
+ # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16
+ pos_lyap = torch.relu(lyap_est)
+ return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ
+
+ else:
+ raise ValueError(f"Unknown reg_type: {reg_type}")
+
+
+def train_epoch(
+ model, loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps,
+ progress=True, compute_sv_every=10,
+ reg_type="squared", current_lambda_reg=None,
+ lyap_threshold=2.0
+):
+ """
+ Train one epoch.
+
+ Args:
+ current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg.
+ reg_type: "squared", "hinge", "asymmetric", or "extreme"
+ lyap_threshold: Threshold for extreme reg_type
+ """
+ model.train()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+ lyap_vals = []
+ grad_norms = []
+ grad_max_svs = []
+ grad_min_svs = []
+ grad_conditions = []
+
+ # Use warmup value if provided
+ effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for batch_idx, (x, y) in enumerate(iterator):
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps)
+
+ loss = criterion(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold)
+ loss = loss + effective_lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan'), None, None, None
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5
+ grad_norms.append(grad_norm)
+
+ # Compute gradient SVs periodically (expensive)
+ if batch_idx % compute_sv_every == 0:
+ max_sv, min_sv, cond = compute_gradient_svs(model)
+ if max_sv is not None:
+ grad_max_svs.append(max_sv)
+ grad_min_svs.append(min_sv)
+ grad_conditions.append(cond)
+
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ return (
+ total_loss / total,
+ correct / total,
+ np.mean(lyap_vals) if lyap_vals else None,
+ np.mean(grad_norms),
+ np.mean(grad_max_svs) if grad_max_svs else None,
+ np.mean(grad_min_svs) if grad_min_svs else None,
+ np.mean(grad_conditions) if grad_conditions else None,
+ )
+
+
+@torch.no_grad()
+def evaluate(model, loader, criterion, device, progress=True):
+ model.eval()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+
+ iterator = tqdm(loader, desc="eval", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+
+ loss = criterion(logits, y)
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ return total_loss / total, correct / total
+
+
+def run_single_config(
+ dataset_name: str,
+ depth_config: Tuple[int, int], # (num_stages, blocks_per_stage)
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+ reg_type: str = "squared",
+ warmup_epochs: int = 0,
+ stable_init: bool = False,
+ lyap_threshold: float = 2.0,
+) -> List[TrainingMetrics]:
+ """Run training for a single configuration."""
+ torch.manual_seed(seed)
+
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ model = SpikingVGG(
+ in_channels=in_channels,
+ num_classes=num_classes,
+ base_channels=64,
+ num_stages=num_stages,
+ blocks_per_stage=blocks_per_stage,
+ T=T,
+ stable_init=stable_init,
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ method = "Lyapunov" if use_lyapunov else "Vanilla"
+ print(f" {method}: depth={total_depth}, params={num_params:,}")
+
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ criterion = nn.CrossEntropyLoss()
+
+ history = []
+ best_acc = 0.0
+
+ for epoch in range(1, epochs + 1):
+ t0 = time.time()
+
+ # Warmup: gradually increase lambda_reg
+ if warmup_epochs > 0 and epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, lambda_target, 1e-4, progress,
+ reg_type=reg_type, current_lambda_reg=current_lambda_reg,
+ lyap_threshold=lyap_threshold
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_acc = max(best_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == epochs:
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else ""
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}")
+
+ if np.isnan(train_loss):
+ print(f" DIVERGED at epoch {epoch}")
+ break
+
+ print(f" Best test acc: {best_acc:.3f}")
+ return history
+
+
+def run_depth_scaling_experiment(
+ dataset_name: str,
+ depth_configs: List[Tuple[int, int]],
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool,
+ reg_type: str = "squared",
+ warmup_epochs: int = 0,
+ stable_init: bool = False,
+ lyap_threshold: float = 2.0,
+) -> Dict:
+ """Run full depth scaling experiment."""
+
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ print(f"Regularization type: {reg_type}")
+ print(f"Warmup epochs: {warmup_epochs}")
+ print(f"Stable init: {stable_init}")
+ print(f"Lyapunov threshold: {lyap_threshold}")
+
+ for depth_config in depth_configs:
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ print(f"\n{'='*60}")
+ print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+
+ history = run_single_config(
+ dataset_name=dataset_name,
+ depth_config=depth_config,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=T,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ device=device,
+ seed=seed,
+ progress=progress,
+ reg_type=reg_type,
+ warmup_epochs=warmup_epochs,
+ stable_init=stable_init,
+ lyap_threshold=lyap_threshold,
+ )
+
+ results[method][total_depth] = history
+
+ return results
+
+
+def print_summary(results: Dict, dataset_name: str):
+ """Print final summary table."""
+ print("\n" + "=" * 100)
+ print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}")
+ print("=" * 100)
+ print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}")
+ print("-" * 100)
+
+ depths = sorted(results["vanilla"].keys())
+
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}"
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+ lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A"
+
+ van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A"
+ lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A"
+ van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A"
+
+ print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}")
+
+ print("=" * 100)
+
+ # Gradient health analysis
+ print("\nGRADIENT HEALTH ANALYSIS:")
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_cond = van.grad_condition if van.grad_condition else 0
+ lyap_cond = lyap.grad_condition if lyap.grad_condition else 0
+
+ status = ""
+ if van_cond > 1e6:
+ status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)"
+ elif van_cond > 1e4:
+ status = "~ Vanilla has moderately ill-conditioned gradients"
+
+ if status:
+ print(f" Depth {depth}: {status}")
+
+ print("")
+
+ # Analysis
+ print("\nKEY OBSERVATIONS:")
+ shallow = min(depths)
+ deep = max(depths)
+
+ van_shallow = results["vanilla"][shallow][-1].test_acc
+ van_deep = results["vanilla"][deep][-1].test_acc
+ lyap_shallow = results["lyapunov"][shallow][-1].test_acc
+ lyap_deep = results["lyapunov"][deep][-1].test_acc
+
+ van_gain = van_deep - van_shallow
+ lyap_gain = lyap_deep - lyap_shallow
+
+ print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change")
+ print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change")
+
+ if lyap_gain > van_gain + 0.02:
+ print(f" ✓ Lyapunov regularization enables better depth scaling!")
+ elif lyap_gain > van_gain:
+ print(f" ~ Lyapunov shows slight improvement in depth scaling")
+ else:
+ print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range")
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, history in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in history]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs")
+
+ p.add_argument("--dataset", type=str, default="cifar100",
+ choices=["mnist", "fashion_mnist", "cifar10", "cifar100"])
+ p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16],
+ help="Total conv layer depths to test")
+ p.add_argument("--T", type=int, default=4, help="Timesteps")
+ p.add_argument("--epochs", type=int, default=100)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--lambda_reg", type=float, default=0.3)
+ p.add_argument("--lambda_target", type=float, default=-0.1)
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--out_dir", type=str, default="runs/depth_scaling")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--no-progress", action="store_true")
+ p.add_argument("--reg_type", type=str, default="squared",
+ choices=["squared", "hinge", "asymmetric", "extreme"],
+ help="Lyapunov regularization type")
+ p.add_argument("--warmup_epochs", type=int, default=0,
+ help="Epochs to warmup lambda_reg (0 = no warmup)")
+ p.add_argument("--stable_init", action="store_true",
+ help="Use stability-aware weight initialization")
+ p.add_argument("--lyap_threshold", type=float, default=2.0,
+ help="Threshold for extreme reg_type (only penalize λ > threshold)")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("DEPTH SCALING BENCHMARK")
+ print("=" * 80)
+ print(f"Dataset: {args.dataset}")
+ print(f"Depths: {args.depths}")
+ print(f"Timesteps: {args.T}")
+ print(f"Epochs: {args.epochs}")
+ print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}")
+ print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Load data
+ print(f"\nLoading {args.dataset}...")
+ train_loader, test_loader, num_classes, input_shape = get_dataset(
+ args.dataset, args.data_dir, args.batch_size
+ )
+ in_channels = input_shape[0]
+ print(f"Classes: {num_classes}, Input: {input_shape}")
+ print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
+
+ # Convert depths to (num_stages, blocks_per_stage) configs
+ # We use 4 stages (3 for smaller nets), adjust blocks_per_stage
+ depth_configs = []
+ for d in args.depths:
+ if d <= 4:
+ depth_configs.append((d, 1)) # d stages, 1 block each
+ elif d <= 8:
+ depth_configs.append((4, d // 4)) # 4 stages
+ else:
+ depth_configs.append((4, d // 4)) # 4 stages, more blocks
+
+ print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}")
+
+ # Run experiment
+ results = run_depth_scaling_experiment(
+ dataset_name=args.dataset,
+ depth_configs=depth_configs,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=args.T,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ reg_type=args.reg_type,
+ warmup_epochs=args.warmup_epochs,
+ stable_init=args.stable_init,
+ lyap_threshold=args.lyap_threshold,
+ )
+
+ # Summary
+ print_summary(results, args.dataset)
+
+ # Save
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}")
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/hyperparameter_grid_search.py b/files/experiments/hyperparameter_grid_search.py
new file mode 100644
index 0000000..011387f
--- /dev/null
+++ b/files/experiments/hyperparameter_grid_search.py
@@ -0,0 +1,597 @@
+"""
+Hyperparameter Grid Search for Lyapunov-Regularized SNNs.
+
+Goal: Find optimal (lambda_reg, lambda_target) for each network depth
+ and derive an adaptive curve for automatic hyperparameter selection.
+
+Usage:
+ python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Tuple
+from itertools import product
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+
+
+@dataclass
+class GridSearchResult:
+ """Result from a single grid search configuration."""
+ depth: int
+ lambda_reg: float
+ lambda_target: float
+ final_train_acc: float
+ final_val_acc: float
+ final_lyapunov: float
+ final_grad_norm: float
+ converged: bool # Did training succeed (not NaN)?
+ epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never)
+
+
+def create_synthetic_data(
+ n_train: int = 2000,
+ n_val: int = 500,
+ T: int = 50,
+ D: int = 100,
+ n_classes: int = 10,
+ seed: int = 42,
+ batch_size: int = 128,
+) -> Tuple[DataLoader, DataLoader, int, int, int]:
+ """Create synthetic spike data."""
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ def generate_data(n_samples):
+ x = torch.zeros(n_samples, T, D)
+ y = torch.randint(0, n_classes, (n_samples,))
+ for i in range(n_samples):
+ label = y[i].item()
+ base_rate = 0.05 + 0.02 * label
+ class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes))
+ for t in range(T):
+ x[i, t] = (torch.rand(D) < base_rate).float()
+ for c in class_channels:
+ if torch.rand(1) < base_rate * 3:
+ x[i, t, c] = 1.0
+ return x, y
+
+ x_train, y_train = generate_data(n_train)
+ x_val, y_val = generate_data(n_val)
+
+ train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
+ val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False)
+
+ return train_loader, val_loader, T, D, n_classes
+
+
+def train_and_evaluate(
+ depth: int,
+ lambda_reg: float,
+ lambda_target: float,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ device: torch.device,
+ seed: int = 42,
+ warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early
+) -> GridSearchResult:
+ """Train a single configuration and return results."""
+ torch.manual_seed(seed)
+
+ # Create model
+ hidden_dims = [hidden_dim] * depth
+ model = LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=0.9,
+ threshold=1.0,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ best_val_acc = 0.0
+ epochs_to_90 = -1
+ final_lyap = 0.0
+ final_grad = 0.0
+ converged = True
+
+ for epoch in range(1, epochs + 1):
+ # Warmup: gradually increase lambda_reg
+ if epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ # Training
+ model.train()
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False)
+
+ ce = ce_loss(logits, y)
+ if lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + current_lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ converged = False
+ break
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+ optimizer.step()
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ if not converged:
+ break
+
+ train_acc = total_correct / total_samples
+ final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0
+ final_grad = np.mean(grad_norms) if grad_norms else 0.0
+
+ # Track epochs to 90% accuracy
+ if epochs_to_90 < 0 and train_acc >= 0.9:
+ epochs_to_90 = epoch
+
+ # Validation
+ model.eval()
+ val_correct = 0
+ val_total = 0
+ with torch.no_grad():
+ for x, y in val_loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+ preds = logits.argmax(dim=1)
+ val_correct += (preds == y).sum().item()
+ val_total += x.size(0)
+ val_acc = val_correct / val_total
+ best_val_acc = max(best_val_acc, val_acc)
+
+ return GridSearchResult(
+ depth=depth,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ final_train_acc=train_acc if converged else 0.0,
+ final_val_acc=best_val_acc if converged else 0.0,
+ final_lyapunov=final_lyap,
+ final_grad_norm=final_grad,
+ converged=converged,
+ epochs_to_90pct=epochs_to_90,
+ )
+
+
+def run_grid_search(
+ depths: List[int],
+ lambda_regs: List[float],
+ lambda_targets: List[float],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ device: torch.device,
+ seed: int = 42,
+ progress: bool = True,
+) -> List[GridSearchResult]:
+ """Run full grid search."""
+ results = []
+
+ # Total configurations
+ configs = list(product(depths, lambda_regs, lambda_targets))
+ total = len(configs)
+
+ iterator = tqdm(configs, desc="Grid Search", disable=not progress)
+
+ for depth, lambda_reg, lambda_target in iterator:
+ if progress:
+ iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target})
+
+ result = train_and_evaluate(
+ depth=depth,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ hidden_dim=hidden_dim,
+ epochs=epochs,
+ lr=lr,
+ device=device,
+ seed=seed,
+ )
+ results.append(result)
+
+ if progress:
+ iterator.set_postfix({
+ "d": depth,
+ "λr": lambda_reg,
+ "λt": lambda_target,
+ "acc": f"{result.final_val_acc:.2f}"
+ })
+
+ return results
+
+
+def analyze_results(results: List[GridSearchResult]) -> Dict:
+ """Analyze grid search results and find optimal hyperparameters per depth."""
+
+ # Group by depth
+ by_depth = {}
+ for r in results:
+ if r.depth not in by_depth:
+ by_depth[r.depth] = []
+ by_depth[r.depth].append(r)
+
+ analysis = {
+ "optimal_per_depth": {},
+ "all_results": [asdict(r) for r in results],
+ }
+
+ print("\n" + "=" * 80)
+ print("GRID SEARCH ANALYSIS")
+ print("=" * 80)
+
+ # Find optimal for each depth
+ print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}")
+ print("-" * 80)
+
+ optimal_lambda_regs = []
+ optimal_lambda_targets = []
+ depths_list = []
+
+ for depth in sorted(by_depth.keys()):
+ depth_results = by_depth[depth]
+ # Find best by validation accuracy
+ best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0)
+
+ analysis["optimal_per_depth"][depth] = {
+ "lambda_reg": best.lambda_reg,
+ "lambda_target": best.lambda_target,
+ "val_acc": best.final_val_acc,
+ "lyapunov": best.final_lyapunov,
+ "epochs_to_90": best.epochs_to_90pct,
+ }
+
+ print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} "
+ f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}")
+
+ if best.final_val_acc > 0.5: # Only use successful runs for curve fitting
+ depths_list.append(depth)
+ optimal_lambda_regs.append(best.lambda_reg)
+ optimal_lambda_targets.append(best.lambda_target)
+
+ # Fit adaptive curves
+ print("\n" + "=" * 80)
+ print("ADAPTIVE HYPERPARAMETER CURVES")
+ print("=" * 80)
+
+ if len(depths_list) >= 3:
+ # Fit polynomial curves
+ depths_arr = np.array(depths_list)
+ lambda_regs_arr = np.array(optimal_lambda_regs)
+ lambda_targets_arr = np.array(optimal_lambda_targets)
+
+ # Fit lambda_reg vs depth (expect increasing with depth)
+ try:
+ reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1))
+ reg_poly = np.poly1d(reg_coeffs)
+ print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}"
+ if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}")
+ analysis["lambda_reg_curve"] = reg_coeffs.tolist()
+ except Exception as e:
+ print(f"Could not fit λ_reg curve: {e}")
+
+ # Fit lambda_target vs depth (expect decreasing / more negative with depth)
+ try:
+ target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1))
+ target_poly = np.poly1d(target_coeffs)
+ print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}"
+ if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}")
+ analysis["lambda_target_curve"] = target_coeffs.tolist()
+ except Exception as e:
+ print(f"Could not fit λ_target curve: {e}")
+
+ # Print recommendations
+ print("\n" + "-" * 80)
+ print("RECOMMENDED HYPERPARAMETERS BY DEPTH:")
+ print("-" * 80)
+ for d in [2, 4, 6, 8, 10, 12, 14, 16]:
+ rec_reg = max(0.01, reg_poly(d))
+ rec_target = min(0.0, target_poly(d))
+ print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}")
+
+ else:
+ print("Not enough successful runs to fit curves")
+
+ return analysis
+
+
+def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict):
+ """Save grid search results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f:
+ json.dump(analysis, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def plot_grid_search(results: List[GridSearchResult], output_dir: str):
+ """Generate visualization of grid search results."""
+ try:
+ import matplotlib.pyplot as plt
+ except ImportError:
+ print("matplotlib not available, skipping plots")
+ return
+
+ # Group by depth
+ by_depth = {}
+ for r in results:
+ if r.depth not in by_depth:
+ by_depth[r.depth] = []
+ by_depth[r.depth].append(r)
+
+ depths = sorted(by_depth.keys())
+
+ # Get unique lambda values
+ lambda_regs = sorted(set(r.lambda_reg for r in results))
+ lambda_targets = sorted(set(r.lambda_target for r in results))
+
+ # Create heatmaps for each depth
+ n_depths = len(depths)
+ fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10))
+ axes = axes.flatten()
+
+ for idx, depth in enumerate(depths):
+ ax = axes[idx]
+ depth_results = by_depth[depth]
+
+ # Create accuracy matrix
+ acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs)))
+ for r in depth_results:
+ i = lambda_targets.index(r.lambda_target)
+ j = lambda_regs.index(r.lambda_reg)
+ acc_matrix[i, j] = r.final_val_acc
+
+ im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
+ ax.set_xticks(range(len(lambda_regs)))
+ ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45)
+ ax.set_yticks(range(len(lambda_targets)))
+ ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets])
+ ax.set_xlabel("λ_reg")
+ ax.set_ylabel("λ_target")
+ ax.set_title(f"Depth {depth}")
+
+ # Mark best
+ best = max(depth_results, key=lambda r: r.final_val_acc)
+ bi = lambda_targets.index(best.lambda_target)
+ bj = lambda_regs.index(best.lambda_reg)
+ ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2)
+
+ # Add colorbar
+ plt.colorbar(im, ax=ax, label='Val Acc')
+
+ # Hide unused subplots
+ for idx in range(len(depths), len(axes)):
+ axes[idx].axis('off')
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight')
+ plt.close()
+
+ # Plot optimal hyperparameters vs depth
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
+
+ optimal_regs = []
+ optimal_targets = []
+ optimal_accs = []
+ for depth in depths:
+ best = max(by_depth[depth], key=lambda r: r.final_val_acc)
+ optimal_regs.append(best.lambda_reg)
+ optimal_targets.append(best.lambda_target)
+ optimal_accs.append(best.final_val_acc)
+
+ axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8)
+ axes[0].set_xlabel("Network Depth")
+ axes[0].set_ylabel("Optimal λ_reg")
+ axes[0].set_title("Optimal Regularization Strength vs Depth")
+ axes[0].grid(True, alpha=0.3)
+
+ axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange')
+ axes[1].set_xlabel("Network Depth")
+ axes[1].set_ylabel("Optimal λ_target")
+ axes[1].set_title("Optimal Target Lyapunov vs Depth")
+ axes[1].grid(True, alpha=0.3)
+
+ axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green')
+ axes[2].set_xlabel("Network Depth")
+ axes[2].set_ylabel("Best Validation Accuracy")
+ axes[2].set_title("Best Achievable Accuracy vs Depth")
+ axes[2].set_ylim(0, 1.05)
+ axes[2].grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight')
+ plt.close()
+
+ print(f"Plots saved to {output_dir}")
+
+
+def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'):
+ """Get CIFAR-10 dataloaders with rate encoding for SNN."""
+ from torchvision import datasets, transforms
+
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)
+ val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform)
+
+ # Rate encoding: convert images to spike sequences
+ class RateEncodedDataset(torch.utils.data.Dataset):
+ def __init__(self, dataset, T):
+ self.dataset = dataset
+ self.T = T
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ img, label = self.dataset[idx]
+ # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D)
+ flat = img.view(-1) # (3072,)
+ # Rate encoding: probability of spike = pixel intensity
+ spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float()
+ return spikes, label
+
+ train_encoded = RateEncodedDataset(train_ds, T)
+ val_encoded = RateEncodedDataset(val_ds, T)
+
+ train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4)
+
+ return train_loader, val_loader, T, 3072, 10 # T, D, num_classes
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN")
+
+ # Grid search parameters
+ p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10],
+ help="Network depths to test")
+ p.add_argument("--lambda_regs", type=float, nargs="+",
+ default=[0.01, 0.05, 0.1, 0.2, 0.3],
+ help="Lambda_reg values to test")
+ p.add_argument("--lambda_targets", type=float, nargs="+",
+ default=[0.0, -0.05, -0.1, -0.2],
+ help="Lambda_target values to test")
+
+ # Model parameters
+ p.add_argument("--hidden_dim", type=int, default=256)
+ p.add_argument("--epochs", type=int, default=15)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--seed", type=int, default=42)
+
+ # Data
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)")
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding")
+
+ # Output
+ p.add_argument("--out_dir", type=str, default="runs/grid_search")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("HYPERPARAMETER GRID SEARCH")
+ print("=" * 80)
+ print(f"Depths: {args.depths}")
+ print(f"λ_reg values: {args.lambda_regs}")
+ print(f"λ_target values: {args.lambda_targets}")
+ print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}")
+ print(f"Epochs per config: {args.epochs}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Load data
+ if args.synthetic:
+ print("\nUsing synthetic data")
+ train_loader, val_loader, T, D, C = create_synthetic_data(
+ seed=args.seed, batch_size=args.batch_size
+ )
+ else:
+ print("\nUsing CIFAR-10 with rate encoding")
+ train_loader, val_loader, T, D, C = get_cifar10_loaders(
+ batch_size=args.batch_size,
+ T=args.T,
+ data_dir=args.data_dir
+ )
+
+ print(f"Data: T={T}, D={D}, classes={C}\n")
+
+ # Run grid search
+ results = run_grid_search(
+ depths=args.depths,
+ lambda_regs=args.lambda_regs,
+ lambda_targets=args.lambda_targets,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=D,
+ num_classes=C,
+ hidden_dim=args.hidden_dim,
+ epochs=args.epochs,
+ lr=args.lr,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Analyze results
+ analysis = analyze_results(results)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, analysis, output_dir, vars(args))
+
+ # Generate plots
+ plot_grid_search(results, output_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/lyapunov_diffonly_benchmark.py b/files/experiments/lyapunov_diffonly_benchmark.py
new file mode 100644
index 0000000..05dbcd2
--- /dev/null
+++ b/files/experiments/lyapunov_diffonly_benchmark.py
@@ -0,0 +1,590 @@
+"""
+Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation.
+
+Optimization B: Instead of storing two full membrane trajectories:
+ mems[i][0] = base trajectory
+ mems[i][1] = perturbed trajectory
+
+Store only:
+ base_mems[i] = base trajectory
+ delta_mems[i] = perturbation (perturbed - base)
+
+Benefits:
+ - ~2x less memory for membrane states
+ - Fewer memory reads/writes during renormalization
+ - Better cache utilization
+"""
+
+import os
+import sys
+import time
+import torch
+import torch.nn as nn
+from typing import Tuple, Optional, List
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+class SpikingVGGBlock(nn.Module):
+ """Conv-BN-LIF block."""
+
+ def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None):
+ super().__init__()
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+
+ def forward(self, x, mem):
+ h = self.bn(self.conv(x))
+ spk, mem = self.lif(h, mem)
+ return spk, mem
+
+
+class SpikingVGG_Original(nn.Module):
+ """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W)."""
+
+ def __init__(self, in_channels=3, num_classes=100, base_channels=64,
+ num_stages=3, blocks_per_stage=2, T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+
+ # Build stages
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ in_ch = in_channels
+ out_ch = base_channels
+ current_size = 32 # CIFAR
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for _ in range(blocks_per_stage):
+ stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta))
+ in_ch = out_ch
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ current_size //= 2
+ if stage < num_stages - 1:
+ out_ch = min(out_ch * 2, 512)
+
+ self.fc = nn.Linear(in_ch * current_size * current_size, num_classes)
+ self._channel_sizes = self._compute_channel_sizes(base_channels)
+
+ def _compute_channel_sizes(self, base):
+ sizes = []
+ ch = base
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ sizes.append(ch)
+ if stage < self.num_stages - 1:
+ ch = min(ch * 2, 512)
+ return sizes
+
+ def _init_mems(self, batch_size, device, dtype, P=1):
+ mems = []
+ H, W = 32, 32
+ for stage in range(self.num_stages):
+ for block_idx in range(self.blocks_per_stage):
+ layer_idx = stage * self.blocks_per_stage + block_idx
+ ch = self._channel_sizes[layer_idx]
+ mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ return mems
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ P = 2 if compute_lyapunov else 1
+
+ mems = self._init_mems(B, device, dtype, P=P)
+
+ if compute_lyapunov:
+ for i in range(len(mems)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0])
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = None
+
+ for t in range(self.T):
+ mem_idx = 0
+ new_mems = []
+ is_first_block = True
+
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ h_conv = block.bn(block.conv(x))
+ h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ spk_flat, mem_new_flat = block.lif(h_flat, mem_flat)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+ h = spk
+ new_mems.append(mem_new)
+ is_first_block = False
+ else:
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ h_conv = block.bn(block.conv(h_flat))
+ spk_flat, mem_new_flat = block.lif(h_conv, mem_flat)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+ h = spk
+ new_mems.append(mem_new)
+ mem_idx += 1
+
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ h_pooled = pool(h_flat)
+ h = h_pooled.view(P, B, *h_pooled.shape[1:])
+
+ mems = new_mems
+
+ h_orig = h[0].view(B, -1)
+ if spike_sum is None:
+ spike_sum = h_orig
+ else:
+ spike_sum = spike_sum + h_orig
+
+ if compute_lyapunov:
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ scale = (lyap_eps / delta).view(B, 1, 1, 1)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ mems[i] = torch.stack([
+ new_mems[i][0],
+ new_mems[i][0] + diff * scale
+ ], dim=0)
+
+ logits = self.fc(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+class SpikingVGG_DiffOnly(nn.Module):
+ """
+ Optimized implementation: stores base + diff instead of 2 full trajectories.
+
+ Memory layout:
+ base_mems[i]: (B, C, H, W) - base trajectory membrane
+ delta_mems[i]: (B, C, H, W) - perturbation vector
+
+ Perturbed trajectory is materialized as (base + delta) only when needed.
+ """
+
+ def __init__(self, in_channels=3, num_classes=100, base_channels=64,
+ num_stages=3, blocks_per_stage=2, T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ in_ch = in_channels
+ out_ch = base_channels
+ current_size = 32
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for _ in range(blocks_per_stage):
+ stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta))
+ in_ch = out_ch
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ current_size //= 2
+ if stage < num_stages - 1:
+ out_ch = min(out_ch * 2, 512)
+
+ self.fc = nn.Linear(in_ch * current_size * current_size, num_classes)
+ self._channel_sizes = self._compute_channel_sizes(base_channels)
+
+ def _compute_channel_sizes(self, base):
+ sizes = []
+ ch = base
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ sizes.append(ch)
+ if stage < self.num_stages - 1:
+ ch = min(ch * 2, 512)
+ return sizes
+
+ def _init_mems(self, batch_size, device, dtype):
+ """Initialize base membrane states (B, C, H, W)."""
+ base_mems = []
+ H, W = 32, 32
+ for stage in range(self.num_stages):
+ for block_idx in range(self.blocks_per_stage):
+ layer_idx = stage * self.blocks_per_stage + block_idx
+ ch = self._channel_sizes[layer_idx]
+ base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ return base_mems
+
+ def _init_deltas(self, base_mems, lyap_eps):
+ """Initialize perturbation vectors δ with ||δ||_global = eps."""
+ delta_mems = []
+ for base in base_mems:
+ delta_mems.append(lyap_eps * torch.randn_like(base))
+ return delta_mems
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+
+ # Initialize base membrane states
+ base_mems = self._init_mems(B, device, dtype)
+
+ # Initialize perturbations if computing Lyapunov
+ if compute_lyapunov:
+ delta_mems = self._init_deltas(base_mems, lyap_eps)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+ else:
+ delta_mems = None
+
+ spike_sum = None
+
+ for t in range(self.T):
+ mem_idx = 0
+ new_base_mems = []
+ new_delta_mems = [] if compute_lyapunov else None
+
+ # Track spikes for base and perturbed (if computing Lyapunov)
+ h_base = None
+ h_delta = None # Will store (h_perturbed - h_base)
+ is_first_block = True
+
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ # First block: input x is same for both trajectories
+ h_conv = block.bn(block.conv(x)) # (B, C, H, W)
+
+ # Base trajectory
+ spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx])
+ new_base_mems.append(mem_base_new)
+ h_base = spk_base
+
+ if compute_lyapunov:
+ # Perturbed trajectory: mem = base + delta
+ mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx]
+ spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed)
+ # Store delta for new membrane
+ new_delta_mems.append(mem_perturbed_new - mem_base_new)
+ # Store spike difference for propagation
+ h_delta = spk_perturbed - spk_base
+
+ is_first_block = False
+ else:
+ # Subsequent blocks: inputs differ
+ # Base trajectory
+ h_conv_base = block.bn(block.conv(h_base))
+ spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx])
+ new_base_mems.append(mem_base_new)
+
+ if compute_lyapunov:
+ # Perturbed trajectory: h_perturbed = h_base + h_delta
+ h_perturbed = h_base + h_delta
+ h_conv_perturbed = block.bn(block.conv(h_perturbed))
+ mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx]
+ spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed)
+ new_delta_mems.append(mem_perturbed_new - mem_base_new)
+ h_delta = spk_perturbed - spk_base
+
+ h_base = spk_base
+
+ mem_idx += 1
+
+ # Pooling
+ h_base = pool(h_base)
+ if compute_lyapunov:
+ # Pool both and compute new delta
+ h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general
+ # But for AvgPool, it's linear so this is fine
+ h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool
+ h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear)
+
+ # Update membrane states
+ base_mems = new_base_mems
+
+ # Accumulate spikes from base trajectory
+ h_flat = h_base.view(B, -1)
+ if spike_sum is None:
+ spike_sum = h_flat
+ else:
+ spike_sum = spike_sum + h_flat
+
+ # Lyapunov: compute global divergence and renormalize
+ if compute_lyapunov:
+ # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||²
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for delta in new_delta_mems:
+ delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3))
+
+ delta_norm = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12)
+
+ # Renormalize: scale all deltas so ||δ||_global = eps
+ scale = (lyap_eps / delta_norm).view(B, 1, 1, 1)
+ delta_mems = [delta * scale for delta in new_delta_mems]
+
+ logits = self.fc(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters())
+
+
+def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20):
+ """Benchmark forward pass time."""
+ device = x.device
+
+ # Warmup
+ for _ in range(num_warmup):
+ with torch.no_grad():
+ _ = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+
+ # Timed runs
+ times = []
+ for _ in range(num_runs):
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ times.append(end - start)
+
+ return times, lyap
+
+
+def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov,
+ lambda_reg=0.3, num_warmup=5, num_runs=20):
+ """Benchmark forward + backward pass time."""
+ device = x.device
+
+ # Warmup
+ for _ in range(num_warmup):
+ model.zero_grad()
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+ loss = criterion(logits, y)
+ if compute_lyapunov and lyap is not None:
+ loss = loss + lambda_reg * (lyap ** 2)
+ loss.backward()
+
+ torch.cuda.synchronize()
+
+ # Timed runs
+ times = []
+ for _ in range(num_runs):
+ model.zero_grad()
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+ loss = criterion(logits, y)
+ if compute_lyapunov and lyap is not None:
+ loss = loss + lambda_reg * (lyap ** 2)
+ loss.backward()
+
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ times.append(end - start)
+
+ return times
+
+
+def measure_memory(model, x, compute_lyapunov):
+ """Measure peak GPU memory during forward pass."""
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.synchronize()
+
+ with torch.no_grad():
+ _ = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+ peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB
+ return peak_mem
+
+
+def run_benchmark():
+ print("=" * 70)
+ print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage")
+ print("=" * 70)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+
+ if device.type == "cuda":
+ print(f"GPU: {torch.cuda.get_device_name()}")
+
+ # Test configurations
+ configs = [
+ {"depth": 4, "blocks_per_stage": 1, "batch_size": 64},
+ {"depth": 8, "blocks_per_stage": 2, "batch_size": 64},
+ {"depth": 12, "blocks_per_stage": 3, "batch_size": 32},
+ ]
+
+ print("\n" + "=" * 70)
+
+ for cfg in configs:
+ depth = cfg["depth"]
+ blocks = cfg["blocks_per_stage"]
+ batch_size = cfg["batch_size"]
+
+ print(f"\n{'='*70}")
+ print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}")
+ print(f"{'='*70}")
+
+ # Create models
+ model_orig = SpikingVGG_Original(
+ blocks_per_stage=blocks, T=4
+ ).to(device)
+
+ model_diff = SpikingVGG_DiffOnly(
+ blocks_per_stage=blocks, T=4
+ ).to(device)
+
+ # Copy weights from original to diff-only
+ model_diff.load_state_dict(model_orig.state_dict())
+
+ print(f"Parameters: {count_parameters(model_orig):,}")
+
+ # Create input
+ x = torch.randn(batch_size, 3, 32, 32, device=device)
+ y = torch.randint(0, 100, (batch_size,), device=device)
+ criterion = nn.CrossEntropyLoss()
+
+ # ============================================================
+ # Test 1: Verify outputs match
+ # ============================================================
+ print("\n--- Output Verification ---")
+ model_orig.eval()
+ model_diff.eval()
+
+ torch.manual_seed(42)
+ with torch.no_grad():
+ logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4)
+
+ torch.manual_seed(42)
+ with torch.no_grad():
+ logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4)
+
+ logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5)
+ lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations
+
+ print(f"Logits match: {logits_match}")
+ print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}")
+ print(f"Lyapunov close (within 0.1): {lyap_close}")
+
+ # ============================================================
+ # Test 2: Forward-only speed (no grad)
+ # ============================================================
+ print("\n--- Forward Speed (no_grad) ---")
+ model_orig.eval()
+ model_diff.eval()
+
+ # Without Lyapunov
+ times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False)
+ times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False)
+
+ mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000
+ mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000
+
+ print(f" Without Lyapunov:")
+ print(f" Original: {mean_orig:.2f} ms")
+ print(f" Diff-only: {mean_diff:.2f} ms")
+
+ # With Lyapunov
+ times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True)
+ times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True)
+
+ mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000
+ mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000
+ speedup = mean_orig_ly / mean_diff_ly
+
+ print(f" With Lyapunov:")
+ print(f" Original: {mean_orig_ly:.2f} ms")
+ print(f" Diff-only: {mean_diff_ly:.2f} ms")
+ print(f" Speedup: {speedup:.2f}x")
+
+ # ============================================================
+ # Test 3: Forward + Backward speed (training mode)
+ # ============================================================
+ print("\n--- Forward+Backward Speed (training) ---")
+ model_orig.train()
+ model_diff.train()
+
+ times_orig_train = benchmark_forward_backward(
+ model_orig, x, y, criterion, compute_lyapunov=True
+ )
+ times_diff_train = benchmark_forward_backward(
+ model_diff, x, y, criterion, compute_lyapunov=True
+ )
+
+ mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000
+ mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000
+ speedup_train = mean_orig_train / mean_diff_train
+
+ print(f" With Lyapunov + backward:")
+ print(f" Original: {mean_orig_train:.2f} ms")
+ print(f" Diff-only: {mean_diff_train:.2f} ms")
+ print(f" Speedup: {speedup_train:.2f}x")
+
+ # ============================================================
+ # Test 4: Memory usage
+ # ============================================================
+ if device.type == "cuda":
+ print("\n--- Peak GPU Memory ---")
+
+ mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False)
+ mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False)
+
+ mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True)
+ mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True)
+
+ print(f" Without Lyapunov:")
+ print(f" Original: {mem_orig_noly:.1f} MB")
+ print(f" Diff-only: {mem_diff_noly:.1f} MB")
+ print(f" With Lyapunov:")
+ print(f" Original: {mem_orig_ly:.1f} MB")
+ print(f" Diff-only: {mem_diff_ly:.1f} MB")
+ print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)")
+
+ # Cleanup
+ del model_orig, model_diff, x, y
+ torch.cuda.empty_cache()
+
+ print("\n" + "=" * 70)
+ print("BENCHMARK COMPLETE")
+ print("=" * 70)
+
+
+if __name__ == "__main__":
+ run_benchmark()
diff --git a/files/experiments/lyapunov_speedup_benchmark.py b/files/experiments/lyapunov_speedup_benchmark.py
new file mode 100644
index 0000000..117009b
--- /dev/null
+++ b/files/experiments/lyapunov_speedup_benchmark.py
@@ -0,0 +1,638 @@
+"""
+Lyapunov Computation Speedup Benchmark
+
+Tests different optimization approaches for computing Lyapunov exponents
+during SNN training. All approaches should produce equivalent results
+(within numerical precision) but with different performance characteristics.
+
+Approaches tested:
+- Baseline: Current sequential implementation
+- Approach A: Trajectory-as-batch (P=2), share first Linear
+- Approach B: Global-norm divergence + single-scale renorm
+- Approach C: torch.compile the time loop
+- Combined: A + B + C together
+"""
+
+import os
+import sys
+import time
+from typing import Tuple, Optional, List
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import snntorch as snn
+from snntorch import surrogate
+
+# Ensure we can import from project
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+
+# =============================================================================
+# Baseline Implementation (Current)
+# =============================================================================
+
+class BaselineSNN(nn.Module):
+ """Current implementation: sequential perturbed trajectory."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ # Simple feedforward for benchmarking (not full VGG)
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims # Flattened input
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1) # Flatten
+
+ # Init membrane potentials
+ mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Original trajectory
+ h = x
+ new_mems = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h = lin(h)
+ spk, mem = lif(h, mems[i])
+ new_mems.append(mem)
+ h = spk
+ mems = new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory (SEPARATE PASS - this is slow)
+ h_p = x
+ new_mems_p = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h_p = lin(h_p)
+ spk_p, mem_p = lif(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # Divergence (per-layer norms, then sum)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq += (diff ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize (per-layer - SLOW)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12
+ new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm
+
+ mems_p = new_mems_p
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach A: Trajectory-as-batch (P=2), share first Linear
+# =============================================================================
+
+class ApproachA_SNN(nn.Module):
+ """Batch both trajectories together, share Linear_1."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ P = 2 if compute_lyapunov else 1
+
+ # State layout: (P, B, H) where P=2 for [original, perturbed]
+ mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ # Initialize perturbed state
+ for i in range(len(self.hidden_dims)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Layer 1: compute Linear ONCE, expand to (P, B, H1)
+ h1 = self.linears[0](x) # (B, H1) - computed ONCE
+
+ if compute_lyapunov:
+ h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view
+ else:
+ h = h1.unsqueeze(0) # (1, B, H1)
+
+ # LIF layer 1
+ spk, mems[0] = self.lifs[0](h, mems[0])
+ h = spk
+
+ # Layers 2+: different inputs for each trajectory
+ for i in range(1, len(self.hidden_dims)):
+ # Reshape to (P*B, H) for batched Linear
+ h_flat = h.reshape(P * B, -1)
+ h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i])
+ spk, mems[i] = self.lifs[i](h_lin, mems[i])
+ h = spk
+
+ # Accumulate spikes from original trajectory only
+ spike_sum = spike_sum + h[0]
+
+ if compute_lyapunov:
+ # Global divergence across all layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0] # (B, H_i)
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # Renormalize with global scale (per-layer still, but simpler)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12
+ mems[i][1] = mems[i][0] + lyap_eps * diff / norm
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach B: Global-norm divergence + single-scale renorm
+# =============================================================================
+
+class ApproachB_SNN(nn.Module):
+ """Global norm for divergence, single scale factor for renorm."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Original trajectory
+ h = x
+ new_mems = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h = lin(h)
+ spk, mem = lif(h, mems[i])
+ new_mems.append(mem)
+ h = spk
+ mems = new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ h_p = x
+ new_mems_p = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h_p = lin(h_p)
+ spk_p, mem_p = lif(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # GLOBAL divergence (one delta per batch element)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # SINGLE SCALE renormalization (key optimization)
+ scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ new_mems_p[i] = new_mems[i] + diff * scale
+
+ mems_p = new_mems_p
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach A+B Combined: Batched trajectories + global renorm
+# =============================================================================
+
+class ApproachAB_SNN(nn.Module):
+ """Combined: trajectory-as-batch + global-norm renorm."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ P = 2 if compute_lyapunov else 1
+
+ # State: (P, B, H)
+ mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ for i in range(len(self.hidden_dims)):
+ mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Layer 1: Linear computed ONCE
+ h1 = self.linears[0](x)
+ h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0)
+
+ spk, mems[0] = self.lifs[0](h, mems[0])
+ h = spk
+
+ # Layers 2+
+ for i in range(1, len(self.hidden_dims)):
+ h_flat = h.reshape(P * B, -1)
+ h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i])
+ spk, mems[i] = self.lifs[i](h_lin, mems[i])
+ h = spk
+
+ spike_sum = spike_sum + h[0]
+
+ if compute_lyapunov:
+ # Global divergence
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # Global scale renorm
+ scale = (lyap_eps / delta).unsqueeze(-1)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ mems[i][1] = mems[i][0] + diff * scale
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach C: torch.compile wrapper
+# =============================================================================
+
+def make_compiled_model(model_class, *args, **kwargs):
+ """Create a model and compile its forward pass."""
+ model = model_class(*args, **kwargs)
+ # Compile the forward method
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ return model
+
+
+# =============================================================================
+# Benchmarking
+# =============================================================================
+
+@dataclass
+class BenchmarkResult:
+ name: str
+ forward_time_ms: float
+ backward_time_ms: float
+ total_time_ms: float
+ lyap_value: float
+ memory_mb: float
+
+ def __str__(self):
+ return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | "
+ f"Bwd: {self.backward_time_ms:7.2f}ms | "
+ f"Total: {self.total_time_ms:7.2f}ms | "
+ f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB")
+
+
+def benchmark_model(
+ model: nn.Module,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ name: str,
+ warmup_iters: int = 5,
+ bench_iters: int = 20,
+) -> BenchmarkResult:
+ """Benchmark a single model configuration."""
+
+ device = x.device
+ criterion = nn.CrossEntropyLoss()
+
+ # Warmup
+ for _ in range(warmup_iters):
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0)
+ loss.backward()
+ model.zero_grad()
+
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ fwd_times = []
+ bwd_times = []
+ lyap_vals = []
+
+ for _ in range(bench_iters):
+ # Forward
+ torch.cuda.synchronize()
+ t0 = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0)
+
+ torch.cuda.synchronize()
+ t1 = time.perf_counter()
+
+ # Backward
+ loss.backward()
+
+ torch.cuda.synchronize()
+ t2 = time.perf_counter()
+
+ fwd_times.append((t1 - t0) * 1000)
+ bwd_times.append((t2 - t1) * 1000)
+ if lyap is not None:
+ lyap_vals.append(lyap.item())
+
+ model.zero_grad()
+
+ peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
+
+ return BenchmarkResult(
+ name=name,
+ forward_time_ms=sum(fwd_times) / len(fwd_times),
+ backward_time_ms=sum(bwd_times) / len(bwd_times),
+ total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times),
+ lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0,
+ memory_mb=peak_mem,
+ )
+
+
+def run_benchmarks(
+ batch_size: int = 64,
+ T: int = 4,
+ hidden_dims: List[int] = [64, 128, 256],
+ device: str = "cuda",
+):
+ """Run all benchmarks and compare."""
+
+ print("=" * 80)
+ print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK")
+ print("=" * 80)
+ print(f"Batch size: {batch_size}")
+ print(f"Timesteps: {T}")
+ print(f"Hidden dims: {hidden_dims}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Create dummy data
+ x = torch.randn(batch_size, 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (batch_size,), device=device)
+
+ results = []
+
+ # 1. Baseline
+ print("\n[1/6] Benchmarking Baseline...")
+ model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "Baseline"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 2. Approach A (batched trajectories)
+ print("[2/6] Benchmarking Approach A (batched)...")
+ model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "A: Batched trajectories"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 3. Approach B (global renorm)
+ print("[3/6] Benchmarking Approach B (global renorm)...")
+ model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "B: Global renorm"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 4. Approach A+B combined
+ print("[4/6] Benchmarking Approach A+B (combined)...")
+ model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "A+B: Combined"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 5. Approach C (torch.compile on baseline)
+ print("[5/6] Benchmarking Approach C (compiled baseline)...")
+ try:
+ model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device)
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10))
+ del model
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f" torch.compile failed: {e}")
+ results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0))
+
+ # 6. A+B+C (all combined)
+ print("[6/6] Benchmarking A+B+C (all optimizations)...")
+ try:
+ model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10))
+ del model
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f" torch.compile failed: {e}")
+ results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0))
+
+ # Print results
+ print("\n" + "=" * 80)
+ print("RESULTS")
+ print("=" * 80)
+
+ baseline_time = results[0].total_time_ms
+
+ for r in results:
+ print(r)
+
+ print("\n" + "-" * 80)
+ print("SPEEDUP vs BASELINE:")
+ print("-" * 80)
+
+ for r in results[1:]:
+ if r.total_time_ms > 0:
+ speedup = baseline_time / r.total_time_ms
+ print(f" {r.name:<25}: {speedup:.2f}x")
+
+ # Verify Lyapunov values are consistent
+ print("\n" + "-" * 80)
+ print("LYAPUNOV VALUE CONSISTENCY CHECK:")
+ print("-" * 80)
+
+ base_lyap = results[0].lyap_value
+ for r in results[1:]:
+ if r.lyap_value != 0:
+ diff = abs(r.lyap_value - base_lyap)
+ status = "✓" if diff < 0.1 else "✗"
+ print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}")
+
+ return results
+
+
+def run_scaling_test(device: str = "cuda"):
+ """Test how approaches scale with batch size and timesteps."""
+
+ print("\n" + "=" * 80)
+ print("SCALING TESTS")
+ print("=" * 80)
+
+ configs = [
+ {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model
+ ]
+
+ print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}")
+ print("-" * 80)
+
+ for cfg in configs:
+ x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (cfg["batch_size"],), device=device)
+
+ # Baseline
+ model_base = BaselineSNN(**cfg).to(device)
+ r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10)
+ del model_base
+
+ # A+B
+ model_ab = ApproachAB_SNN(**cfg).to(device)
+ r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10)
+ del model_ab
+
+ torch.cuda.empty_cache()
+
+ speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0
+
+ cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}"
+ print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--T", type=int, default=4)
+ parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256])
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--scaling", action="store_true", help="Run scaling tests")
+ args = parser.parse_args()
+
+ if not torch.cuda.is_available():
+ print("CUDA not available, using CPU (results will not be representative)")
+ args.device = "cpu"
+
+ # Main benchmark
+ results = run_benchmarks(
+ batch_size=args.batch_size,
+ T=args.T,
+ hidden_dims=args.hidden_dims,
+ device=args.device,
+ )
+
+ # Scaling tests
+ if args.scaling:
+ run_scaling_test(args.device)
diff --git a/files/experiments/plot_depth_comparison.py b/files/experiments/plot_depth_comparison.py
new file mode 100644
index 0000000..2222b7b
--- /dev/null
+++ b/files/experiments/plot_depth_comparison.py
@@ -0,0 +1,305 @@
+"""
+Visualization for depth comparison experiments.
+
+Usage:
+ python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP
+"""
+
+import os
+import sys
+import json
+import argparse
+from typing import Dict, List
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+
+def load_results(results_dir: str) -> Dict:
+ """Load results from JSON file."""
+ with open(os.path.join(results_dir, "results.json"), "r") as f:
+ return json.load(f)
+
+
+def load_config(results_dir: str) -> Dict:
+ """Load config from JSON file."""
+ config_path = os.path.join(results_dir, "config.json")
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ return json.load(f)
+ return {}
+
+
+def plot_training_curves(results: Dict, output_path: str):
+ """
+ Plot training curves for each depth.
+
+ Creates a figure with subplots for each depth showing:
+ - Training loss
+ - Validation accuracy
+ - Lyapunov exponent (if available)
+ - Gradient norm
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+ n_depths = len(depths)
+
+ fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths))
+ if n_depths == 1:
+ axes = axes.reshape(1, -1)
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+ labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"}
+
+ for i, depth in enumerate(depths):
+ for method in ["vanilla", "lyapunov"]:
+ metrics = results[method][str(depth)]
+ epochs = [m["epoch"] for m in metrics]
+
+ # Training Loss
+ train_loss = [m["train_loss"] for m in metrics]
+ axes[i, 0].plot(epochs, train_loss, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 0].set_ylabel("Train Loss")
+ axes[i, 0].set_title(f"Depth={depth}: Training Loss")
+ axes[i, 0].set_yscale("log")
+ axes[i, 0].grid(True, alpha=0.3)
+
+ # Validation Accuracy
+ val_acc = [m["val_acc"] for m in metrics]
+ axes[i, 1].plot(epochs, val_acc, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 1].set_ylabel("Val Accuracy")
+ axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy")
+ axes[i, 1].set_ylim(0, 1)
+ axes[i, 1].grid(True, alpha=0.3)
+
+ # Lyapunov Exponent
+ lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None]
+ lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None]
+ if lyap:
+ axes[i, 2].plot(lyap_epochs, lyap, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ axes[i, 2].set_ylabel("Lyapunov λ")
+ axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent")
+ axes[i, 2].grid(True, alpha=0.3)
+
+ # Gradient Norm
+ grad_norm = [m["grad_norm"] for m in metrics]
+ axes[i, 3].plot(epochs, grad_norm, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 3].set_ylabel("Gradient Norm")
+ axes[i, 3].set_title(f"Depth={depth}: Gradient Norm")
+ axes[i, 3].set_yscale("log")
+ axes[i, 3].grid(True, alpha=0.3)
+
+ # Add legend to first row
+ if i == 0:
+ for ax in axes[i]:
+ ax.legend(loc="upper right")
+
+ # Set x-labels on bottom row
+ for ax in axes[-1]:
+ ax.set_xlabel("Epoch")
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved training curves to {output_path}")
+
+
+def plot_depth_summary(results: Dict, output_path: str):
+ """
+ Plot summary comparing methods across depths.
+
+ Creates a figure showing:
+ - Final validation accuracy vs depth
+ - Final gradient norm vs depth
+ - Final Lyapunov exponent vs depth
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+ markers = {"vanilla": "o", "lyapunov": "s"}
+
+ # Collect final metrics
+ van_acc = []
+ lyap_acc = []
+ van_grad = []
+ lyap_grad = []
+ lyap_lambda = []
+
+ for depth in depths:
+ van_metrics = results["vanilla"][str(depth)][-1]
+ lyap_metrics = results["lyapunov"][str(depth)][-1]
+
+ van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0)
+ lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0)
+
+ van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0)
+ lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0)
+
+ if lyap_metrics["lyapunov"] is not None:
+ lyap_lambda.append(lyap_metrics["lyapunov"])
+ else:
+ lyap_lambda.append(0)
+
+ # Plot 1: Validation Accuracy vs Depth
+ ax = axes[0]
+ ax.plot(depths, van_acc, 'o-', color=colors["vanilla"],
+ label="Vanilla", linewidth=2, markersize=8)
+ ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"],
+ label="Lyapunov", linewidth=2, markersize=8)
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Validation Accuracy")
+ ax.set_title("Accuracy vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05)
+
+ # Plot 2: Gradient Norm vs Depth
+ ax = axes[1]
+ ax.plot(depths, van_grad, 'o-', color=colors["vanilla"],
+ label="Vanilla", linewidth=2, markersize=8)
+ ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"],
+ label="Lyapunov", linewidth=2, markersize=8)
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Gradient Norm")
+ ax.set_title("Gradient Stability vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_yscale("log")
+
+ # Plot 3: Lyapunov Exponent vs Depth
+ ax = axes[2]
+ ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"],
+ linewidth=2, markersize=8)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)")
+ ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region")
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Lyapunov Exponent")
+ ax.set_title("Lyapunov Exponent vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved depth summary to {output_path}")
+
+
+def plot_stability_comparison(results: Dict, output_path: str):
+ """
+ Plot stability metrics comparison.
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+
+ # Collect metrics over training
+ for depth in depths:
+ van_metrics = results["vanilla"][str(depth)]
+ lyap_metrics = results["lyapunov"][str(depth)]
+
+ van_epochs = [m["epoch"] for m in van_metrics]
+ lyap_epochs = [m["epoch"] for m in lyap_metrics]
+
+ # Firing rate
+ van_fr = [m["firing_rate"] for m in van_metrics]
+ lyap_fr = [m["firing_rate"] for m in lyap_metrics]
+ axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+ axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+
+ # Dead neurons
+ van_dead = [m["dead_neurons"] for m in van_metrics]
+ lyap_dead = [m["dead_neurons"] for m in lyap_metrics]
+ axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+ axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+
+ axes[0, 0].set_xlabel("Epoch")
+ axes[0, 0].set_ylabel("Firing Rate")
+ axes[0, 0].set_title("Firing Rate Over Training")
+ axes[0, 0].grid(True, alpha=0.3)
+
+ axes[0, 1].set_xlabel("Epoch")
+ axes[0, 1].set_ylabel("Dead Neuron Fraction")
+ axes[0, 1].set_title("Dead Neurons Over Training")
+ axes[0, 1].grid(True, alpha=0.3)
+
+ # Final metrics bar chart
+ van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths]
+ lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths]
+
+ x = np.arange(len(depths))
+ width = 0.35
+
+ axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"])
+ axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"])
+ axes[1, 0].set_xlabel("Network Depth")
+ axes[1, 0].set_ylabel("Final Validation Accuracy")
+ axes[1, 0].set_title("Final Accuracy Comparison")
+ axes[1, 0].set_xticks(x)
+ axes[1, 0].set_xticklabels(depths)
+ axes[1, 0].legend()
+ axes[1, 0].grid(True, alpha=0.3, axis='y')
+
+ # Improvement percentage
+ improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)]
+ colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements]
+
+ axes[1, 1].bar(x, improvements, color=colors_bar)
+ axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5)
+ axes[1, 1].set_xlabel("Network Depth")
+ axes[1, 1].set_ylabel("Accuracy Improvement")
+ axes[1, 1].set_title("Lyapunov Improvement over Vanilla")
+ axes[1, 1].set_xticks(x)
+ axes[1, 1].set_xticklabels(depths)
+ axes[1, 1].grid(True, alpha=0.3, axis='y')
+
+ # Add legend for line plots
+ custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2),
+ Line2D([0], [0], color=colors["lyapunov"], lw=2)]
+ axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov'])
+ axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov'])
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved stability comparison to {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--results_dir", type=str, required=True,
+ help="Directory containing results.json")
+ parser.add_argument("--output_dir", type=str, default=None,
+ help="Output directory for plots (default: same as results_dir)")
+ args = parser.parse_args()
+
+ output_dir = args.output_dir or args.results_dir
+
+ print(f"Loading results from {args.results_dir}")
+ results = load_results(args.results_dir)
+ config = load_config(args.results_dir)
+
+ print(f"Config: {config}")
+
+ # Generate plots
+ plot_training_curves(results, os.path.join(output_dir, "training_curves.png"))
+ plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png"))
+ plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png"))
+
+ print(f"\nAll plots saved to {output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/posthoc_finetune.py b/files/experiments/posthoc_finetune.py
new file mode 100644
index 0000000..3f3bf6c
--- /dev/null
+++ b/files/experiments/posthoc_finetune.py
@@ -0,0 +1,323 @@
+"""
+Post-hoc Lyapunov Fine-tuning Experiment
+
+Strategy:
+1. Train network with vanilla (no Lyapunov) for N epochs
+2. Then fine-tune with Lyapunov regularization for M epochs
+
+This allows the network to learn task-relevant features first,
+then stabilize dynamics without starting from chaotic initialization.
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from files.experiments.depth_scaling_benchmark import (
+ SpikingVGG,
+ get_dataset,
+ train_epoch,
+ evaluate,
+ TrainingMetrics,
+ compute_lyap_reg_loss,
+)
+
+
+def run_posthoc_experiment(
+ dataset_name: str,
+ depth_config: Tuple[int, int],
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ pretrain_epochs: int,
+ finetune_epochs: int,
+ lr: float,
+ finetune_lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ reg_type: str = "extreme",
+ lyap_threshold: float = 2.0,
+ progress: bool = True,
+) -> Dict:
+ """Run post-hoc fine-tuning experiment."""
+ torch.manual_seed(seed)
+
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ print(f"\n{'='*60}")
+ print(f"POST-HOC FINE-TUNING: Depth = {total_depth}")
+ print(f"Pretrain: {pretrain_epochs} epochs (vanilla)")
+ print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})")
+ print(f"{'='*60}")
+
+ model = SpikingVGG(
+ in_channels=in_channels,
+ num_classes=num_classes,
+ base_channels=64,
+ num_stages=num_stages,
+ blocks_per_stage=blocks_per_stage,
+ T=T,
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"Parameters: {num_params:,}")
+
+ criterion = nn.CrossEntropyLoss()
+
+ # Phase 1: Vanilla pre-training
+ print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---")
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs)
+
+ pretrain_history = []
+ best_pretrain_acc = 0.0
+
+ for epoch in range(1, pretrain_epochs + 1):
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov=False, # No Lyapunov during pre-training
+ lambda_reg=0, lambda_target=0, lyap_eps=1e-4,
+ progress=progress,
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_pretrain_acc = max(best_pretrain_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ pretrain_history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == pretrain_epochs:
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}")
+
+ print(f" Best pretrain acc: {best_pretrain_acc:.3f}")
+
+ # Phase 2: Lyapunov fine-tuning
+ print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---")
+ print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}")
+
+ # Reset optimizer with lower learning rate for fine-tuning
+ optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs)
+
+ finetune_history = []
+ best_finetune_acc = 0.0
+
+ for epoch in range(1, finetune_epochs + 1):
+ t0 = time.time()
+
+ # Warmup lambda_reg over first 10 epochs of fine-tuning
+ warmup_epochs = 10
+ if epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov=True,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=1e-4,
+ progress=progress,
+ reg_type=reg_type,
+ current_lambda_reg=current_lambda_reg,
+ lyap_threshold=lyap_threshold,
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_finetune_acc = max(best_finetune_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=pretrain_epochs + epoch, # Continue epoch numbering
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ finetune_history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == finetune_epochs:
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}")
+
+ if np.isnan(train_loss):
+ print(f" DIVERGED at epoch {epoch}")
+ break
+
+ print(f" Best finetune acc: {best_finetune_acc:.3f}")
+ print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "")
+
+ return {
+ "depth": total_depth,
+ "pretrain_history": pretrain_history,
+ "finetune_history": finetune_history,
+ "best_pretrain_acc": best_pretrain_acc,
+ "best_finetune_acc": best_finetune_acc,
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning")
+ parser.add_argument("--dataset", type=str, default="cifar100",
+ choices=["mnist", "fashion_mnist", "cifar10", "cifar100"])
+ parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16])
+ parser.add_argument("--T", type=int, default=4)
+ parser.add_argument("--pretrain_epochs", type=int, default=100)
+ parser.add_argument("--finetune_epochs", type=int, default=50)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--lr", type=float, default=1e-3)
+ parser.add_argument("--finetune_lr", type=float, default=1e-4)
+ parser.add_argument("--lambda_reg", type=float, default=0.1)
+ parser.add_argument("--lambda_target", type=float, default=-0.1)
+ parser.add_argument("--reg_type", type=str, default="extreme")
+ parser.add_argument("--lyap_threshold", type=float, default=2.0)
+ parser.add_argument("--data_dir", type=str, default="./data")
+ parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--no-progress", action="store_true")
+
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT")
+ print("=" * 80)
+ print(f"Dataset: {args.dataset}")
+ print(f"Depths: {args.depths}")
+ print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})")
+ print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})")
+ print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}")
+ print("=" * 80)
+
+ # Load data
+ train_loader, test_loader, num_classes, input_shape = get_dataset(
+ args.dataset, args.data_dir, args.batch_size
+ )
+ in_channels = input_shape[0]
+
+ # Convert depths to configs
+ depth_configs = []
+ for d in args.depths:
+ if d <= 4:
+ depth_configs.append((d, 1))
+ else:
+ depth_configs.append((4, d // 4))
+
+ # Run experiments
+ all_results = []
+ for depth_config in depth_configs:
+ result = run_posthoc_experiment(
+ dataset_name=args.dataset,
+ depth_config=depth_config,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=args.T,
+ pretrain_epochs=args.pretrain_epochs,
+ finetune_epochs=args.finetune_epochs,
+ lr=args.lr,
+ finetune_lr=args.finetune_lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ reg_type=args.reg_type,
+ lyap_threshold=args.lyap_threshold,
+ progress=not args.no_progress,
+ )
+ all_results.append(result)
+
+ # Summary
+ print("\n" + "=" * 80)
+ print("SUMMARY")
+ print("=" * 80)
+ print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}")
+ print("-" * 80)
+
+ for r in all_results:
+ pre_acc = r["best_pretrain_acc"]
+ fine_acc = r["best_finetune_acc"]
+ change = fine_acc - pre_acc
+ final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None
+ lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A"
+ change_str = f"{change:+.3f}"
+
+ print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}")
+
+ print("=" * 80)
+
+ # Save results
+ os.makedirs(args.out_dir, exist_ok=True)
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json")
+
+ serializable_results = []
+ for r in all_results:
+ sr = {
+ "depth": r["depth"],
+ "best_pretrain_acc": r["best_pretrain_acc"],
+ "best_finetune_acc": r["best_finetune_acc"],
+ "pretrain_history": [asdict(m) for m in r["pretrain_history"]],
+ "finetune_history": [asdict(m) for m in r["finetune_history"]],
+ }
+ serializable_results.append(sr)
+
+ with open(output_file, "w") as f:
+ json.dump({"config": vars(args), "results": serializable_results}, f, indent=2)
+
+ print(f"\nResults saved to {output_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/experiments/scaled_reg_grid_search.py b/files/experiments/scaled_reg_grid_search.py
new file mode 100644
index 0000000..928caff
--- /dev/null
+++ b/files/experiments/scaled_reg_grid_search.py
@@ -0,0 +1,301 @@
+"""
+Grid Search: Multiplier-Scaled Regularization Experiments
+
+Tests the new multiplier-scaled regularization approach:
+ loss = (λ_reg × g(relu(λ))) × relu(λ)
+
+Where g(x) is the multiplier scaling function:
+ - mult_linear: g(x) = x → loss = λ_reg × relu(λ)²
+ - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³
+ - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ)
+
+Grid:
+ - λ_reg: 0.01, 0.05, 0.1, 0.3
+ - reg_type: mult_linear, mult_squared, mult_log
+ - depth: specified via command line
+
+Usage:
+ python scaled_reg_grid_search.py --depth 4
+ python scaled_reg_grid_search.py --depth 8
+ python scaled_reg_grid_search.py --depth 12
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional
+from itertools import product
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+# Import from main benchmark
+from depth_scaling_benchmark import (
+ SpikingVGG,
+ compute_lyap_reg_loss,
+)
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+@dataclass
+class ExperimentResult:
+ depth: int
+ reg_type: str
+ lambda_reg: float
+ vanilla_acc: float
+ lyapunov_acc: float
+ final_lyap: Optional[float]
+ delta: float
+
+
+def train_epoch(model, loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, reg_type, progress=False):
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4)
+ loss = criterion(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ # Target is implicitly 0 for scaled reg types
+ lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0)
+ loss = loss + lambda_reg * lyap_reg
+ lyap_vals.append(lyap_est.item())
+
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ _, pred = logits.max(1)
+ correct += pred.eq(y).sum().item()
+ total += x.size(0)
+
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return total_loss / total, correct / total, avg_lyap
+
+
+def evaluate(model, loader, device):
+ """Evaluate model."""
+ model.eval()
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+ _, pred = logits.max(1)
+ correct += pred.eq(y).sum().item()
+ total += x.size(0)
+
+ return correct / total
+
+
+def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader,
+ device, epochs=100, lr=0.001):
+ """Run a single experiment configuration."""
+
+ # Determine blocks per stage based on depth
+ # depth = num_stages * blocks_per_stage, with num_stages=4
+ blocks_per_stage = depth // 4
+
+ print(f"\n{'='*60}")
+ print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}")
+ print(f"{'='*60}")
+
+ # --- Run Vanilla baseline ---
+ print(f" Training Vanilla...")
+ model_v = SpikingVGG(
+ num_classes=100,
+ blocks_per_stage=blocks_per_stage,
+ T=4,
+ ).to(device)
+
+ optimizer_v = optim.Adam(model_v.parameters(), lr=lr)
+ criterion = nn.CrossEntropyLoss()
+ scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs)
+
+ best_vanilla = 0.0
+ for epoch in range(epochs):
+ train_epoch(model_v, train_loader, optimizer_v, criterion, device,
+ use_lyapunov=False, lambda_reg=0, reg_type="squared")
+ scheduler_v.step()
+
+ if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
+ acc = evaluate(model_v, test_loader, device)
+ best_vanilla = max(best_vanilla, acc)
+ print(f" Epoch {epoch+1:3d}: test={acc:.3f}")
+
+ del model_v, optimizer_v, scheduler_v
+ torch.cuda.empty_cache()
+
+ # --- Run Lyapunov version ---
+ print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...")
+ model_l = SpikingVGG(
+ num_classes=100,
+ blocks_per_stage=blocks_per_stage,
+ T=4,
+ ).to(device)
+
+ optimizer_l = optim.Adam(model_l.parameters(), lr=lr)
+ scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs)
+
+ best_lyap_acc = 0.0
+ final_lyap = None
+
+ for epoch in range(epochs):
+ _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device,
+ use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type)
+ scheduler_l.step()
+ final_lyap = lyap
+
+ if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
+ acc = evaluate(model_l, test_loader, device)
+ best_lyap_acc = max(best_lyap_acc, acc)
+ lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A"
+ print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}")
+
+ del model_l, optimizer_l, scheduler_l
+ torch.cuda.empty_cache()
+
+ delta = best_lyap_acc - best_vanilla
+
+ result = ExperimentResult(
+ depth=depth,
+ reg_type=reg_type,
+ lambda_reg=lambda_reg,
+ vanilla_acc=best_vanilla,
+ lyapunov_acc=best_lyap_acc,
+ final_lyap=final_lyap,
+ delta=delta,
+ )
+
+ print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}")
+
+ return result
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12])
+ parser.add_argument("--epochs", type=int, default=100)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--lr", type=float, default=0.001)
+ parser.add_argument("--data_dir", type=str, default="./data")
+ parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid")
+ args = parser.parse_args()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ print("=" * 70)
+ print("SCALED REGULARIZATION GRID SEARCH")
+ print("=" * 70)
+ print(f"Depth: {args.depth}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Device: {device}")
+ if device.type == "cuda":
+ print(f"GPU: {torch.cuda.get_device_name()}")
+ print("=" * 70)
+
+ # Grid parameters
+ lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks
+ reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning
+
+ print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments")
+ print(f"λ_reg values: {lambda_regs}")
+ print(f"reg_types: {reg_types}")
+
+ # Load data
+ print(f"\nLoading CIFAR-100...")
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+
+ train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train)
+ test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test)
+
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=4, pin_memory=True)
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
+ num_workers=4, pin_memory=True)
+
+ print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
+
+ # Run grid search
+ results = []
+
+ for lambda_reg, reg_type in product(lambda_regs, reg_types):
+ result = run_single_experiment(
+ depth=args.depth,
+ reg_type=reg_type,
+ lambda_reg=lambda_reg,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ device=device,
+ epochs=args.epochs,
+ lr=args.lr,
+ )
+ results.append(result)
+
+ # Print summary table
+ print("\n" + "=" * 70)
+ print(f"SUMMARY: DEPTH = {args.depth}")
+ print("=" * 70)
+ print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}")
+ print("-" * 70)
+
+ for r in results:
+ lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A"
+ delta_str = f"{r.delta:+.3f}"
+ print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}")
+
+ # Find best configuration
+ best = max(results, key=lambda x: x.lyapunov_acc)
+ print("-" * 70)
+ print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})")
+
+ # Save results
+ os.makedirs(args.out_dir, exist_ok=True)
+ out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json")
+ with open(out_file, "w") as f:
+ json.dump([asdict(r) for r in results], f, indent=2)
+ print(f"\nResults saved to: {out_file}")
+
+ print("\n" + "=" * 70)
+ print("GRID SEARCH COMPLETE")
+ print("=" * 70)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/files/models/conv_snn.py b/files/models/conv_snn.py
new file mode 100644
index 0000000..69f77d6
--- /dev/null
+++ b/files/models/conv_snn.py
@@ -0,0 +1,483 @@
+"""
+Convolutional SNN with Lyapunov regularization for image classification.
+
+Properly handles spatial structure:
+- Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor
+- Uses Conv-LIF layers to preserve spatial hierarchy
+- Rate encoding converts images to spike trains
+
+Based on standard SNN vision practices:
+- Rate/Poisson encoding for input
+- Conv → BatchNorm → LIF → Pool architecture
+- Time comes from encoding + LIF dynamics, not flattening
+"""
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import snntorch as snn
+from snntorch import surrogate
+
+
+class RateEncoder(nn.Module):
+ """
+ Rate (Poisson/Bernoulli) encoder for static images.
+
+ Converts intensity x ∈ [0,1] to spike probability per timestep.
+ Each pixel independently fires with P(spike) = x * gain.
+
+ Args:
+ T: Number of timesteps
+ gain: Scaling factor for firing probability (default 1.0)
+
+ Input: (B, C, H, W) normalized image in [0, 1]
+ Output: (B, T, C, H, W) binary spike tensor
+ """
+
+ def __init__(self, T: int = 25, gain: float = 1.0):
+ super().__init__()
+ self.T = T
+ self.gain = gain
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: (B, C, H, W) image tensor, values in [0, 1]
+ Returns:
+ spikes: (B, T, C, H, W) binary spike tensor
+ """
+ # Clamp to valid probability range
+ prob = (x * self.gain).clamp(0, 1)
+
+ # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W)
+ prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1)
+
+ # Sample spikes
+ spikes = torch.bernoulli(prob)
+
+ return spikes
+
+
+class DirectEncoder(nn.Module):
+ """
+ Direct encoding - feed static image as constant current.
+
+ Common in surrogate gradient papers: no spike encoding at input,
+ let spiking emerge from the network dynamics.
+
+ Input: (B, C, H, W) image
+ Output: (B, T, C, H, W) repeated image (as analog current)
+ """
+
+ def __init__(self, T: int = 25):
+ super().__init__()
+ self.T = T
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Simply repeat across time
+ return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1)
+
+
+class ConvLIFBlock(nn.Module):
+ """
+ Conv → BatchNorm → LIF block.
+
+ Maintains spatial structure while adding spiking dynamics.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ spike_grad=None,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(
+ in_channels, out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.lif = snn.Leaky(
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mem: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x: (B, C_in, H, W) input (spikes or current)
+ mem: (B, C_out, H', W') membrane potential
+ Returns:
+ spk: (B, C_out, H', W') output spikes
+ mem: (B, C_out, H', W') updated membrane
+ """
+ cur = self.bn(self.conv(x))
+ spk, mem = self.lif(cur, mem)
+ return spk, mem
+
+
+class ConvLyapunovSNN(nn.Module):
+ """
+ Convolutional SNN with Lyapunov exponent regularization.
+
+ Architecture for CIFAR-10 (32x32x3):
+ Input → Encoder → [Conv-LIF-Pool] × N → FC → Output
+
+ Properly preserves spatial structure for hierarchical feature learning.
+
+ Args:
+ in_channels: Input channels (3 for RGB)
+ num_classes: Output classes
+ channels: List of channel sizes for conv layers
+ T: Number of timesteps
+ beta: LIF membrane decay
+ threshold: LIF firing threshold
+ encoding: 'rate', 'direct', or 'none' (pre-encoded input)
+ encoding_gain: Gain for rate encoding
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ channels: List[int] = [64, 128, 256],
+ T: int = 25,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ encoding: str = 'rate',
+ encoding_gain: float = 1.0,
+ dropout: float = 0.2,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.encoding_type = encoding
+ self.channels = channels
+ self.num_layers = len(channels)
+
+ # Input encoder
+ if encoding == 'rate':
+ self.encoder = RateEncoder(T=T, gain=encoding_gain)
+ elif encoding == 'direct':
+ self.encoder = DirectEncoder(T=T)
+ else:
+ self.encoder = None # Expect pre-encoded (B, T, C, H, W) input
+
+ # Build conv-LIF layers
+ self.blocks = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ ch_in = in_channels
+ for ch_out in channels:
+ self.blocks.append(
+ ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold)
+ )
+ self.pools.append(nn.AvgPool2d(2))
+ ch_in = ch_out
+
+ # Calculate output spatial size after pooling
+ # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers)
+ spatial_size = 32 // (2 ** len(channels))
+ fc_input = channels[-1] * spatial_size * spatial_size
+
+ # Fully connected readout
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(fc_input, num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]:
+ """Initialize membrane potentials for all layers."""
+ mems = []
+ H, W = 32, 32
+ for i, ch in enumerate(self.channels):
+ H, W = H // 2, W // 2 # After pooling
+ # Actually we need size BEFORE pooling for LIF
+ H_pre, W_pre = H * 2, W * 2
+ mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype))
+ return mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+ """
+ Forward pass with optional Lyapunov computation.
+
+ Args:
+ x: Input tensor
+ - If encoder: (B, C, H, W) static image
+ - If no encoder: (B, T, C, H, W) pre-encoded spikes
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude
+
+ Returns:
+ logits: (B, num_classes)
+ lyap_est: Scalar Lyapunov estimate or None
+ recordings: Optional dict with spike recordings
+ """
+ # Encode input if needed
+ if self.encoder is not None:
+ x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W)
+
+ B, T, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize membrane potentials
+ mems = self._init_mem(B, device, dtype)
+
+ # For accumulating output spikes
+ spike_sum = None
+
+ # Lyapunov setup
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ # Time loop
+ for t in range(T):
+ x_t = x[:, t] # (B, C, H, W)
+
+ # Forward through conv-LIF blocks
+ h = x_t
+ new_mems = []
+ for i, (block, pool) in enumerate(zip(self.blocks, self.pools)):
+ h, mem = block(h, mems[i])
+ new_mems.append(mem)
+ h = pool(h) # Spatial downsampling
+
+ mems = new_mems
+
+ # Accumulate final layer spikes
+ if spike_sum is None:
+ spike_sum = h.view(B, -1)
+ else:
+ spike_sum = spike_sum + h.view(B, -1)
+
+ # Lyapunov computation
+ if compute_lyapunov:
+ h_p = x_t
+ new_mems_p = []
+ for i, (block, pool) in enumerate(zip(self.blocks, self.pools)):
+ h_p, mem_p = block(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = pool(h_p)
+
+ # Compute divergence
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(self.num_layers):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq += (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize perturbation
+ for i in range(self.num_layers):
+ diff = new_mems_p[i] - new_mems[i]
+ norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12)
+ # Broadcast norm to spatial dimensions
+ norm = norm.view(B, 1, 1, 1)
+ new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm
+
+ mems_p = new_mems_p
+
+ # Readout
+ out = self.dropout(spike_sum)
+ logits = self.fc(out)
+
+ if compute_lyapunov:
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ return logits, lyap_est, None
+
+
+class VGGLyapunovSNN(nn.Module):
+ """
+ VGG-style deep Conv-SNN with Lyapunov regularization.
+
+ Deeper architecture for more challenging benchmarks.
+ Uses multiple conv layers between pooling to increase depth.
+
+ Architecture (VGG-9 style):
+ [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ T: int = 25,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ encoding: str = 'rate',
+ dropout: float = 0.3,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.encoding_type = encoding
+
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ if encoding == 'rate':
+ self.encoder = RateEncoder(T=T)
+ elif encoding == 'direct':
+ self.encoder = DirectEncoder(T=T)
+ else:
+ self.encoder = None
+
+ # VGG-style blocks: (in_ch, out_ch, num_convs)
+ block_configs = [
+ (in_channels, 64, 2), # 32x32 -> 16x16
+ (64, 128, 2), # 16x16 -> 8x8
+ (128, 256, 2), # 8x8 -> 4x4
+ ]
+
+ self.blocks = nn.ModuleList()
+ for in_ch, out_ch, n_convs in block_configs:
+ layers = []
+ for i in range(n_convs):
+ ch_in = in_ch if i == 0 else out_ch
+ layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False))
+ layers.append(nn.BatchNorm2d(out_ch))
+ self.blocks.append(nn.ModuleList(layers))
+
+ # LIF neurons for each conv layer
+ self.lifs = nn.ModuleList([
+ snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+ for _ in range(6) # 2 convs × 3 blocks
+ ])
+
+ self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)])
+
+ # FC layers
+ self.fc1 = nn.Linear(256 * 4 * 4, 512)
+ self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+ self.dropout = nn.Dropout(dropout)
+ self.fc2 = nn.Linear(512, num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+
+ if self.encoder is not None:
+ x = self.encoder(x)
+
+ B, T, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize all membrane potentials
+ # For each conv layer output
+ mem_shapes = [
+ (B, 64, 32, 32), (B, 64, 32, 32), # Block 1
+ (B, 128, 16, 16), (B, 128, 16, 16), # Block 2
+ (B, 256, 8, 8), (B, 256, 8, 8), # Block 3
+ (B, 512), # FC
+ ]
+ mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes]
+
+ spike_sum = torch.zeros(B, 512, device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ for t in range(T):
+ h = x[:, t]
+
+ lif_idx = 0
+ for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)):
+ for i in range(0, len(block_layers), 2): # Conv, BN pairs
+ conv, bn = block_layers[i], block_layers[i + 1]
+ h = bn(conv(h))
+ h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx])
+ lif_idx += 1
+ h = pool(h)
+
+ # FC layers
+ h = h.view(B, -1)
+ h = self.fc1(h)
+ h, mems[6] = self.lif_fc(h, mems[6])
+ spike_sum = spike_sum + h
+
+ # Lyapunov (simplified - just on last layer)
+ if compute_lyapunov:
+ diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6])
+ delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12
+ if t > 0:
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+ mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6])
+
+ out = self.dropout(spike_sum)
+ logits = self.fc2(out)
+
+ lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est, None
+
+
+def create_conv_snn(
+ model_type: str = 'simple',
+ **kwargs,
+) -> nn.Module:
+ """
+ Factory function for Conv-SNN models.
+
+ Args:
+ model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style)
+ **kwargs: Arguments passed to model constructor
+ """
+ if model_type == 'simple':
+ return ConvLyapunovSNN(**kwargs)
+ elif model_type == 'vgg':
+ return VGGLyapunovSNN(**kwargs)
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}")
diff --git a/files/models/snn.py b/files/models/snn.py
new file mode 100644
index 0000000..b1cf633
--- /dev/null
+++ b/files/models/snn.py
@@ -0,0 +1,141 @@
+import math
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SurrogateStep(torch.autograd.Function):
+ """
+ Heaviside with a smooth surrogate gradient (fast sigmoid).
+ """
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, alpha: float):
+ ctx.save_for_backward(x)
+ ctx.alpha = alpha
+ return (x > 0).to(x.dtype)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (x,) = ctx.saved_tensors
+ alpha = ctx.alpha
+ # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid)
+ # Use fast sigmoid: s = 1 / (1 + |alpha*x|)
+ s = 1.0 / (1.0 + (alpha * x).abs())
+ grad = grad_output * s * s
+ return grad, None
+
+
+def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor:
+ return SurrogateStep.apply(x, alpha)
+
+
+class LIFLayer(nn.Module):
+ """
+ Single LIF layer without recurrent synapses between neurons.
+ Dynamics per neuron i:
+ v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1}
+ s_t = H( v_t - v_th ) with surrogate gradient
+ """
+ def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
+ super().__init__()
+ self.linear = nn.Linear(input_dim, hidden_dim, bias=True)
+ self.v_threshold = float(v_threshold)
+ self.decay = float(decay)
+ self.spike_alpha = float(spike_alpha)
+ self.rec_strength = float(rec_strength)
+ self.rec = None
+ if self.rec_strength != 0.0:
+ self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False)
+
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+ if self.rec is not None:
+ nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale)
+
+ def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H)
+ I_t = self.linear(x_t) # (B, H)
+ R_t = 0.0
+ if self.rec is not None:
+ R_t = self.rec_strength * self.rec(s_prev)
+ v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev
+ s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha)
+ return v_t, s_t
+
+
+class SimpleSNN(nn.Module):
+ """
+ Minimal SNN for SHD-like input (B,T,D):
+ - One LIF hidden layer
+ - Readout linear on time-summed spikes
+ """
+ def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
+ super().__init__()
+ self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale)
+ self.readout = nn.Linear(hidden_dim, num_classes)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ @torch.no_grad()
+ def _init_states(self, batch_size: int, hidden_dim: int, device, dtype):
+ v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
+ s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
+ return v0, s0
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-3,
+ lyap_safe_eps: float = 1e-8,
+ lyap_measure: str = "v", # "v" or "s"
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ x: (B, T, D)
+ Returns:
+ logits: (B, C)
+ lyap_est: scalar tensor if compute_lyapunov else None
+ """
+ assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}"
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+ H = self.readout.in_features
+
+ v, s = self._init_states(B, H, device, dtype)
+ spike_sum = torch.zeros(B, H, device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ v_p = v + lyap_eps * torch.randn_like(v)
+ s_p = s.clone()
+ delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
+ lyap_terms = []
+
+ for t in range(T):
+ x_t = x[:, t, :]
+ v, s = self.lif(x_t, v, s)
+ spike_sum = spike_sum + s
+
+ if compute_lyapunov:
+ # run perturbed trajectory through same ops
+ v_p, s_p = self.lif(x_t, v_p, s_p)
+ if lyap_measure == "s":
+ delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps
+ else:
+ delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
+ ratio = delta_t / delta_prev
+ lyap_terms.append(torch.log(ratio + lyap_safe_eps))
+ delta_prev = delta_t
+
+ logits = self.readout(spike_sum) # (B, C)
+
+ if compute_lyapunov:
+ lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,)
+ lyap_est = lyap_batch.mean() # scalar
+ else:
+ lyap_est = None
+
+ return logits, lyap_est
+
+
diff --git a/files/models/snn_snntorch.py b/files/models/snn_snntorch.py
new file mode 100644
index 0000000..71c1c18
--- /dev/null
+++ b/files/models/snn_snntorch.py
@@ -0,0 +1,398 @@
+"""
+snnTorch-based SNN with Lyapunov exponent regularization.
+
+This module provides deep SNN architectures using snnTorch with proper
+finite-time Lyapunov exponent computation for training stabilization.
+"""
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import snntorch as snn
+from snntorch import surrogate
+
+
+class LyapunovSNN(nn.Module):
+ """
+ Multi-layer SNN using snnTorch with Lyapunov exponent computation.
+
+ Architecture:
+ Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits
+
+ Args:
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers)
+ num_classes: Number of output classes
+ beta: Membrane potential decay factor (0 < beta < 1)
+ threshold: Firing threshold
+ spike_grad: Surrogate gradient function (default: fast_sigmoid)
+ dropout: Dropout probability between layers (0 = no dropout)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ spike_grad: Optional[Any] = None,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.hidden_dims = hidden_dims
+ self.num_layers = len(hidden_dims)
+ self.beta = beta
+ self.threshold = threshold
+
+ # Build layers
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+ self.dropouts = nn.ModuleList() if dropout > 0 else None
+
+ dims = [input_dim] + hidden_dims
+ for i in range(self.num_layers):
+ self.linears.append(nn.Linear(dims[i], dims[i + 1]))
+ self.lifs.append(
+ snn.Leaky(
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ reset_mechanism="subtract",
+ )
+ )
+ if dropout > 0:
+ self.dropouts.append(nn.Dropout(p=dropout))
+
+ # Readout layer
+ self.readout = nn.Linear(hidden_dims[-1], num_classes)
+
+ # Initialize weights
+ self._init_weights()
+
+ def _init_weights(self):
+ for lin in self.linears:
+ nn.init.xavier_uniform_(lin.weight)
+ nn.init.zeros_(lin.bias)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]:
+ """Initialize membrane potentials for all layers."""
+ mems = []
+ for dim in self.hidden_dims:
+ mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ return mems
+
+ def _step(
+ self,
+ x_t: torch.Tensor,
+ mems: List[torch.Tensor],
+ training: bool = True,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Single timestep forward pass.
+
+ Returns:
+ spike_out: Output spikes from last layer (B, H_last)
+ new_mems: Updated membrane potentials
+ all_mems: Membrane potentials from all layers (for Lyapunov)
+ """
+ new_mems = []
+ all_mems = []
+
+ h = x_t
+ for i in range(self.num_layers):
+ h = self.linears[i](h)
+ spk, mem = self.lifs[i](h, mems[i])
+ new_mems.append(mem)
+ all_mems.append(mem)
+ h = spk
+ if self.dropouts is not None and training:
+ h = self.dropouts[i](h)
+
+ return h, new_mems, all_mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ lyap_layers: Optional[List[int]] = None,
+ record_states: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]:
+ """
+ Forward pass with optional Lyapunov exponent computation.
+
+ Args:
+ x: Input tensor (B, T, D)
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude for Lyapunov computation
+ lyap_layers: Which layers to measure (default: all).
+ e.g., [0] for first layer only, [-1] for last layer
+ record_states: Whether to record spikes and membrane potentials
+
+ Returns:
+ logits: Classification logits (B, num_classes)
+ lyap_est: Estimated Lyapunov exponent (scalar) or None
+ recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None
+ """
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize states
+ mems = self._init_states(B, device, dtype)
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ # Recording setup
+ if record_states:
+ spike_rec = []
+ mem_rec = []
+
+ # Lyapunov setup
+ if compute_lyapunov:
+ if lyap_layers is None:
+ lyap_layers = list(range(self.num_layers))
+
+ # Perturbed trajectory - perturb all membrane potentials
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ # Time loop
+ for t in range(T):
+ x_t = x[:, t, :]
+
+ # Nominal trajectory
+ spk, mems, all_mems = self._step(x_t, mems, training=self.training)
+ spike_sum = spike_sum + spk
+
+ if record_states:
+ spike_rec.append(spk.detach())
+ mem_rec.append(all_mems[-1].detach()) # Last layer membrane
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False)
+
+ # Compute divergence across selected layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ delta_p_sq = torch.zeros(B, device=device, dtype=dtype)
+
+ for layer_idx in lyap_layers:
+ diff = all_mems_p[layer_idx] - all_mems[layer_idx]
+ delta_sq += (diff ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+
+ # Renormalization step (key for numerical stability)
+ # Rescale perturbation back to fixed magnitude
+ for layer_idx in lyap_layers:
+ diff = mems_p[layer_idx] - mems[layer_idx]
+ # Normalize to maintain fixed perturbation magnitude
+ norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12
+ diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm
+ mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized
+
+ # Accumulate log-divergence
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ logits = self.readout(spike_sum)
+
+ if compute_lyapunov:
+ # Average over time and batch
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ if record_states:
+ recordings = {
+ "spikes": torch.stack(spike_rec, dim=1), # (B, T, H)
+ "membrane": torch.stack(mem_rec, dim=1), # (B, T, H)
+ }
+ else:
+ recordings = None
+
+ return logits, lyap_est, recordings
+
+
+class RecurrentLyapunovSNN(nn.Module):
+ """
+ Recurrent SNN with Lyapunov exponent computation.
+
+ Uses snnTorch's RSynaptic (recurrent synaptic) neurons for
+ richer temporal dynamics.
+
+ Args:
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes
+ num_classes: Number of output classes
+ alpha: Synaptic current decay rate
+ beta: Membrane potential decay rate
+ threshold: Firing threshold
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ alpha: float = 0.9,
+ beta: float = 0.85,
+ threshold: float = 1.0,
+ spike_grad: Optional[Any] = None,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.hidden_dims = hidden_dims
+ self.num_layers = len(hidden_dims)
+ self.alpha = alpha
+ self.beta = beta
+
+ # Build layers with recurrent synaptic neurons
+ self.linears = nn.ModuleList()
+ self.neurons = nn.ModuleList()
+
+ dims = [input_dim] + hidden_dims
+ for i in range(self.num_layers):
+ self.linears.append(nn.Linear(dims[i], dims[i + 1]))
+ self.neurons.append(
+ snn.RSynaptic(
+ alpha=alpha,
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ reset_mechanism="subtract",
+ all_to_all=True,
+ linear_features=dims[i + 1],
+ )
+ )
+
+ self.readout = nn.Linear(hidden_dims[-1], num_classes)
+ self._init_weights()
+
+ def _init_weights(self):
+ for lin in self.linears:
+ nn.init.xavier_uniform_(lin.weight)
+ nn.init.zeros_(lin.bias)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ def _init_states(self, batch_size: int, device, dtype):
+ """Initialize synaptic currents and membrane potentials."""
+ syns = []
+ mems = []
+ for dim in self.hidden_dims:
+ syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ return syns, mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Forward pass with optional Lyapunov computation."""
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+
+ syns, mems = self._init_states(B, device, dtype)
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ # Perturb both synaptic currents and membrane potentials
+ syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns]
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ for t in range(T):
+ x_t = x[:, t, :]
+
+ # Nominal trajectory
+ h = x_t
+ new_syns, new_mems = [], []
+ for i in range(self.num_layers):
+ h = self.linears[i](h)
+ spk, syn, mem = self.neurons[i](h, syns[i], mems[i])
+ new_syns.append(syn)
+ new_mems.append(mem)
+ h = spk
+ syns, mems = new_syns, new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ h_p = x_t
+ new_syns_p, new_mems_p = [], []
+ for i in range(self.num_layers):
+ h_p = self.linears[i](h_p)
+ spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i])
+ new_syns_p.append(syn_p)
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # Compute divergence (on membrane potentials)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(self.num_layers):
+ diff_m = new_mems_p[i] - new_mems[i]
+ diff_s = new_syns_p[i] - new_syns[i]
+ delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize perturbation
+ total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem
+ scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12)
+
+ syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i])
+ for i in range(self.num_layers)]
+ mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i])
+ for i in range(self.num_layers)]
+
+ logits = self.readout(spike_sum)
+
+ if compute_lyapunov:
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ return logits, lyap_est
+
+
+def create_snn(
+ model_type: str,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ **kwargs,
+) -> nn.Module:
+ """
+ Factory function to create SNN models.
+
+ Args:
+ model_type: "feedforward" or "recurrent"
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes
+ num_classes: Number of output classes
+ **kwargs: Additional arguments passed to model constructor
+
+ Returns:
+ SNN model instance
+ """
+ if model_type == "feedforward":
+ return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs)
+ elif model_type == "recurrent":
+ return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs)
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'")
diff --git a/files/tests/conftest.py b/files/tests/conftest.py
new file mode 100644
index 0000000..a63c259
--- /dev/null
+++ b/files/tests/conftest.py
@@ -0,0 +1,11 @@
+import os
+import sys
+
+# Ensure project root is importable so `import files...` works when tests are under files/tests
+_HERE = os.path.dirname(__file__)
+_FILES_DIR = os.path.dirname(_HERE)
+_ROOT = os.path.dirname(_FILES_DIR)
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+
diff --git a/files/tests/test_data_io.py b/files/tests/test_data_io.py
new file mode 100644
index 0000000..1f2ccd8
--- /dev/null
+++ b/files/tests/test_data_io.py
@@ -0,0 +1,11 @@
+import torch
+from files.data_io.dataset_loader import get_dataloader
+
+
+def test_dataloader_shape():
+ """Smoke test: verify dataloader output shape."""
+ train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
+ x, y = next(iter(train_loader))
+ assert isinstance(x, torch.Tensor)
+ assert x.ndim == 3
+ assert y.ndim == 1
diff --git a/files/tests/test_shd_loader_properties.py b/files/tests/test_shd_loader_properties.py
new file mode 100644
index 0000000..237740c
--- /dev/null
+++ b/files/tests/test_shd_loader_properties.py
@@ -0,0 +1,40 @@
+import torch
+from files.data_io.dataset_loader import get_dataloader, SHDDataset
+
+
+def test_shd_dataset_global_T_D_consistency():
+ # Ensure SHDDataset computes global T and adaptive D once and applies consistently
+ ds = SHDDataset(
+ data_dir="/u/yurenh2/ml-projects/snn-training/files/data",
+ split="train",
+ dt_ms=1.0,
+ default_D=700,
+ )
+ assert ds.T >= 1
+ assert ds.D >= 700 # by construction at least default_D
+ x0, y0 = ds[0]
+ x1, y1 = ds[1]
+ assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor)
+ assert x0.shape == (ds.T, ds.D)
+ assert x1.shape == (ds.T, ds.D)
+ assert isinstance(y0, int) and isinstance(y1, int)
+ # values should be finite
+ assert torch.isfinite(x0).all()
+ assert torch.isfinite(x1).all()
+
+
+def test_dataloader_batch_shapes_and_finiteness():
+ train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
+ xb, yb = next(iter(train_loader))
+ assert isinstance(xb, torch.Tensor)
+ assert xb.ndim == 3 # (B, T, D)
+ assert yb.ndim == 1
+ B, T, D = xb.shape
+ assert B >= 1 and T >= 1 and D >= 1
+ # finiteness and no NaNs
+ assert torch.isfinite(xb).all()
+ # After normalization (enabled in config), per-sample per-channel mean over time should be ~0
+ mean_t = xb.mean(dim=1) # (B, D)
+ assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3)
+
+
diff --git a/files/tests/test_train_smoke.py b/files/tests/test_train_smoke.py
new file mode 100644
index 0000000..e04e40b
--- /dev/null
+++ b/files/tests/test_train_smoke.py
@@ -0,0 +1,33 @@
+import torch
+from files.models.snn import SimpleSNN
+from files.data_io.dataset_loader import get_dataloader
+import torch.nn as nn
+import torch.optim as optim
+
+
+def _one_step(model, xb, yb, lyapunov=False):
+ model.train()
+ opt = optim.Adam(model.parameters(), lr=1e-3)
+ ce = nn.CrossEntropyLoss()
+ opt.zero_grad(set_to_none=True)
+ logits, lyap = model(xb, compute_lyapunov=lyapunov)
+ loss = ce(logits, yb)
+ if lyapunov and lyap is not None:
+ loss = loss + 0.1 * (lyap - 0.0) ** 2
+ loss.backward()
+ opt.step()
+ assert torch.isfinite(loss).all()
+
+
+def test_train_step_baseline_and_lyapunov():
+ train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
+ xb, yb = next(iter(train_loader))
+ B, T, D = xb.shape
+ C = 20
+ model = SimpleSNN(input_dim=D, hidden_dim=64, num_classes=C)
+ # baseline
+ _one_step(model, xb, yb, lyapunov=False)
+ # lyapunov-regularized
+ _one_step(model, xb, yb, lyapunov=True)
+
+
diff --git a/files/tests/test_transforms_normalize.py b/files/tests/test_transforms_normalize.py
new file mode 100644
index 0000000..cc5a3d4
--- /dev/null
+++ b/files/tests/test_transforms_normalize.py
@@ -0,0 +1,53 @@
+import torch
+from files.data_io.transforms.normalization import Normalize
+
+
+def test_normalize_zscore_2d_zero_mean_unitish_var():
+ # (T, D) toy example
+ T, D = 5, 3
+ x = torch.arange(T * D, dtype=torch.float32).reshape(T, D)
+ norm = Normalize(mode="zscore")
+ xz = norm(x)
+ # mean over time per channel should be ~0
+ mean_t = xz.mean(dim=0)
+ assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
+ # std over time per channel should be ~1
+ std_t = xz.std(dim=0, unbiased=False)
+ assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5)
+
+
+def test_normalize_zscore_3d_per_sample():
+ # (B, T, D)
+ B, T, D = 2, 6, 4
+ x = torch.randn(B, T, D)
+ norm = Normalize(mode="zscore")
+ xz = norm(x)
+ mean_t = xz.mean(dim=1) # (B, D)
+ std_t = xz.std(dim=1, unbiased=False) # (B, D)
+ assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
+ assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4)
+
+
+def test_normalize_minmax_range_01_2d():
+ T, D = 7, 2
+ x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D)
+ norm = Normalize(mode="minmax")
+ xm = norm(x)
+ assert xm.min().item() >= -1e-6
+ assert xm.max().item() <= 1 + 1e-6
+ # Check endpoints map to 0 and 1
+ assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6)
+ assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6)
+
+
+def test_normalize_rejects_bad_ndim():
+ x = torch.ones(1, 2, 3, 4)
+ norm = Normalize()
+ try:
+ _ = norm(x)
+ except ValueError:
+ pass
+ else:
+ raise AssertionError("Normalize should raise on x.ndim not in {2,3}")
+
+
diff --git a/files/train_mvp.py b/files/train_mvp.py
new file mode 100644
index 0000000..b89ddc6
--- /dev/null
+++ b/files/train_mvp.py
@@ -0,0 +1,202 @@
+import os
+import sys
+import json
+import csv
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(_HERE)
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+import argparse
+import time
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from files.data_io.dataset_loader import get_dataloader, SHDDataset
+from files.models.snn import SimpleSNN
+from tqdm.auto import tqdm
+
+
+def _prepare_run_dir(base_dir: str):
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ run_dir = os.path.join(base_dir, ts)
+ os.makedirs(run_dir, exist_ok=True)
+ return run_dir
+
+
+def _append_metrics(csv_path: str, row: dict):
+ write_header = not os.path.exists(csv_path)
+ with open(csv_path, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=row.keys())
+ if write_header:
+ writer.writeheader()
+ writer.writerow(row)
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="MVP training: baseline vs Lyapunov-regularized")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="YAML config for dataloader")
+ p.add_argument("--epochs", type=int, default=2)
+ p.add_argument("--hidden", type=int, default=256)
+ p.add_argument("--classes", type=int, default=20, help="Number of classes")
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regulation")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Weight for Lyapunov penalty")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target average log growth (≈0 for neutral)")
+ p.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bar")
+ p.add_argument("--out_dir", type=str, default="runs/mvp", help="Directory to save metrics/checkpoints")
+ p.add_argument("--log_batches", action="store_true", help="Also log per-batch metrics to CSV")
+ # Model dynamics and recurrence controls
+ p.add_argument("--spike_alpha", type=float, default=5.0, help="Surrogate spike sharpness")
+ p.add_argument("--decay", type=float, default=0.95, help="Membrane decay")
+ p.add_argument("--v_threshold", type=float, default=1.0, help="Firing threshold")
+ p.add_argument("--rec_strength", type=float, default=0.0, help="Recurrent coupling strength on spikes")
+ p.add_argument("--rec_init_scale", type=float, default=1.0, help="Gain for recurrent weight init")
+ # Lyapunov measurement controls
+ p.add_argument("--lyap_measure", type=str, default="v", choices=["v", "s"], help="Measure divergence on 'v' or 's'")
+ p.add_argument("--lyap_eps", type=float, default=1e-3, help="Initial perturbation magnitude")
+ return p.parse_args()
+
+
+def train_one_epoch(
+ model: SimpleSNN,
+ loader,
+ optimizer,
+ device,
+ ce_loss: nn.Module,
+ lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ progress: bool,
+ run_dir: str | None = None,
+ epoch_idx: int | None = None,
+ log_batches: bool = False,
+ lyap_measure: str = "v",
+ lyap_eps: float = 1e-3,
+):
+ model.train()
+ total = 0
+ correct = 0
+ running_loss = 0.0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader
+ for bidx, (x, y) in enumerate(iterator):
+ x = x.to(device) # (B, T, D)
+ y = y.to(device)
+ optimizer.zero_grad(set_to_none=True)
+ logits, lyap_est = model(
+ x,
+ compute_lyapunov=lyapunov,
+ lyap_eps=lyap_eps,
+ lyap_measure=lyap_measure,
+ )
+ ce = ce_loss(logits, y)
+ if lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.detach().item())
+ else:
+ loss = ce
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ batch_correct = (preds == y).sum().item()
+ correct += batch_correct
+ total += x.size(0)
+ if log_batches and run_dir is not None and epoch_idx is not None:
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "batch",
+ "epoch": int(epoch_idx),
+ "batch": int(bidx),
+ "loss": float(loss.item()),
+ "acc": float(batch_correct / max(x.size(0), 1)),
+ "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"),
+ "time_sec": float(0.0),
+ },
+ )
+ if progress:
+ avg_loss = running_loss / max(total, 1)
+ avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None
+ postfix = {"loss": f"{avg_loss:.4f}"}
+ if avg_lyap is not None:
+ postfix["lyap"] = f"{avg_lyap:.4f}"
+ iterator.set_postfix(postfix)
+
+ avg_loss = running_loss / max(total, 1)
+ acc = correct / max(total, 1)
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return avg_loss, acc, avg_lyap
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ # Prepare output directory and save run config
+ run_dir = _prepare_run_dir(args.out_dir)
+ with open(os.path.join(run_dir, "args.json"), "w") as f:
+ json.dump(vars(args), f, indent=2)
+
+ train_loader, val_loader = get_dataloader(args.cfg)
+
+ # Infer input dim and classes from a sample and args
+ xb, yb = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = args.classes
+
+ model = SimpleSNN(
+ input_dim=D,
+ hidden_dim=args.hidden,
+ num_classes=C,
+ v_threshold=args.v_threshold,
+ decay=args.decay,
+ spike_alpha=args.spike_alpha,
+ rec_strength=args.rec_strength,
+ rec_init_scale=args.rec_init_scale,
+ ).to(device)
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ print(f"Starting training on {device} | lyapunov={args.lyapunov} lambda={args.lambda_reg} target={args.lambda_target}")
+ print(f"Saving run to: {run_dir}")
+ for epoch in range(1, args.epochs + 1):
+ t0 = time.time()
+ tr_loss, tr_acc, tr_lyap = train_one_epoch(
+ model, train_loader, optimizer, device, ce_loss,
+ lyapunov=args.lyapunov, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target,
+ progress=(not args.no_progress),
+ run_dir=run_dir,
+ epoch_idx=epoch,
+ log_batches=args.log_batches,
+ lyap_measure=args.lyap_measure,
+ lyap_eps=args.lyap_eps,
+ )
+ dt = time.time() - t0
+ lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else ""
+ print(f"[Epoch {epoch}] loss={tr_loss:.4f} acc={tr_acc:.3f}{lyap_str} ({dt:.1f}s)")
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "epoch",
+ "epoch": int(epoch),
+ "batch": int(-1),
+ "loss": float(tr_loss),
+ "acc": float(tr_acc),
+ "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"),
+ "time_sec": float(dt),
+ },
+ )
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/files/train_snntorch.py b/files/train_snntorch.py
new file mode 100644
index 0000000..dbdce37
--- /dev/null
+++ b/files/train_snntorch.py
@@ -0,0 +1,345 @@
+"""
+Training script for snnTorch-based deep SNNs with Lyapunov regularization.
+
+Usage:
+ # Baseline (no Lyapunov)
+ python files/train_snntorch.py --hidden 256 128 --epochs 10
+
+ # With Lyapunov regularization
+ python files/train_snntorch.py --hidden 256 128 --epochs 10 --lyapunov --lambda_reg 0.1
+
+ # Recurrent model
+ python files/train_snntorch.py --model recurrent --hidden 256 --epochs 10 --lyapunov
+"""
+
+import os
+import sys
+import json
+import csv
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(_HERE)
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import time
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from tqdm.auto import tqdm
+
+from files.data_io.dataset_loader import get_dataloader
+from files.models.snn_snntorch import create_snn
+
+
+def _prepare_run_dir(base_dir: str) -> str:
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ run_dir = os.path.join(base_dir, ts)
+ os.makedirs(run_dir, exist_ok=True)
+ return run_dir
+
+
+def _append_metrics(csv_path: str, row: dict):
+ write_header = not os.path.exists(csv_path)
+ with open(csv_path, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=row.keys())
+ if write_header:
+ writer.writeheader()
+ writer.writerow(row)
+
+
+def parse_args():
+ p = argparse.ArgumentParser(
+ description="Train deep SNN with snnTorch and optional Lyapunov regularization"
+ )
+
+ # Model architecture
+ p.add_argument(
+ "--model", type=str, default="feedforward", choices=["feedforward", "recurrent"],
+ help="Model type: 'feedforward' (LIF) or 'recurrent' (RSynaptic)"
+ )
+ p.add_argument(
+ "--hidden", type=int, nargs="+", default=[256],
+ help="Hidden layer sizes (e.g., --hidden 256 128 for 2 layers)"
+ )
+ p.add_argument("--classes", type=int, default=20, help="Number of output classes")
+ p.add_argument("--beta", type=float, default=0.9, help="Membrane decay (beta)")
+ p.add_argument("--threshold", type=float, default=1.0, help="Firing threshold")
+ p.add_argument("--dropout", type=float, default=0.0, help="Dropout between layers")
+ p.add_argument(
+ "--surrogate_slope", type=float, default=25.0,
+ help="Slope for fast_sigmoid surrogate gradient"
+ )
+
+ # Recurrent-specific (only for --model recurrent)
+ p.add_argument("--alpha", type=float, default=0.9, help="Synaptic current decay (recurrent only)")
+
+ # Training
+ p.add_argument("--epochs", type=int, default=10)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--weight_decay", type=float, default=0.0, help="L2 regularization")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config")
+
+ # Lyapunov regularization
+ p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regularization")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov penalty weight")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent")
+ p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation magnitude")
+ p.add_argument(
+ "--lyap_layers", type=int, nargs="*", default=None,
+ help="Which layers to measure (default: all). E.g., --lyap_layers 0 1"
+ )
+
+ # Output
+ p.add_argument("--out_dir", type=str, default="runs/snntorch", help="Output directory")
+ p.add_argument("--log_batches", action="store_true", help="Log per-batch metrics")
+ p.add_argument("--no-progress", action="store_true", help="Disable progress bar")
+ p.add_argument("--save_model", action="store_true", help="Save model checkpoint")
+
+ return p.parse_args()
+
+
+def train_one_epoch(
+ model: nn.Module,
+ loader,
+ optimizer: optim.Optimizer,
+ device: torch.device,
+ ce_loss: nn.Module,
+ lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ lyap_layers: Optional[List[int]],
+ progress: bool,
+ run_dir: Optional[str] = None,
+ epoch_idx: Optional[int] = None,
+ log_batches: bool = False,
+):
+ model.train()
+ total = 0
+ correct = 0
+ running_loss = 0.0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader
+
+ for bidx, (x, y) in enumerate(iterator):
+ x = x.to(device) # (B, T, D)
+ y = y.to(device)
+
+ optimizer.zero_grad(set_to_none=True)
+
+ logits, lyap_est = model(
+ x,
+ compute_lyapunov=lyapunov,
+ lyap_eps=lyap_eps,
+ lyap_layers=lyap_layers,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if lyapunov and lyap_est is not None:
+ # Penalize deviation from target Lyapunov exponent
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.detach().item())
+ else:
+ loss = ce
+
+ loss.backward()
+ optimizer.step()
+
+ running_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ batch_correct = (preds == y).sum().item()
+ correct += batch_correct
+ total += x.size(0)
+
+ if log_batches and run_dir is not None and epoch_idx is not None:
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "batch",
+ "epoch": int(epoch_idx),
+ "batch": int(bidx),
+ "loss": float(loss.item()),
+ "acc": float(batch_correct / max(x.size(0), 1)),
+ "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"),
+ "time_sec": 0.0,
+ },
+ )
+
+ if progress:
+ avg_loss = running_loss / max(total, 1)
+ avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None
+ postfix = {"loss": f"{avg_loss:.4f}", "acc": f"{correct / total:.3f}"}
+ if avg_lyap is not None:
+ postfix["lyap"] = f"{avg_lyap:.4f}"
+ iterator.set_postfix(postfix)
+
+ avg_loss = running_loss / max(total, 1)
+ acc = correct / max(total, 1)
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return avg_loss, acc, avg_lyap
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader,
+ device: torch.device,
+ ce_loss: nn.Module,
+ progress: bool,
+):
+ model.eval()
+ total = 0
+ correct = 0
+ running_loss = 0.0
+
+ iterator = tqdm(loader, desc="eval", leave=False, dynamic_ncols=True) if progress else loader
+
+ for x, y in iterator:
+ x = x.to(device)
+ y = y.to(device)
+
+ logits, _ = model(x, compute_lyapunov=False)
+ loss = ce_loss(logits, y)
+
+ running_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ correct += (preds == y).sum().item()
+ total += x.size(0)
+
+ avg_loss = running_loss / max(total, 1)
+ acc = correct / max(total, 1)
+ return avg_loss, acc
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ # Prepare output directory
+ run_dir = _prepare_run_dir(args.out_dir)
+ with open(os.path.join(run_dir, "args.json"), "w") as f:
+ json.dump(vars(args), f, indent=2)
+
+ # Load data
+ train_loader, val_loader = get_dataloader(args.cfg)
+
+ # Infer dimensions from data
+ xb, yb = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = args.classes
+
+ print(f"Data: T={T}, D={D}, classes={C}")
+ print(f"Model: {args.model}, hidden={args.hidden}")
+
+ # Create model
+ from snntorch import surrogate
+ spike_grad = surrogate.fast_sigmoid(slope=args.surrogate_slope)
+
+ if args.model == "feedforward":
+ model = create_snn(
+ model_type="feedforward",
+ input_dim=D,
+ hidden_dims=args.hidden,
+ num_classes=C,
+ beta=args.beta,
+ threshold=args.threshold,
+ spike_grad=spike_grad,
+ dropout=args.dropout,
+ )
+ else: # recurrent
+ model = create_snn(
+ model_type="recurrent",
+ input_dim=D,
+ hidden_dims=args.hidden,
+ num_classes=C,
+ alpha=args.alpha,
+ beta=args.beta,
+ threshold=args.threshold,
+ spike_grad=spike_grad,
+ )
+
+ model = model.to(device)
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"Parameters: {num_params:,}")
+
+ optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ ce_loss = nn.CrossEntropyLoss()
+
+ print(f"\nTraining on {device} | lyapunov={args.lyapunov} λ_reg={args.lambda_reg} λ_target={args.lambda_target}")
+ print(f"Output: {run_dir}\n")
+
+ best_val_acc = 0.0
+
+ for epoch in range(1, args.epochs + 1):
+ t0 = time.time()
+
+ tr_loss, tr_acc, tr_lyap = train_one_epoch(
+ model=model,
+ loader=train_loader,
+ optimizer=optimizer,
+ device=device,
+ ce_loss=ce_loss,
+ lyapunov=args.lyapunov,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ lyap_layers=args.lyap_layers,
+ progress=(not args.no_progress),
+ run_dir=run_dir,
+ epoch_idx=epoch,
+ log_batches=args.log_batches,
+ )
+
+ val_loss, val_acc = evaluate(
+ model=model,
+ loader=val_loader,
+ device=device,
+ ce_loss=ce_loss,
+ progress=(not args.no_progress),
+ )
+
+ dt = time.time() - t0
+ lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else ""
+
+ print(
+ f"[Epoch {epoch:3d}] "
+ f"train_loss={tr_loss:.4f} train_acc={tr_acc:.3f}{lyap_str} | "
+ f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} ({dt:.1f}s)"
+ )
+
+ _append_metrics(
+ os.path.join(run_dir, "metrics.csv"),
+ {
+ "step": "epoch",
+ "epoch": int(epoch),
+ "batch": -1,
+ "loss": float(tr_loss),
+ "acc": float(tr_acc),
+ "val_loss": float(val_loss),
+ "val_acc": float(val_acc),
+ "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"),
+ "time_sec": float(dt),
+ },
+ )
+
+ # Save best model
+ if args.save_model and val_acc > best_val_acc:
+ best_val_acc = val_acc
+ torch.save(model.state_dict(), os.path.join(run_dir, "best_model.pt"))
+
+ print(f"\nTraining complete. Best val_acc: {best_val_acc:.3f}")
+ if args.save_model:
+ torch.save(model.state_dict(), os.path.join(run_dir, "final_model.pt"))
+ print(f"Model saved to {run_dir}")
+
+
+if __name__ == "__main__":
+ main()