diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files')
42 files changed, 8436 insertions, 0 deletions
diff --git a/files/__init__.py b/files/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/__init__.py diff --git a/files/analysis/plot_mvp.py b/files/analysis/plot_mvp.py new file mode 100644 index 0000000..495a9c2 --- /dev/null +++ b/files/analysis/plot_mvp.py @@ -0,0 +1,139 @@ +import argparse +import csv +import json +import os +from glob import glob +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt + + +def load_run(run_dir: str) -> Dict: + args_path = os.path.join(run_dir, "args.json") + metrics_path = os.path.join(run_dir, "metrics.csv") + if not (os.path.exists(args_path) and os.path.exists(metrics_path)): + return {} + with open(args_path, "r") as f: + args = json.load(f) + epochs = [] + loss = [] + acc = [] + lyap = [] + with open(metrics_path, "r") as f: + reader = csv.DictReader(f) + for row in reader: + if row.get("step", "") != "epoch": + continue + try: + e = int(row.get("epoch", "0")) + l = float(row.get("loss", "nan")) + a = float(row.get("acc", "nan")) + y = row.get("lyap", "nan") + y = float("nan") if (y is None or y == "" or str(y).lower() == "nan") else float(y) + except Exception: + continue + epochs.append(e) + loss.append(l) + acc.append(a) + lyap.append(y) + if not epochs: + return {} + return { + "args": args, + "epochs": epochs, + "loss": loss, + "acc": acc, + "lyap": lyap, + "run_dir": run_dir, + } + + +def label_for_run(run: Dict) -> str: + args = run["args"] + if args.get("lyapunov", False): + lam = args.get("lambda_reg", None) + tgt = args.get("lambda_target", None) + hid = args.get("hidden", None) + return f"Lyap λ={lam}, tgt={tgt}, H={hid}" + else: + hid = args.get("hidden", None) + return f"Baseline H={hid}" + + +def gather_runs(base_dir: str) -> List[Dict]: + cand = sorted(glob(os.path.join(base_dir, "*"))) + runs = [] + for rd in cand: + data = load_run(rd) + if data: + runs.append(data) + return runs + + +def plot_runs(runs: List[Dict], out_path: str): + if not runs: + raise SystemExit("No valid runs found (expected args.json and metrics.csv under base_dir/*)") + # Split baseline vs lyapunov + base = [r for r in runs if not r["args"].get("lyapunov", False)] + lyap = [r for r in runs if r["args"].get("lyapunov", False)] + + fig, axes = plt.subplots(1, 3, figsize=(14, 4.5)) + + # Loss + ax = axes[0] + for r in base: + ax.plot(r["epochs"], r["loss"], label=label_for_run(r), alpha=0.9) + for r in lyap: + ax.plot(r["epochs"], r["loss"], label=label_for_run(r), linestyle="--", alpha=0.9) + ax.set_title("Training loss") + ax.set_xlabel("Epoch") + ax.set_ylabel("Loss") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + + # Accuracy + ax = axes[1] + for r in base: + ax.plot(r["epochs"], r["acc"], label=label_for_run(r), alpha=0.9) + for r in lyap: + ax.plot(r["epochs"], r["acc"], label=label_for_run(r), linestyle="--", alpha=0.9) + ax.set_title("Training accuracy") + ax.set_xlabel("Epoch") + ax.set_ylabel("Accuracy") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + + # Lyapunov estimate (only lyap runs) + ax = axes[2] + if lyap: + for r in lyap: + ax.plot(r["epochs"], r["lyap"], label=label_for_run(r), alpha=0.9) + ax.set_title("Surrogate Lyapunov estimate") + ax.set_xlabel("Epoch") + ax.set_ylabel("Avg log growth") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=8) + else: + ax.text(0.5, 0.5, "No Lyapunov runs found", ha="center", va="center", transform=ax.transAxes) + ax.set_axis_off() + + fig.tight_layout() + os.makedirs(os.path.dirname(out_path), exist_ok=True) + fig.savefig(out_path, dpi=150) + print(f"Saved figure to {out_path}") + + +def main(): + ap = argparse.ArgumentParser(description="Plot MVP runs: baseline vs Lyapunov comparison") + ap.add_argument("--base_dir", type=str, default="runs/mvp", help="Directory containing run subfolders") + ap.add_argument("--out", type=str, default="runs/mvp_summary.png", help="Output figure path") + args = ap.parse_args() + + runs = gather_runs(args.base_dir) + plot_runs(runs, args.out) + + +if __name__ == "__main__": + main() + + diff --git a/files/analysis/stability_monitor.py b/files/analysis/stability_monitor.py new file mode 100644 index 0000000..18cad0f --- /dev/null +++ b/files/analysis/stability_monitor.py @@ -0,0 +1,395 @@ +""" +Stability monitoring utilities for SNN training. + +Provides metrics to diagnose training stability: +- Lyapunov exponent (trajectory divergence) +- Gradient norms (vanishing/exploding) +- Firing rates (dead/saturated neurons) +- Membrane potential statistics +""" + +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, field + +import torch +import torch.nn as nn +import numpy as np + + +@dataclass +class StabilityMetrics: + """Container for stability measurements.""" + lyapunov: Optional[float] = None + grad_norm: Optional[float] = None + grad_max_sv: Optional[float] = None # Max singular value of gradients + grad_min_sv: Optional[float] = None # Min singular value of gradients + grad_condition: Optional[float] = None # Condition number (max_sv / min_sv) + firing_rate_mean: Optional[float] = None + firing_rate_std: Optional[float] = None + dead_neuron_frac: Optional[float] = None + saturated_neuron_frac: Optional[float] = None + membrane_mean: Optional[float] = None + membrane_std: Optional[float] = None + + def to_dict(self) -> Dict[str, float]: + return {k: v for k, v in self.__dict__.items() if v is not None} + + def __str__(self) -> str: + parts = [] + if self.lyapunov is not None: + parts.append(f"λ={self.lyapunov:.4f}") + if self.grad_norm is not None: + parts.append(f"∇={self.grad_norm:.4f}") + if self.firing_rate_mean is not None: + parts.append(f"fr={self.firing_rate_mean:.3f}±{self.firing_rate_std:.3f}") + if self.dead_neuron_frac is not None: + parts.append(f"dead={self.dead_neuron_frac:.1%}") + if self.saturated_neuron_frac is not None: + parts.append(f"sat={self.saturated_neuron_frac:.1%}") + return " | ".join(parts) + + +class StabilityMonitor: + """ + Monitor SNN stability during training. + + Usage: + monitor = StabilityMonitor() + + # During training + logits, lyap = model(x, compute_lyapunov=True) + loss.backward() + + metrics = monitor.compute( + model=model, + lyapunov=lyap, + spikes=spike_recordings, # optional + membrane=membrane_recordings, # optional + ) + print(metrics) + """ + + def __init__( + self, + dead_threshold: float = 0.01, + saturated_threshold: float = 0.9, + history_size: int = 100, + ): + """ + Args: + dead_threshold: Firing rate below this = dead neuron + saturated_threshold: Firing rate above this = saturated neuron + history_size: Number of batches to track for moving averages + """ + self.dead_threshold = dead_threshold + self.saturated_threshold = saturated_threshold + self.history_size = history_size + + # History tracking + self.lyap_history: List[float] = [] + self.grad_history: List[float] = [] + self.fr_history: List[float] = [] + + def compute_gradient_norm(self, model: nn.Module) -> float: + """Compute total gradient norm across all parameters.""" + total_norm = 0.0 + for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.data.norm(2).item() ** 2 + return total_norm ** 0.5 + + def compute_gradient_singular_values( + self, + model: nn.Module, + top_k: int = 5, + ) -> Dict[str, Tuple[np.ndarray, float]]: + """ + Compute singular values of gradient matrices. + + Per Gradient Flossing paper: singular value spectrum reveals + gradient pathologies (vanishing/exploding/rank collapse). + + Args: + model: The SNN model + top_k: Number of top/bottom singular values to return + + Returns: + Dict mapping layer name to (singular_values, condition_number) + """ + results = {} + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + # Only compute for weight matrices (2D) + with torch.no_grad(): + G = param.grad.detach().cpu() + try: + # Full SVD is expensive; use truncated for large matrices + if G.shape[0] * G.shape[1] > 1e6: + # For very large matrices, just compute extremes + U, S, V = torch.svd_lowrank(G, q=min(top_k, min(G.shape))) + sv = S.numpy() + else: + sv = torch.linalg.svdvals(G).numpy() + + max_sv = sv[0] if len(sv) > 0 else 0 + min_sv = sv[-1] if len(sv) > 0 else 0 + condition = max_sv / (min_sv + 1e-12) + + results[name] = (sv[:top_k], condition) + except Exception: + pass # Skip if SVD fails + return results + + def get_aggregate_gradient_sv(self, model: nn.Module) -> Tuple[float, float, float]: + """ + Get aggregate gradient singular value statistics. + + Returns: + (max_sv, min_sv, avg_condition_number) across all layers + """ + sv_results = self.compute_gradient_singular_values(model) + if not sv_results: + return 0.0, 0.0, 1.0 + + max_svs = [] + min_svs = [] + conditions = [] + + for name, (sv, cond) in sv_results.items(): + if len(sv) > 0: + max_svs.append(sv[0]) + min_svs.append(sv[-1]) + conditions.append(cond) + + if not max_svs: + return 0.0, 0.0, 1.0 + + return ( + float(np.max(max_svs)), + float(np.min(min_svs)), + float(np.mean(conditions)) + ) + + def compute_firing_stats( + self, + spikes: torch.Tensor, + ) -> Tuple[float, float, float, float]: + """ + Compute firing rate statistics. + + Args: + spikes: Spike tensor, shape (B, T, H) or (T, H) or (B, H) + Values should be 0/1. + + Returns: + (mean_rate, std_rate, dead_frac, saturated_frac) + """ + with torch.no_grad(): + # Flatten to (num_samples, num_neurons) if needed + if spikes.ndim == 3: + # (B, T, H) -> compute rate per neuron per sample + rates = spikes.float().mean(dim=1) # (B, H) + elif spikes.ndim == 2: + # Could be (T, H) or (B, H) - assume (T, H) for single sample + rates = spikes.float().mean(dim=0, keepdim=True) # (1, H) + else: + rates = spikes.float().unsqueeze(0) + + # Per-neuron average rate across batch + neuron_rates = rates.mean(dim=0) # (H,) + + mean_rate = neuron_rates.mean().item() + std_rate = neuron_rates.std().item() + + dead_frac = (neuron_rates < self.dead_threshold).float().mean().item() + saturated_frac = (neuron_rates > self.saturated_threshold).float().mean().item() + + return mean_rate, std_rate, dead_frac, saturated_frac + + def compute_membrane_stats( + self, + membrane: torch.Tensor, + ) -> Tuple[float, float]: + """ + Compute membrane potential statistics. + + Args: + membrane: Membrane potential tensor, any shape + + Returns: + (mean, std) + """ + with torch.no_grad(): + return membrane.mean().item(), membrane.std().item() + + def compute( + self, + model: nn.Module, + lyapunov: Optional[torch.Tensor] = None, + spikes: Optional[torch.Tensor] = None, + membrane: Optional[torch.Tensor] = None, + compute_sv: bool = False, + ) -> StabilityMetrics: + """ + Compute all available stability metrics. + + Args: + model: The SNN model (for gradient norms) + lyapunov: Lyapunov exponent from forward pass + spikes: Recorded spikes (optional) + membrane: Recorded membrane potentials (optional) + compute_sv: Whether to compute gradient singular values (expensive) + + Returns: + StabilityMetrics object + """ + metrics = StabilityMetrics() + + # Lyapunov exponent + if lyapunov is not None: + if isinstance(lyapunov, torch.Tensor): + lyapunov = lyapunov.item() + metrics.lyapunov = lyapunov + self.lyap_history.append(lyapunov) + if len(self.lyap_history) > self.history_size: + self.lyap_history.pop(0) + + # Gradient norm + grad_norm = self.compute_gradient_norm(model) + metrics.grad_norm = grad_norm + self.grad_history.append(grad_norm) + if len(self.grad_history) > self.history_size: + self.grad_history.pop(0) + + # Gradient singular values (optional, expensive) + if compute_sv: + max_sv, min_sv, avg_cond = self.get_aggregate_gradient_sv(model) + metrics.grad_max_sv = max_sv + metrics.grad_min_sv = min_sv + metrics.grad_condition = avg_cond + + # Firing rate statistics + if spikes is not None: + fr_mean, fr_std, dead_frac, sat_frac = self.compute_firing_stats(spikes) + metrics.firing_rate_mean = fr_mean + metrics.firing_rate_std = fr_std + metrics.dead_neuron_frac = dead_frac + metrics.saturated_neuron_frac = sat_frac + self.fr_history.append(fr_mean) + if len(self.fr_history) > self.history_size: + self.fr_history.pop(0) + + # Membrane potential statistics + if membrane is not None: + mem_mean, mem_std = self.compute_membrane_stats(membrane) + metrics.membrane_mean = mem_mean + metrics.membrane_std = mem_std + + return metrics + + def get_trends(self) -> Dict[str, str]: + """ + Analyze trends in stability metrics. + + Returns: + Dictionary with trend analysis for each metric. + """ + trends = {} + + if len(self.lyap_history) >= 10: + recent = np.mean(self.lyap_history[-10:]) + older = np.mean(self.lyap_history[:10]) + if recent > older + 0.1: + trends["lyapunov"] = "⚠️ INCREASING (becoming unstable)" + elif recent < older - 0.1: + trends["lyapunov"] = "✓ DECREASING (stabilizing)" + else: + trends["lyapunov"] = "→ STABLE" + + if len(self.grad_history) >= 10: + recent = np.mean(self.grad_history[-10:]) + older = np.mean(self.grad_history[:10]) + ratio = recent / (older + 1e-8) + if ratio > 10: + trends["gradients"] = "⚠️ EXPLODING" + elif ratio < 0.1: + trends["gradients"] = "⚠️ VANISHING" + else: + trends["gradients"] = "✓ STABLE" + + if len(self.fr_history) >= 10: + recent = np.mean(self.fr_history[-10:]) + if recent < 0.01: + trends["firing"] = "⚠️ DEAD NETWORK" + elif recent > 0.8: + trends["firing"] = "⚠️ SATURATED" + else: + trends["firing"] = "✓ HEALTHY" + + return trends + + def diagnose(self) -> str: + """Generate a diagnostic summary.""" + trends = self.get_trends() + + lines = ["=== Stability Diagnosis ==="] + + if self.lyap_history: + avg_lyap = np.mean(self.lyap_history[-20:]) + lines.append(f"Lyapunov exponent: {avg_lyap:.4f}") + if avg_lyap > 0.5: + lines.append(" → Network is CHAOTIC (trajectories diverge quickly)") + lines.append(" → Suggestion: Increase lambda_reg or decrease learning rate") + elif avg_lyap < -0.5: + lines.append(" → Network is OVER-STABLE (trajectories collapse)") + lines.append(" → Suggestion: May lose expressiveness, consider reducing regularization") + else: + lines.append(" → Network is at EDGE OF CHAOS (good for learning)") + + if self.grad_history: + avg_grad = np.mean(self.grad_history[-20:]) + max_grad = max(self.grad_history[-20:]) + lines.append(f"Gradient norm: avg={avg_grad:.4f}, max={max_grad:.4f}") + if "gradients" in trends: + lines.append(f" → {trends['gradients']}") + + if self.fr_history: + avg_fr = np.mean(self.fr_history[-20:]) + lines.append(f"Firing rate: {avg_fr:.4f}") + if "firing" in trends: + lines.append(f" → {trends['firing']}") + + return "\n".join(lines) + + +def compute_spectral_radius(weight_matrix: torch.Tensor) -> float: + """ + Compute spectral radius of a weight matrix. + + For recurrent networks, spectral radius > 1 indicates potential instability. + + Args: + weight_matrix: 2D weight tensor + + Returns: + Spectral radius (largest absolute eigenvalue) + """ + with torch.no_grad(): + W = weight_matrix.detach().cpu().numpy() + eigenvalues = np.linalg.eigvals(W) + return float(np.max(np.abs(eigenvalues))) + + +def analyze_weight_spectrum(model: nn.Module) -> Dict[str, float]: + """ + Analyze spectral properties of all weight matrices. + + Returns: + Dictionary mapping layer names to spectral radii. + """ + results = {} + for name, param in model.named_parameters(): + if "weight" in name and param.ndim == 2: + if param.shape[0] == param.shape[1]: # Square matrix (recurrent) + results[name] = compute_spectral_radius(param) + return results diff --git a/files/data_io/__init__.py b/files/data_io/__init__.py new file mode 100644 index 0000000..de423db --- /dev/null +++ b/files/data_io/__init__.py @@ -0,0 +1,3 @@ +""" +data_io package: unified data loading, encoding, and preprocessing for spiking neural networks. +""" diff --git a/files/data_io/benchmark_datasets.py b/files/data_io/benchmark_datasets.py new file mode 100644 index 0000000..302ed51 --- /dev/null +++ b/files/data_io/benchmark_datasets.py @@ -0,0 +1,360 @@ +""" +Challenging benchmark datasets for deep SNN evaluation. + +Datasets: +1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps +2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order +3. CIFAR-10 with rate coding +4. DVS-CIFAR10 (requires tonic library) + +These benchmarks are harder than SHD and benefit from deeper networks. +""" + +import os +from typing import Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader + + +class SequentialMNIST(Dataset): + """ + Sequential MNIST - feed pixels one at a time. + + Each 28x28 image becomes a sequence of 784 timesteps, + each with a single pixel intensity converted to spike probability. + + This is MUCH harder than standard MNIST because: + - Network must remember information across 784 timesteps + - Tests long-range temporal dependencies + - Shallow networks fail due to vanishing gradients + + Args: + root: Data directory + train: Train or test split + permute: If True, use fixed random permutation (psMNIST) + spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity) + max_rate: Maximum firing rate for rate coding + seed: Random seed for permutation + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + permute: bool = False, + spike_encoding: str = "rate", + max_rate: float = 100.0, + n_repeat: int = 1, # Repeat each pixel n times for more spikes + seed: int = 42, + download: bool = True, + ): + try: + from torchvision import datasets, transforms + except ImportError: + raise ImportError("torchvision required: pip install torchvision") + + self.train = train + self.permute = permute + self.spike_encoding = spike_encoding + self.max_rate = max_rate + self.n_repeat = n_repeat + + # Load MNIST + self.mnist = datasets.MNIST( + root=root, + train=train, + download=download, + transform=transforms.ToTensor(), + ) + + # Create fixed permutation for psMNIST + if permute: + rng = np.random.RandomState(seed) + self.perm = torch.from_numpy(rng.permutation(784)) + else: + self.perm = None + + def __len__(self): + return len(self.mnist) + + def __getitem__(self, idx): + img, label = self.mnist[idx] + + # Flatten to (784,) + pixels = img.view(-1) + + # Apply permutation + if self.perm is not None: + pixels = pixels[self.perm] + + # Convert to spike sequence (T, 1) where T = 784 * n_repeat + T = 784 * self.n_repeat + + if self.spike_encoding == "direct": + # Direct intensity: repeat each pixel n_repeat times + spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) + + elif self.spike_encoding == "rate": + # Rate coding: Poisson spikes based on intensity + probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins + probs = probs.clamp(0, 1) + # Repeat and sample + probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) + spikes = torch.bernoulli(probs_expanded) + + elif self.spike_encoding == "latency": + # Latency coding: spike time proportional to intensity + # High intensity = early spike, low = late spike + spikes = torch.zeros(T, 1) + for i, p in enumerate(pixels): + if p > 0.1: # Threshold for spiking + # Spike time: higher intensity = earlier + spike_time = int((1 - p) * (self.n_repeat - 1)) + t = i * self.n_repeat + spike_time + spikes[t, 0] = 1.0 + + return spikes, label + + +class RateCodingCIFAR10(Dataset): + """ + CIFAR-10 with rate coding for SNNs. + + Converts 32x32x3 images to spike trains: + - Each pixel channel becomes a Poisson spike train + - Total input dimension: 32*32*3 = 3072 + - Sequence length: T timesteps + + Args: + root: Data directory + train: Train or test split + T: Number of timesteps + max_rate: Maximum firing rate (Hz) + flatten: If True, flatten spatial dimensions + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + T: int = 100, + max_rate: float = 200.0, + flatten: bool = True, + download: bool = True, + ): + try: + from torchvision import datasets, transforms + except ImportError: + raise ImportError("torchvision required: pip install torchvision") + + self.T = T + self.max_rate = max_rate + self.flatten = flatten + + # Normalize to [0, 1] + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + self.cifar = datasets.CIFAR10( + root=root, + train=train, + download=download, + transform=transform, + ) + + def __len__(self): + return len(self.cifar) + + def __getitem__(self, idx): + img, label = self.cifar[idx] # (3, 32, 32) + + if self.flatten: + img = img.view(-1) # (3072,) + + # Rate coding + prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps + prob_per_step = prob_per_step.clamp(0, 1) + + # Generate spikes for T timesteps + if self.flatten: + probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072) + else: + probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32) + + spikes = torch.bernoulli(probs) + + return spikes, label + + +class DVSCIFAR10(Dataset): + """ + DVS-CIFAR10 dataset wrapper. + + Requires the 'tonic' library for neuromorphic datasets: + pip install tonic + + DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching + CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark. + + Args: + root: Data directory + train: Train or test split + T: Number of time bins + spatial_downsample: Downsample spatial resolution + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + T: int = 100, + dt_ms: float = 10.0, + download: bool = True, + ): + try: + import tonic + from tonic import transforms as tonic_transforms + except ImportError: + raise ImportError( + "tonic library required for DVS datasets: pip install tonic" + ) + + self.T = T + + # Time binning transform + sensor_size = tonic.datasets.CIFAR10DVS.sensor_size + frame_transform = tonic_transforms.ToFrame( + sensor_size=sensor_size, + time_window=dt_ms * 1000, # Convert to microseconds + ) + + self.dataset = tonic.datasets.CIFAR10DVS( + save_to=root, + train=train, + transform=frame_transform, + ) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities + + # Convert to tensor and flatten spatial dims + frames = torch.from_numpy(frames).float() + + # Adjust to target T + T_actual = frames.shape[0] + if T_actual > self.T: + # Subsample + indices = torch.linspace(0, T_actual - 1, self.T).long() + frames = frames[indices] + elif T_actual < self.T: + # Pad with zeros + pad = torch.zeros(self.T - T_actual, *frames.shape[1:]) + frames = torch.cat([frames, pad], dim=0) + + # Flatten: (T, 2, H, W) -> (T, 2*H*W) + frames = frames.view(self.T, -1) + + return frames, label + + +def get_benchmark_dataloader( + dataset_name: str, + batch_size: int = 64, + root: str = "./data", + **kwargs, +) -> Tuple[DataLoader, DataLoader, dict]: + """ + Get train and validation dataloaders for a benchmark dataset. + + Args: + dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10' + batch_size: Batch size + root: Data directory + **kwargs: Additional arguments passed to dataset + + Returns: + train_loader, val_loader, info_dict + """ + + if dataset_name == "smnist": + train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs) + val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs) + info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, + "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"} + + elif dataset_name == "psmnist": + train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs) + val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs) + info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, + "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"} + + elif dataset_name == "cifar10": + T = kwargs.pop("T", 100) + train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs) + val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs) + info = {"T": T, "D": 3072, "classes": 10, + "description": "CIFAR-10 with rate coding"} + + elif dataset_name == "dvs_cifar10": + train_ds = DVSCIFAR10(root, train=True, **kwargs) + val_ds = DVSCIFAR10(root, train=False, **kwargs) + info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10, + "description": "DVS-CIFAR10 neuromorphic dataset"} + + else: + raise ValueError(f"Unknown dataset: {dataset_name}. " + f"Options: smnist, psmnist, cifar10, dvs_cifar10") + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_loader, val_loader, info + + +# Quick test +if __name__ == "__main__": + print("Testing benchmark datasets...\n") + + # Test sMNIST + print("1. Sequential MNIST") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "smnist", batch_size=32, n_repeat=1, spike_encoding="direct" + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + # Test psMNIST + print("\n2. Permuted Sequential MNIST") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct" + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + # Test CIFAR-10 + print("\n3. CIFAR-10 (rate coded)") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "cifar10", batch_size=32, T=50 + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + print("\nDone!") diff --git a/files/data_io/configs/__init__.py b/files/data_io/configs/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/data_io/configs/__init__.py diff --git a/files/data_io/configs/dvs.yaml b/files/data_io/configs/dvs.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/data_io/configs/dvs.yaml diff --git a/files/data_io/configs/shd.yaml b/files/data_io/configs/shd.yaml new file mode 100644 index 0000000..c786934 --- /dev/null +++ b/files/data_io/configs/shd.yaml @@ -0,0 +1,11 @@ +dataset: SHD +data_dir: /u/yurenh2/ml-projects/snn-training/files/data +shuffle: true +encoder: + type: poisson + max_rate: 50 + dt_ms: 1.0 + seed: 42 # 也可不写,默认 42 +transforms: + normalize: true + spike_jitter: 0.01 diff --git a/files/data_io/configs/ssc.yaml b/files/data_io/configs/ssc.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/data_io/configs/ssc.yaml diff --git a/files/data_io/dataset_loader.py b/files/data_io/dataset_loader.py new file mode 100644 index 0000000..983b67b --- /dev/null +++ b/files/data_io/dataset_loader.py @@ -0,0 +1,281 @@ +from typing import Optional +import yaml +import torch +from torch.utils.data import Dataset, DataLoader +import os +import h5py +import numpy as np +from .encoders import poisson_encoder, latency_encoder, rank_order_encoder +from .transforms import normalization, spike_augmentation +# from .utils import file_utils # 目前未使用可先注释 + + + +# ----- +# Base synthetic dataset placeholder +# ----- +class BaseDataset(Dataset): + """ + Abstract base class for datasets. + Each subclass must implement __getitem__ and __len__. + + This base implementation uses synthetic placeholders for quick smoke tests. + """ + + def __init__(self, data_dir, encoder, transforms=None): + self.data_dir = data_dir + self.encoder = encoder + self.transforms = transforms or [] + self.samples = list(range(200)) # placeholder synthetic samples + + def __getitem__(self, idx): + # synthetic static input -> encode to (T, D) via encoder + raw_data = torch.rand(128) # 128-dim intensity + label = torch.randint(0, 10, (1,)).item() + encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data + x = torch.as_tensor(encoded) + for t in self.transforms: + x = t(x) + return x, label + + def __len__(self): + return len(self.samples) + + +# ----- +# SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D +# ----- +import h5py # noqa: E402 + + +class SHDDataset(Dataset): + """ + SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T. + + H5 structure (Zenke Lab convention): + - f["labels"][i] -> scalar label for sample i + - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i + - f["spikes"]["units"][i] -> 1D array of channel ids for sample i + + We: + (1) scan once to determine global T = ceil(max_time / dt_ms) + (2) decide D from max unit id (fallback to default_D=700) + (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D) + """ + + def __init__( + self, + data_dir: str, + encoder=None, # ignored for SHD (already spiking events) + transforms=None, + split: str = "train", + dt_ms: float = 1.0, + seed: Optional[int] = None, + default_D: int = 700 + ): + super().__init__() + self.data_dir = data_dir + self.transforms = transforms or [] + self.dt_ms = float(dt_ms) + self.seed = 42 if seed is None else int(seed) + self.encoder = None # IMPORTANT: do not apply intensity encoders to event data + self.default_D = int(default_D) + + fname = f"shd_{split}.h5" + self.path = os.path.join(self.data_dir, fname) + if not os.path.exists(self.path): + raise FileNotFoundError(f"SHD file not found: {self.path}") + + with h5py.File(self.path, "r") as f: + # labels is dense array + self.labels = np.array(f["labels"], dtype=np.int64) + self.N = int(self.labels.shape[0]) + + # ragged datasets for events + times_ds = f["spikes"]["times"] + units_ds = f["spikes"]["units"] + + # scan once to compute global T and adaptive D + t_max_global = 0.0 + max_unit = -1 + for i in range(self.N): + ti = times_ds[i] + ui = units_ds[i] + if ti.size > 0: + last_t = float(ti[-1]) # ms + if last_t > t_max_global: + t_max_global = last_t + if ui.size > 0: + uimax = int(ui.max()) + if uimax > max_unit: + max_unit = uimax + + # decide D + if max_unit >= 0: + self.D = max(max_unit + 1, self.default_D) + else: + self.D = self.default_D + + # decide T from global max time + self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1 + + # rng in case transforms need it + self._rng = np.random.default_rng(self.seed) + + def __len__(self): + return self.N + + def __getitem__(self, idx: int): + # open file per-sample for worker safety + with h5py.File(self.path, "r") as f: + ti = f["spikes"]["times"][idx][:] + ui = f["spikes"]["units"][idx][:] + y = int(f["labels"][idx]) + + # bin events to (T, D) + spikes = np.zeros((self.T, self.D), dtype=np.float32) + if ti.size > 0: + bins = (ti / self.dt_ms).astype(np.int64) + bins = np.clip(bins, 0, self.T - 1) + ui = np.clip(ui.astype(np.int64), 0, self.D - 1) + spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1 + + x = torch.from_numpy(spikes) + + # apply transforms (on torch tensor) + for tr in self.transforms: + x = tr(x) + + return x, y + + +# ----- +# SSC / DVS placeholders (still synthetic; implement real readers later) +# ----- +class SSCDataset(BaseDataset): + """Placeholder SSC dataset (synthetic).""" + pass + + +class DVSDataset(BaseDataset): + """Placeholder DVS dataset (synthetic).""" + pass + + +# ----- +# Helpers: encoders / transforms / cfg path resolution +# ----- +def build_encoder(cfg): + """ + Build encoder from config dict. + + Expected schema: + encoder: + type: poisson | latency | rank_order + # Poisson-only optional fields: + max_rate: 50 + T: 64 + dt_ms: 1.0 + seed: 123 + """ + etype = cfg["type"].lower() + if etype == "poisson": + return poisson_encoder.PoissonEncoder( + max_rate=cfg.get("max_rate", cfg.get("rate", 20)), + T=cfg.get("T", 50), + dt_ms=cfg.get("dt_ms", 1.0), + seed=cfg.get("seed", None), + ) + elif etype == "latency": + return latency_encoder.LatencyEncoder() + elif etype == "rank_order": + return rank_order_encoder.RankOrderEncoder() + else: + raise ValueError(f"Unknown encoder type: {etype}") + + +def build_transforms(cfg): + tlist = [] + if cfg.get("normalize", False): + tlist.append(normalization.Normalize()) + if cfg.get("spike_jitter", None) is not None: + tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"])) + return tlist + + +def _resolve_cfg_path(cfg_path: str) -> str: + """ + Resolve cfg_path against: + 1) as-is (absolute or CWD-relative) + 2) relative to this package directory + 3) <pkg_dir>/configs/<basename> + """ + if os.path.isabs(cfg_path) and os.path.exists(cfg_path): + return cfg_path + if os.path.exists(cfg_path): + return cfg_path + pkg_dir = os.path.dirname(__file__) + cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path)) + if os.path.exists(cand2): + return cand2 + cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path)) + if os.path.exists(cand3): + return cand3 + raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}") + + +# ----- +# Entry: get_dataloader +# ----- +def get_dataloader(cfg_path): + """ + Create train/val DataLoader from YAML config. + Handles SHD as true event dataset (encoder=None), others as synthetic placeholders. + """ + cfg_path_resolved = _resolve_cfg_path(cfg_path) + with open(cfg_path_resolved, "r") as f: + cfg = yaml.safe_load(f) + + dataset_name = cfg["dataset"].lower() + data_dir = cfg["data_dir"] + transforms = build_transforms(cfg.get("transforms", {})) + + if dataset_name == "shd": + # event dataset: do NOT use intensity encoders here + dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0) + seed = cfg.get("encoder", {}).get("seed", 42) + ds_train = SHDDataset( + data_dir, encoder=None, transforms=transforms, split="train", + dt_ms=dt_ms, seed=seed + ) + ds_val = SHDDataset( + data_dir, encoder=None, transforms=transforms, split="test", + dt_ms=dt_ms, seed=seed + ) + elif dataset_name == "ssc": + # placeholder path; later implement true SSC reader + encoder = build_encoder(cfg["encoder"]) + ds_train = SSCDataset(data_dir, encoder, transforms) + ds_val = SSCDataset(data_dir, encoder, transforms) + elif dataset_name == "dvs": + encoder = build_encoder(cfg["encoder"]) + ds_train = DVSDataset(data_dir, encoder, transforms) + ds_val = DVSDataset(data_dir, encoder, transforms) + else: + raise ValueError(f"Unknown dataset: {dataset_name}") + + train_loader = DataLoader( + ds_train, + batch_size=cfg.get("batch_size", 16), + shuffle=cfg.get("shuffle", True), + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + ) + val_loader = DataLoader( + ds_val, + batch_size=cfg.get("batch_size", 16), + shuffle=False, + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + ) + return train_loader, val_loader
\ No newline at end of file diff --git a/files/data_io/encoders/__init__.py b/files/data_io/encoders/__init__.py new file mode 100644 index 0000000..b2c5dc3 --- /dev/null +++ b/files/data_io/encoders/__init__.py @@ -0,0 +1 @@ +"""Encoder submodule: provides various spike encoding strategies.""" diff --git a/files/data_io/encoders/base_encoder.py b/files/data_io/encoders/base_encoder.py new file mode 100644 index 0000000..e40451f --- /dev/null +++ b/files/data_io/encoders/base_encoder.py @@ -0,0 +1,10 @@ +import numpy as np +import torch + +class BaseEncoder: + """Abstract base class for all encoders.""" + def encode(self, data: np.ndarray) -> torch.Tensor: + """ + Convert static data (e.g., image, waveform) into spike tensor (T, input_dim). + """ + raise NotImplementedError diff --git a/files/data_io/encoders/latency_encoder.py b/files/data_io/encoders/latency_encoder.py new file mode 100644 index 0000000..a7804ae --- /dev/null +++ b/files/data_io/encoders/latency_encoder.py @@ -0,0 +1,13 @@ +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class LatencyEncoder(BaseEncoder): + """Encode input intensity into spike latency.""" + def __init__(self): + pass + + def encode(self, data: np.ndarray) -> torch.Tensor: + # TODO: map value→time delay + spikes = torch.zeros(10, data.size) # placeholder + return spikes diff --git a/files/data_io/encoders/poisson_encoder.py b/files/data_io/encoders/poisson_encoder.py new file mode 100644 index 0000000..0b404f7 --- /dev/null +++ b/files/data_io/encoders/poisson_encoder.py @@ -0,0 +1,91 @@ +from typing import Optional, Union +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class PoissonEncoder(BaseEncoder): + r""" + PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling. + + Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D} + by sampling at each time step t and dimension d: + S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ). + + Parameters + ---------- + max_rate : float + Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20. + T : int + Number of discrete time steps in the encoded spike train. Default: 50. + dt_ms : float + Time resolution per step in milliseconds. Default: 1.0 ms. + Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability. + seed : int or None + Optional RNG seed for reproducibility. If None, uses global RNG state. + + Notes + ----- + - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]). + If 1D, we return S with shape (T, D). + If 2D (batched), we broadcast probabilities across batch and return (T, B, D). + - Intensities outside [0,1] will be clipped to [0,1]. + - Device of returned tensor follows torch default device (CPU) unless you move it later. + """ + + def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None): + super().__init__() + self.max_rate = float(max_rate) + self.T = int(T) + self.dt_ms = float(dt_ms) + self.seed = seed + # local generator for reproducibility if seed is provided + self._g = torch.Generator() + if seed is not None: + self._g.manual_seed(int(seed)) + + def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + return np.asarray(data) + + def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """ + Convert input intensities to Poisson spike trains. + + Parameters + ---------- + data : np.ndarray or torch.Tensor + Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not). + + Returns + ------- + spikes : torch.Tensor + Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}. + """ + x = self._ensure_numpy(data) + # clip to [0,1] + x = np.clip(x, 0.0, 1.0) + + # compute per-step probability + p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity + # probability tensor (broadcast-friendly) + probs = x * p_step + probs = np.clip(probs, 0.0, 1.0) + + # create Bernoulli samples for T steps + # If input is 1D: probs.shape = (D,) -> output (T, D) + # If input is 2D: probs.shape = (B, D) -> output (T, B, D) + if probs.ndim == 1: + D = probs.shape[0] + probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D) + probs_t = torch.from_numpy(probs_t.astype(np.float32)) + spikes = torch.bernoulli(probs_t, generator=self._g) + return spikes + elif probs.ndim == 2: + B, D = probs.shape + probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D) + probs_t = torch.from_numpy(probs_t.astype(np.float32)) + spikes = torch.bernoulli(probs_t, generator=self._g) + return spikes + else: + raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}") diff --git a/files/data_io/encoders/rank_order_encoder.py b/files/data_io/encoders/rank_order_encoder.py new file mode 100644 index 0000000..9102e90 --- /dev/null +++ b/files/data_io/encoders/rank_order_encoder.py @@ -0,0 +1,10 @@ +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class RankOrderEncoder(BaseEncoder): + """Encode by rank order of input features.""" + def encode(self, data: np.ndarray) -> torch.Tensor: + # TODO: implement rank order conversion + spikes = torch.zeros(10, data.size) # placeholder + return spikes diff --git a/files/data_io/transforms/__init__.py b/files/data_io/transforms/__init__.py new file mode 100644 index 0000000..f6aa496 --- /dev/null +++ b/files/data_io/transforms/__init__.py @@ -0,0 +1 @@ +"""Transforms for spike data (normalization, augmentation, etc.).""" diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py new file mode 100644 index 0000000..c86b8c7 --- /dev/null +++ b/files/data_io/transforms/normalization.py @@ -0,0 +1,52 @@ +import torch + +class Normalize: + """Normalize spike input (z-score or min-max).""" + def __init__(self, mode: str = "zscore", eps: float = 1e-6): + """ + Parameters + ---------- + mode : str + One of {"zscore", "minmax"}. + - "zscore": per-sample, per-channel standardization across time. + - "minmax": per-sample, per-channel min-max scaling across time. + eps : float + Small constant to avoid division by zero. + """ + mode_l = str(mode).lower() + if mode_l not in {"zscore", "minmax"}: + raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}") + self.mode = mode_l + self.eps = float(eps) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply normalization. + Accepts tensors of shape (T, D) or (B, T, D). + Normalization is computed per-sample and per-channel over the time axis T. + """ + if not isinstance(x, torch.Tensor): + raise TypeError("Normalize expects a torch.Tensor") + if x.ndim == 2: + # (T, D) -> time dim = 0 + time_dim = 0 + keep_dims = True + elif x.ndim == 3: + # (B, T, D) -> time dim = 1 + time_dim = 1 + keep_dims = True + else: + raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}") + + if self.mode == "zscore": + mean_t = x.mean(dim=time_dim, keepdim=keep_dims) + # population std (unbiased=False) to avoid NaNs for small T + std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False) + x_norm = (x - mean_t) / (std_t + self.eps) + return x_norm + else: # "minmax" + min_t = x.amin(dim=time_dim, keepdim=keep_dims) + max_t = x.amax(dim=time_dim, keepdim=keep_dims) + denom = (max_t - min_t).clamp_min(self.eps) + x_scaled = (x - min_t) / denom + return x_scaled diff --git a/files/data_io/transforms/spike_augmentation.py b/files/data_io/transforms/spike_augmentation.py new file mode 100644 index 0000000..9b7b687 --- /dev/null +++ b/files/data_io/transforms/spike_augmentation.py @@ -0,0 +1,10 @@ +import torch + +class SpikeJitter: + """Add temporal jitter noise to spikes.""" + def __init__(self, std=0.01): + self.std = std + + def __call__(self, spikes: torch.Tensor) -> torch.Tensor: + # TODO: add random jitter to spike timings + return spikes diff --git a/files/data_io/utils/__init__.py b/files/data_io/utils/__init__.py new file mode 100644 index 0000000..ee3ab2f --- /dev/null +++ b/files/data_io/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions for file management, spike tools, and visualization.""" diff --git a/files/data_io/utils/file_utils.py b/files/data_io/utils/file_utils.py new file mode 100644 index 0000000..0a1e846 --- /dev/null +++ b/files/data_io/utils/file_utils.py @@ -0,0 +1,15 @@ +import os + +def ensure_dir(path: str): + """Ensure that a directory exists.""" + if not os.path.exists(path): + os.makedirs(path) + +def list_files(root: str, suffix: str): + """Recursively list files ending with suffix.""" + matches = [] + for dirpath, _, filenames in os.walk(root): + for f in filenames: + if f.endswith(suffix): + matches.append(os.path.join(dirpath, f)) + return matches diff --git a/files/data_io/utils/spike_tools.py b/files/data_io/utils/spike_tools.py new file mode 100644 index 0000000..968ee72 --- /dev/null +++ b/files/data_io/utils/spike_tools.py @@ -0,0 +1,10 @@ +import torch +import numpy as np + +def to_raster(spikes: torch.Tensor) -> np.ndarray: + """Convert spike tensor (T,B,N) to raster array (T,N).""" + return spikes.detach().cpu().numpy().mean(axis=1) + +def firing_rate(spikes: torch.Tensor, dt=1.0): + """Compute firing rate per neuron.""" + return spikes.sum(dim=0) / (spikes.shape[0] * dt) diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py new file mode 100644 index 0000000..6f0de95 --- /dev/null +++ b/files/data_io/utils/visualize.py @@ -0,0 +1,19 @@ +import matplotlib.pyplot as plt +import torch + +def plot_raster(spikes: torch.Tensor, title=None): + """ + Plot raster diagram of spike activity (T,B,N) or (T,N). + """ + s = spikes.detach().cpu() + if s.ndim == 3: + s = s[:, 0, :] # take first batch + t, n = s.shape + for i in range(n): + times = torch.nonzero(s[:, i]).squeeze().numpy() + plt.scatter(times, i * np.ones_like(times), s=2, c='black') + plt.xlabel("Time step") + plt.ylabel("Neuron index") + if title: + plt.title(title) + plt.show() diff --git a/files/experiments/benchmark_experiment.py b/files/experiments/benchmark_experiment.py new file mode 100644 index 0000000..fb01ff2 --- /dev/null +++ b/files/experiments/benchmark_experiment.py @@ -0,0 +1,518 @@ +""" +Benchmark Experiment: Compare Vanilla vs Lyapunov-Regularized SNN on real datasets. + +Datasets: +- Sequential MNIST (sMNIST): 784 timesteps, very hard for deep networks +- Permuted Sequential MNIST (psMNIST): Even harder, tests long-range memory +- CIFAR-10: Rate-coded images, requires hierarchical features + +Usage: + python files/experiments/benchmark_experiment.py --dataset smnist --depths 2 4 6 8 + python files/experiments/benchmark_experiment.py --dataset cifar10 --depths 4 6 8 10 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN +from files.data_io.benchmark_datasets import get_benchmark_dataloader + + +@dataclass +class EpochMetrics: + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + grad_max_sv: Optional[float] + grad_min_sv: Optional[float] + grad_condition: Optional[float] + time_sec: float + + +def compute_gradient_svs(model): + """Compute gradient singular value statistics.""" + max_svs = [] + min_svs = [] + + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + with torch.no_grad(): + G = param.grad.detach() + try: + sv = torch.linalg.svdvals(G) + if len(sv) > 0: + max_svs.append(sv[0].item()) + min_svs.append(sv[-1].item()) + except Exception: + pass + + if not max_svs: + return None, None, None + + max_sv = max(max_svs) + min_sv = min(min_svs) + condition = max_sv / (min_sv + 1e-12) + + return max_sv, min_sv, condition + + +def create_model( + input_dim: int, + num_classes: int, + depth: int, + hidden_dim: int = 128, + beta: float = 0.9, +) -> LyapunovSNN: + """Create SNN with specified depth.""" + hidden_dims = [hidden_dim] * depth + return LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=beta, + threshold=1.0, + ) + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + compute_sv_every: int = 10, +) -> Tuple[float, float, Optional[float], float, Optional[float], Optional[float], Optional[float]]: + """Train one epoch.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + grad_max_svs = [] + grad_min_svs = [] + grad_conditions = [] + + for batch_idx, (x, y) in enumerate(loader): + x, y = x.to(device), y.to(device) + + # Handle different input shapes + if x.ndim == 2: + x = x.unsqueeze(-1) # (B, T) -> (B, T, 1) + + optimizer.zero_grad() + + logits, lyap_est, _ = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + record_states=False, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan'), None, None, None + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + # Compute gradient SVs periodically + if batch_idx % compute_sv_every == 0: + max_sv, min_sv, cond = compute_gradient_svs(model) + if max_sv is not None: + grad_max_svs.append(max_sv) + grad_min_svs.append(min_sv) + grad_conditions.append(cond) + + optimizer.step() + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + avg_loss = total_loss / total_samples + avg_acc = total_correct / total_samples + avg_lyap = np.mean(lyap_vals) if lyap_vals else None + avg_grad = np.mean(grad_norms) + avg_max_sv = np.mean(grad_max_svs) if grad_max_svs else None + avg_min_sv = np.mean(grad_min_svs) if grad_min_svs else None + avg_cond = np.mean(grad_conditions) if grad_conditions else None + + return avg_loss, avg_acc, avg_lyap, avg_grad, avg_max_sv, avg_min_sv, avg_cond + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """Evaluate on validation set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + for x, y in loader: + x, y = x.to(device), y.to(device) + + if x.ndim == 2: + x = x.unsqueeze(-1) + + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + loss = ce_loss(logits, y) + + if torch.isnan(loss): + return float('nan'), 0.0 + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_experiment( + depth: int, + use_lyapunov: bool, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> List[EpochMetrics]: + """Run single experiment configuration.""" + torch.manual_seed(seed) + + model = create_model( + input_dim=input_dim, + num_classes=num_classes, + depth=depth, + hidden_dim=hidden_dim, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=lr) + ce_loss = nn.CrossEntropyLoss() + + method = "Lyapunov" if use_lyapunov else "Vanilla" + metrics_history = [] + + iterator = range(1, epochs + 1) + if progress: + iterator = tqdm(iterator, desc=f"D={depth} {method}", leave=False) + + for epoch in iterator: + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, ce_loss, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, + ) + + val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) + dt = time.time() - t0 + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=val_loss, + val_acc=val_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + time_sec=dt, + ) + metrics_history.append(metrics) + + if progress: + lyap_str = f"λ={lyap:.2f}" if lyap else "" + iterator.set_postfix({"acc": f"{val_acc:.3f}", "loss": f"{train_loss:.3f}", "lyap": lyap_str}) + + if np.isnan(train_loss): + print(f" Training diverged at epoch {epoch}") + break + + return metrics_history + + +def run_depth_comparison( + dataset_name: str, + depths: List[int], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> Dict[str, Dict[int, List[EpochMetrics]]]: + """Run comparison across depths.""" + results = {"vanilla": {}, "lyapunov": {}} + + for depth in depths: + print(f"\n{'='*60}") + print(f"Depth = {depth} layers") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + metrics = run_experiment( + depth=depth, + use_lyapunov=use_lyap, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + hidden_dim=hidden_dim, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=lyap_eps, + device=device, + seed=seed, + progress=progress, + ) + + results[method][depth] = metrics + + final = metrics[-1] + lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" + print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " + f"val_acc={final.val_acc:.3f} {lyap_str}") + + return results + + +def print_summary(results: Dict, dataset_name: str): + """Print summary table.""" + print("\n" + "=" * 90) + print(f"SUMMARY: {dataset_name.upper()} - Final Validation Accuracy") + print("=" * 90) + print(f"{'Depth':<8} {'Vanilla':<12} {'Lyapunov':<12} {'Δ Acc':<10} {'Van ∇norm':<12} {'Van κ':<12}") + print("-" * 90) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" + van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" + + print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<10} {van_grad:<12} {van_cond:<12}") + + print("=" * 90) + + # Gradient health analysis + print("\nGRADIENT HEALTH:") + for depth in depths: + van = results["vanilla"][depth][-1] + van_cond = van.grad_condition if van.grad_condition else 0 + if van_cond > 1e6: + print(f" Depth {depth}: ⚠️ Ill-conditioned gradients (κ={van_cond:.1e})") + elif van_cond > 1e4: + print(f" Depth {depth}: ~ Moderate conditioning (κ={van_cond:.1e})") + + +def save_results(results: Dict, output_dir: str, config: Dict): + """Save results to JSON.""" + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Benchmark experiment for Lyapunov SNN") + + # Dataset + p.add_argument("--dataset", type=str, default="smnist", + choices=["smnist", "psmnist", "cifar10"], + help="Dataset to use") + p.add_argument("--data_dir", type=str, default="./data") + + # Model + p.add_argument("--depths", type=int, nargs="+", default=[2, 4, 6, 8], + help="Network depths to test") + p.add_argument("--hidden_dim", type=int, default=128) + + # Training + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + + # Lyapunov + p.add_argument("--lambda_reg", type=float, default=0.3, + help="Lyapunov regularization weight (higher for harder tasks)") + p.add_argument("--lambda_target", type=float, default=-0.1, + help="Target Lyapunov exponent (negative for stability)") + p.add_argument("--lyap_eps", type=float, default=1e-4) + + # Dataset-specific + p.add_argument("--T", type=int, default=100, + help="Timesteps for CIFAR-10 (sMNIST uses 784)") + p.add_argument("--n_repeat", type=int, default=1, + help="Repeat each pixel n times for sMNIST") + + # Other + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--out_dir", type=str, default="runs/benchmark") + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print(f"BENCHMARK EXPERIMENT: {args.dataset.upper()}") + print("=" * 70) + print(f"Depths: {args.depths}") + print(f"Hidden dim: {args.hidden_dim}") + print(f"Epochs: {args.epochs}") + print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") + print(f"Device: {device}") + print("=" * 70) + + # Load dataset + print(f"\nLoading {args.dataset} dataset...") + + if args.dataset == "smnist": + train_loader, val_loader, info = get_benchmark_dataloader( + "smnist", + batch_size=args.batch_size, + root=args.data_dir, + n_repeat=args.n_repeat, + spike_encoding="direct", + ) + elif args.dataset == "psmnist": + train_loader, val_loader, info = get_benchmark_dataloader( + "psmnist", + batch_size=args.batch_size, + root=args.data_dir, + n_repeat=args.n_repeat, + spike_encoding="direct", + ) + elif args.dataset == "cifar10": + train_loader, val_loader, info = get_benchmark_dataloader( + "cifar10", + batch_size=args.batch_size, + root=args.data_dir, + T=args.T, + ) + + print(f"Dataset info: {info}") + print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}") + + # Run experiments + results = run_depth_comparison( + dataset_name=args.dataset, + depths=args.depths, + train_loader=train_loader, + val_loader=val_loader, + input_dim=info["D"], + num_classes=info["classes"], + hidden_dim=args.hidden_dim, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Print summary + print_summary(results, args.dataset) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/cifar10_conv_experiment.py b/files/experiments/cifar10_conv_experiment.py new file mode 100644 index 0000000..a582f9f --- /dev/null +++ b/files/experiments/cifar10_conv_experiment.py @@ -0,0 +1,448 @@ +""" +CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization. + +Uses proper convolutional architecture that preserves spatial structure. +Tests whether Lyapunov regularization helps train deeper Conv-SNNs. + +Architecture: + Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output + +Usage: + python files/experiments/cifar10_conv_experiment.py --model simple --T 25 + python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +from files.models.conv_snn import create_conv_snn + + +@dataclass +class EpochMetrics: + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + time_sec: float + + +def get_cifar10_loaders( + data_dir: str = './data', + batch_size: int = 128, + num_workers: int = 4, +) -> Tuple[DataLoader, DataLoader]: + """ + Get CIFAR-10 dataloaders with standard normalization. + + Images normalized to [0, 1] for rate encoding. + """ + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + # Note: For rate encoding, we keep values in [0, 1] + # No normalization to negative values + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + ]) + + train_dataset = datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transform_train + ) + test_dataset = datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transform_test + ) + + train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True + ) + test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True + ) + + return train_loader, test_loader + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + progress: bool = True, +) -> Tuple[float, float, Optional[float], float]: + """Train one epoch.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32) + + optimizer.zero_grad() + + logits, lyap_est, _ = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan') + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + optimizer.step() + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + if progress: + iterator.set_postfix({ + "loss": f"{loss.item():.3f}", + "acc": f"{total_correct/total_samples:.3f}", + }) + + return ( + total_loss / total_samples, + total_correct / total_samples, + np.mean(lyap_vals) if lyap_vals else None, + np.mean(grad_norms), + ) + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, + progress: bool = True, +) -> Tuple[float, float]: + """Evaluate on test set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + iterator = tqdm(loader, desc="eval", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + + loss = ce_loss(logits, y) + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_experiment( + model_type: str, + channels: List[int], + T: int, + use_lyapunov: bool, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> List[EpochMetrics]: + """Run single experiment.""" + torch.manual_seed(seed) + + model = create_conv_snn( + model_type=model_type, + in_channels=3, + num_classes=10, + channels=channels, + T=T, + encoding='rate', + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" Model: {model_type}, params: {num_params:,}") + + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + ce_loss = nn.CrossEntropyLoss() + + metrics_history = [] + best_acc = 0.0 + + for epoch in range(1, epochs + 1): + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm = train_epoch( + model, train_loader, optimizer, ce_loss, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress + ) + + test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_acc = max(best_acc, test_acc) + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=test_loss, + val_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + time_sec=dt, + ) + metrics_history.append(metrics) + + lyap_str = f"λ={lyap:.3f}" if lyap else "" + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)") + + if np.isnan(train_loss): + print(" Training diverged!") + break + + print(f" Best test accuracy: {best_acc:.3f}") + return metrics_history + + +def run_comparison( + model_type: str, + channels_configs: List[List[int]], + T: int, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool, +) -> Dict: + """Compare vanilla vs Lyapunov across different depths.""" + results = {"vanilla": {}, "lyapunov": {}} + + for channels in channels_configs: + depth = len(channels) + print(f"\n{'='*60}") + print(f"Depth = {depth} conv layers, channels = {channels}") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + metrics = run_experiment( + model_type=model_type, + channels=channels, + T=T, + use_lyapunov=use_lyap, + train_loader=train_loader, + test_loader=test_loader, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=1e-4, + device=device, + seed=seed, + progress=progress, + ) + + results[method][depth] = metrics + + return results + + +def print_summary(results: Dict): + """Print comparison summary.""" + print("\n" + "=" * 70) + print("SUMMARY: CIFAR-10 Conv-SNN Results") + print("=" * 70) + print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}") + print("-" * 70) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + + print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") + + print("=" * 70) + + +def save_results(results: Dict, output_dir: str, config: Dict): + """Save results.""" + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser() + + # Model + p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"]) + p.add_argument("--channels", type=int, nargs="+", default=None, + help="Channel sizes (default: test multiple depths)") + p.add_argument("--T", type=int, default=25, help="Timesteps") + + # Training + p.add_argument("--epochs", type=int, default=50) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + + # Lyapunov + p.add_argument("--lambda_reg", type=float, default=0.3) + p.add_argument("--lambda_target", type=float, default=-0.1) + + # Other + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--out_dir", type=str, default="runs/cifar10_conv") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print("CIFAR-10 Conv-SNN Experiment") + print("=" * 70) + print(f"Model: {args.model}") + print(f"Timesteps: {args.T}") + print(f"Epochs: {args.epochs}") + print(f"Device: {device}") + print("=" * 70) + + # Load data + print("\nLoading CIFAR-10...") + train_loader, test_loader = get_cifar10_loaders( + data_dir=args.data_dir, + batch_size=args.batch_size, + ) + print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") + + # Define depth configurations to test + if args.channels: + channels_configs = [args.channels] + else: + # Test increasing depths + channels_configs = [ + [64, 128], # 2 conv layers (shallow) + [64, 128, 256], # 3 conv layers + [64, 128, 256, 512], # 4 conv layers (deep) + ] + + # Run comparison + results = run_comparison( + model_type=args.model, + channels_configs=channels_configs, + T=args.T, + train_loader=train_loader, + test_loader=test_loader, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Summary + print_summary(results) + + # Save + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/depth_comparison.py b/files/experiments/depth_comparison.py new file mode 100644 index 0000000..48c62d8 --- /dev/null +++ b/files/experiments/depth_comparison.py @@ -0,0 +1,542 @@ +""" +Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths. + +Hypothesis: +- Shallow networks (1-2 layers): Both methods train successfully +- Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds + +Usage: + # Quick test (synthetic data) + python files/experiments/depth_comparison.py --synthetic --epochs 20 + + # Full experiment with SHD data + python files/experiments/depth_comparison.py --epochs 50 + + # Specific depths to test + python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN +from files.analysis.stability_monitor import StabilityMonitor + + +@dataclass +class ExperimentConfig: + """Configuration for a single experiment run.""" + depth: int + hidden_dim: int + use_lyapunov: bool + lambda_reg: float + lambda_target: float + lyap_eps: float + epochs: int + lr: float + batch_size: int + beta: float + threshold: float + seed: int + + +@dataclass +class EpochMetrics: + """Metrics collected per epoch.""" + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + firing_rate: float + dead_neurons: float + time_sec: float + + +def create_synthetic_data( + n_train: int = 2000, + n_val: int = 500, + T: int = 50, + D: int = 100, + n_classes: int = 10, + seed: int = 42, +) -> Tuple[DataLoader, DataLoader]: + """Create synthetic spike data for testing.""" + torch.manual_seed(seed) + np.random.seed(seed) + + def generate_data(n_samples): + # Generate class-conditional spike patterns + x = torch.zeros(n_samples, T, D) + y = torch.randint(0, n_classes, (n_samples,)) + + for i in range(n_samples): + label = y[i].item() + # Each class has different firing rate pattern + base_rate = 0.05 + 0.02 * label + # Class-specific channels fire more + class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) + for t in range(T): + # Background activity + x[i, t] = (torch.rand(D) < base_rate).float() + # Enhanced activity for class-specific channels + for c in class_channels: + if torch.rand(1) < base_rate * 3: + x[i, t, c] = 1.0 + + return x, y + + x_train, y_train = generate_data(n_train) + x_val, y_val = generate_data(n_val) + + train_loader = DataLoader( + TensorDataset(x_train, y_train), + batch_size=64, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(x_val, y_val), + batch_size=64, + shuffle=False, + ) + + return train_loader, val_loader, T, D, n_classes + + +def create_model( + input_dim: int, + num_classes: int, + depth: int, + hidden_dim: int = 128, + beta: float = 0.9, + threshold: float = 1.0, +) -> LyapunovSNN: + """Create SNN with specified depth.""" + # Create hidden dims list based on depth + # Gradually decrease size for deeper networks to keep param count reasonable + hidden_dims = [] + current_dim = hidden_dim + for i in range(depth): + hidden_dims.append(current_dim) + # Optionally decrease dim in deeper layers + # current_dim = max(64, current_dim // 2) + + return LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=beta, + threshold=threshold, + ) + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + monitor: StabilityMonitor, +) -> Tuple[float, float, float, float, float, float]: + """Train for one epoch, return metrics.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + firing_rates = [] + dead_fracs = [] + + for x, y in loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, recordings = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + record_states=True, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + # Check for NaN + if torch.isnan(loss): + return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0 + + loss.backward() + + # Gradient clipping for stability comparison fairness + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + optimizer.step() + + # Collect metrics + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + # Stability metrics + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + if recordings is not None: + spikes = recordings['spikes'] + fr = spikes.mean().item() + dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item() + firing_rates.append(fr) + dead_fracs.append(dead) + + avg_loss = total_loss / total_samples + avg_acc = total_correct / total_samples + avg_lyap = np.mean(lyap_vals) if lyap_vals else None + avg_grad = np.mean(grad_norms) + avg_fr = np.mean(firing_rates) if firing_rates else 0.0 + avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0 + + return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """Evaluate model on validation set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + for x, y in loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + + loss = ce_loss(logits, y) + if torch.isnan(loss): + return float('nan'), 0.0 + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_single_experiment( + config: ExperimentConfig, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + device: torch.device, + progress: bool = True, +) -> List[EpochMetrics]: + """Run a single experiment with given configuration.""" + torch.manual_seed(config.seed) + + model = create_model( + input_dim=input_dim, + num_classes=num_classes, + depth=config.depth, + hidden_dim=config.hidden_dim, + beta=config.beta, + threshold=config.threshold, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=config.lr) + ce_loss = nn.CrossEntropyLoss() + monitor = StabilityMonitor() + + metrics_history = [] + method = "Lyapunov" if config.use_lyapunov else "Vanilla" + + iterator = range(1, config.epochs + 1) + if progress: + iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False) + + for epoch in iterator: + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch( + model=model, + loader=train_loader, + optimizer=optimizer, + ce_loss=ce_loss, + device=device, + use_lyapunov=config.use_lyapunov, + lambda_reg=config.lambda_reg, + lambda_target=config.lambda_target, + lyap_eps=config.lyap_eps, + monitor=monitor, + ) + + val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) + + dt = time.time() - t0 + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=val_loss, + val_acc=val_acc, + lyapunov=lyap, + grad_norm=grad_norm, + firing_rate=fr, + dead_neurons=dead, + time_sec=dt, + ) + metrics_history.append(metrics) + + # Early stopping if training diverged + if np.isnan(train_loss): + print(f" Training diverged at epoch {epoch}") + break + + return metrics_history + + +def run_depth_comparison( + depths: List[int], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + device: torch.device, + epochs: int = 30, + hidden_dim: int = 128, + lr: float = 1e-3, + lambda_reg: float = 0.1, + lambda_target: float = 0.0, + lyap_eps: float = 1e-4, + beta: float = 0.9, + seed: int = 42, + progress: bool = True, +) -> Dict[str, Dict[int, List[EpochMetrics]]]: + """ + Run comparison experiments across depths. + + Returns: + Dictionary with structure: + { + "vanilla": {1: [metrics...], 2: [metrics...], ...}, + "lyapunov": {1: [metrics...], 2: [metrics...], ...} + } + """ + results = {"vanilla": {}, "lyapunov": {}} + + for depth in depths: + print(f"\n{'='*50}") + print(f"Depth = {depth} layers") + print(f"{'='*50}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + config = ExperimentConfig( + depth=depth, + hidden_dim=hidden_dim, + use_lyapunov=use_lyap, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=lyap_eps, + epochs=epochs, + lr=lr, + batch_size=64, + beta=beta, + threshold=1.0, + seed=seed, + ) + + metrics = run_single_experiment( + config=config, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + device=device, + progress=progress, + ) + + results[method][depth] = metrics + + # Print final metrics + final = metrics[-1] + lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" + print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " + f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}") + + return results + + +def save_results(results: Dict, output_dir: str, config: dict): + """Save experiment results to JSON.""" + os.makedirs(output_dir, exist_ok=True) + + # Convert metrics to dicts + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]): + """Print summary comparison table.""" + print("\n" + "=" * 70) + print("SUMMARY: Final Validation Accuracy by Depth") + print("=" * 70) + print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}") + print("-" * 70) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van_metrics = results["vanilla"][depth] + lyap_metrics = results["lyapunov"][depth] + + van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0 + lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0 + + van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED" + lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED" + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") + + print("=" * 70) + + # Gradient analysis + print("\nGradient Norm Analysis (final epoch):") + print("-" * 70) + print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}") + print("-" * 70) + for depth in depths: + van_grad = results["vanilla"][depth][-1].grad_norm + lyap_grad = results["lyapunov"][depth][-1].grad_norm + print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths") + p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6], + help="Network depths to test") + p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer") + p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment") + p.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") + p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov") + p.add_argument("--beta", type=float, default=0.9, help="Membrane decay") + p.add_argument("--seed", type=int, default=42, help="Random seed") + p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") + p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--no-progress", action="store_true", help="Disable progress bars") + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print("Experiment: Vanilla vs Lyapunov-Regularized SNN") + print("=" * 70) + print(f"Depths: {args.depths}") + print(f"Hidden dim: {args.hidden_dim}") + print(f"Epochs: {args.epochs}") + print(f"Lambda_reg: {args.lambda_reg}") + print(f"Device: {device}") + + # Load data + if args.synthetic: + print("\nUsing SYNTHETIC data for quick testing") + train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed) + else: + print(f"\nLoading data from {args.cfg}") + from files.data_io.dataset_loader import get_dataloader + train_loader, val_loader = get_dataloader(args.cfg) + xb, _ = next(iter(train_loader)) + _, T, D = xb.shape + C = 20 # SHD has 20 classes + + print(f"Data: T={T}, D={D}, classes={C}") + + # Run experiments + results = run_depth_comparison( + depths=args.depths, + train_loader=train_loader, + val_loader=val_loader, + input_dim=D, + num_classes=C, + device=device, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + beta=args.beta, + seed=args.seed, + progress=not args.no_progress, + ) + + # Print summary + print_summary(results) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/depth_scaling_benchmark.py b/files/experiments/depth_scaling_benchmark.py new file mode 100644 index 0000000..efab140 --- /dev/null +++ b/files/experiments/depth_scaling_benchmark.py @@ -0,0 +1,1035 @@ +""" +Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs. + +Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve. + +Key hypothesis (from literature): +- Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet) +- Deep SNNs without regularization fail to train (gradient issues) +- Deep SNNs WITH Lyapunov regularization achieve higher accuracy + +Reference results: +- Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI) +- SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS) +- Spikformer on ImageNet: ~74.8% top-1 (arXiv) + +Usage: + python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +import snntorch as snn +from snntorch import surrogate + + +# ============================================================================= +# VGG-style Spiking Network (scalable depth) +# ============================================================================= + +class SpikingVGGBlock(nn.Module): + """Conv-BN-LIF block for VGG-style architecture.""" + + def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): + super().__init__() + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + + def forward(self, x, mem): + h = self.bn(self.conv(x)) + spk, mem = self.lif(h, mem) + return spk, mem + + +class SpikingVGG(nn.Module): + """ + Scalable VGG-style Spiking Neural Network. + + Architecture follows VGG pattern: + - Multiple conv blocks between pooling layers + - Depth controlled by num_blocks_per_stage + + Args: + in_channels: Input channels (3 for RGB) + num_classes: Output classes + base_channels: Starting channel count (doubled each stage) + num_stages: Number of pooling stages (3-4 typical) + blocks_per_stage: Conv blocks per stage (controls depth) + T: Number of timesteps + beta: LIF membrane decay + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + base_channels: int = 64, + num_stages: int = 3, + blocks_per_stage: int = 2, + T: int = 4, + beta: float = 0.9, + threshold: float = 1.0, + dropout: float = 0.25, + stable_init: bool = False, + ): + super().__init__() + + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + self.total_conv_layers = num_stages * blocks_per_stage + self.stable_init = stable_init + + spike_grad = surrogate.fast_sigmoid(slope=25) + + # Build stages + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + ch_in = in_channels + ch_out = base_channels + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for b in range(blocks_per_stage): + block_in = ch_in if b == 0 else ch_out + stage_blocks.append( + SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad) + ) + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + ch_in = ch_out + ch_out = min(ch_out * 2, 512) # Cap at 512 + + # Calculate spatial size after pooling + # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages) + final_spatial = 32 // (2 ** num_stages) + final_channels = min(base_channels * (2 ** (num_stages - 1)), 512) + fc_input = final_channels * final_spatial * final_spatial + + # Classifier + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(fc_input, num_classes) + + if stable_init: + self._init_weights_stable() + else: + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_weights_stable(self): + """ + Stability-aware initialization for SNNs. + + Uses smaller weight magnitudes to produce less chaotic initial dynamics. + The key insight: Lyapunov exponent depends on weight magnitudes. + Smaller weights → smaller gradients → more stable dynamics. + + Strategy: + - Use orthogonal init (preserves gradient magnitude across layers) + - Scale down by factor of 0.5 to reduce initial chaos + - This should produce λ closer to 0 from the start + """ + scale_factor = 0.5 # Reduce weight magnitudes + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # Orthogonal init for conv (reshape to 2D, init, reshape back) + weight_shape = m.weight.shape + fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3] + fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3] + + # Use smaller gain for stability + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + with torch.no_grad(): + m.weight.mul_(scale_factor) + + elif isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight, gain=scale_factor) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_mems(self, batch_size, device, dtype, P=1): + """Initialize membrane potentials for all LIF layers. + + Args: + batch_size: Batch size B + device: torch device + dtype: torch dtype + P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov) + + Returns: + List of membrane tensors with shape (P, B, C, H, W) + """ + mems = [] + H, W = 32, 32 + ch = 64 + + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + ch = min(ch * 2, 512) + + return mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + """ + Forward pass with optimized Lyapunov computation (Approach A: trajectory batching). + + When compute_lyapunov=True, both original and perturbed trajectories are + processed together by batching them along a new dimension P=2. This avoids + redundant computation, especially for the first conv layer where inputs are identical. + + Args: + x: (B, C, H, W) static image (will be repeated for T steps) + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude + + Returns: + logits, lyap_est, recordings + """ + B = x.size(0) + device, dtype = x.device, x.dtype + + # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed) + P = 2 if compute_lyapunov else 1 + + # Initialize membrane potentials with shape (P, B, C, H, W) + mems = self._init_mems(B, device, dtype, P=P) + + # Initialize perturbed trajectory + if compute_lyapunov: + for i in range(len(mems)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = None + + # Time loop - repeat static input + for t in range(self.T): + mem_idx = 0 + new_mems = [] + is_first_block = True + + # Process through stages + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + # First block: input x is identical for both trajectories + # Compute conv+bn ONCE, then expand to (P, B, C, H, W) + h_conv = block.bn(block.conv(x)) # (B, C, H, W) + h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy + + # LIF with batched membrane states + # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) + + # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + + h = spk + new_mems.append(mem_new) + is_first_block = False + else: + # Subsequent blocks: inputs differ between trajectories + # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + + # Full block forward (conv+bn+lif) + h_conv = block.bn(block.conv(h_flat)) + spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) + + # Reshape back + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + + h = spk + new_mems.append(mem_new) + + mem_idx += 1 + + # Pool: apply to batched tensor + h_flat = h.reshape(P * B, *h.shape[2:]) + h_pooled = pool(h_flat) + h = h_pooled.view(P, B, *h_pooled.shape[1:]) + + mems = new_mems + + # Accumulate final spikes from ORIGINAL trajectory only (index 0) + h_orig = h[0].view(B, -1) # (B, C*H*W) + if spike_sum is None: + spike_sum = h_orig + else: + spike_sum = spike_sum + h_orig + + # Lyapunov divergence and renormalization (Option 1: global delta + global renorm) + # This is the textbook Benettin-style Lyapunov exponent estimator where + # the perturbation is treated as one vector in the concatenated state space. + if compute_lyapunov: + # Compute GLOBAL divergence across all layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W) + delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # GLOBAL renormalization: same scale factor for all layers + # This ensures ||perturbation||_global = eps after renorm + scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting + + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + # Update perturbed trajectory: scale the diff to have global norm = eps + mems[i] = torch.stack([ + new_mems[i][0], + new_mems[i][0] + diff * scale + ], dim=0) + + # Readout + out = self.dropout(spike_sum) + logits = self.fc(out) + + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est, None + + @property + def depth(self): + return self.total_conv_layers + + +# ============================================================================= +# Dataset Loading +# ============================================================================= + +def get_dataset( + name: str, + data_dir: str = './data', + batch_size: int = 128, + num_workers: int = 4, +) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]: + """ + Get train/test loaders for various datasets. + + Returns: + train_loader, test_loader, num_classes, input_shape + """ + + if name == 'mnist': + transform = transforms.Compose([ + transforms.Resize(32), # Resize to 32x32 for consistency + transforms.ToTensor(), + ]) + train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform) + num_classes = 10 + input_shape = (1, 32, 32) + + elif name == 'fashion_mnist': + transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + ]) + train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform) + test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform) + num_classes = 10 + input_shape = (1, 32, 32) + + elif name == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train) + test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test) + num_classes = 10 + input_shape = (3, 32, 32) + + elif name == 'cifar100': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train) + test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test) + num_classes = 100 + input_shape = (3, 32, 32) + + else: + raise ValueError(f"Unknown dataset: {name}") + + train_loader = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True + ) + test_loader = DataLoader( + test_ds, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True + ) + + return train_loader, test_loader, num_classes, input_shape + + +# ============================================================================= +# Training +# ============================================================================= + +@dataclass +class TrainingMetrics: + epoch: int + train_loss: float + train_acc: float + test_loss: float + test_acc: float + lyapunov: Optional[float] + grad_norm: float + grad_max_sv: Optional[float] # Max singular value of gradients + grad_min_sv: Optional[float] # Min singular value of gradients + grad_condition: Optional[float] # Condition number + lr: float + time_sec: float + + +def compute_gradient_svs(model): + """Compute gradient singular value statistics for all weight matrices.""" + max_svs = [] + min_svs = [] + + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + with torch.no_grad(): + G = param.grad.detach() + try: + sv = torch.linalg.svdvals(G) + if len(sv) > 0: + max_svs.append(sv[0].item()) + min_svs.append(sv[-1].item()) + except Exception: + pass + + if not max_svs: + return None, None, None + + max_sv = max(max_svs) + min_sv = min(min_svs) + condition = max_sv / (min_sv + 1e-12) + + return max_sv, min_sv, condition + + +def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float, + lyap_threshold: float = 2.0) -> torch.Tensor: + """ + Compute Lyapunov regularization loss with different penalty types. + + Args: + lyap_est: Estimated Lyapunov exponent (scalar tensor) + reg_type: Type of regularization: + - "squared": (λ - target)² - original, aggressive + - "hinge": max(0, λ - threshold)² - only penalize chaos + - "asymmetric": strong penalty for chaos, weak for collapse + - "extreme": only penalize when λ > lyap_threshold (configurable) + - "adaptive_linear": penalty scales linearly with excess over threshold + - "adaptive_exp": penalty grows exponentially for severe chaos + - "adaptive_sigmoid": smooth sigmoid transition around threshold + lambda_target: Target value (used for squared reg_type) + lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0) + + Returns: + Regularization loss (scalar tensor) + """ + if reg_type == "squared": + # Original: penalize any deviation from target + return (lyap_est - lambda_target) ** 2 + + elif reg_type == "hinge": + # Only penalize when λ > threshold (too chaotic) + # threshold = 0 means: only penalize positive Lyapunov (chaos) + threshold = 0.0 + excess = torch.relu(lyap_est - threshold) + return excess ** 2 + + elif reg_type == "asymmetric": + # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1) + # This allows the network to be stable without being dead + chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0 + collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1 + return chaos_penalty + collapse_penalty + + elif reg_type == "extreme": + # Only penalize when λ > threshold (VERY chaotic) + # This allows moderate chaos while preventing extreme instability + # Threshold is now configurable via lyap_threshold argument + excess = torch.relu(lyap_est - lyap_threshold) + return excess ** 2 + + elif reg_type == "adaptive_linear": + # Penalty scales linearly with how far above threshold we are + # loss = excess * excess² = excess³ + # This naturally makes the penalty weaker for small excesses + # and much stronger for large excesses + excess = torch.relu(lyap_est - lyap_threshold) + return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic + + elif reg_type == "adaptive_exp": + # Exponential penalty for severe chaos + # loss = (exp(excess) - 1) * excess² for excess > 0 + # This gives very weak penalty near threshold, explosive penalty for chaos + excess = torch.relu(lyap_est - lyap_threshold) + # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after + exp_scale = torch.exp(excess) - 1.0 + return exp_scale * excess # exp(excess) * excess - excess + + elif reg_type == "adaptive_sigmoid": + # Smooth sigmoid transition around threshold + # The "sharpness" of transition is controlled by a temperature parameter + # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness + # Using T=0.5 for moderately sharp transition + temperature = 0.5 + weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature) + # Penalize deviation from target, weighted by how far past threshold + deviation = lyap_est - lambda_target + return weight * (deviation ** 2) + + # ========================================================================= + # SCALED MULTIPLIER REGULARIZATION + # loss = (λ_reg × g(relu(λ))) × relu(λ) + # └─────────────────┘ └──────┘ + # scaled multiplier penalty toward target=0 + # + # The multiplier itself scales with λ, making it mild when λ is small + # and aggressive when λ is large. + # ========================================================================= + + elif reg_type == "mult_linear": + # Multiplier scales linearly: g(x) = x + # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)² + # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0 + pos_lyap = torch.relu(lyap_est) + return pos_lyap * pos_lyap # relu(λ)² + + elif reg_type == "mult_squared": + # Multiplier scales quadratically: g(x) = x² + # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³ + # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0 + pos_lyap = torch.relu(lyap_est) + return pos_lyap * pos_lyap * pos_lyap # relu(λ)³ + + elif reg_type == "mult_log": + # Multiplier scales logarithmically: g(x) = log(1+x) + # loss = (λ_reg × log(1+relu(λ))) × relu(λ) + # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16 + pos_lyap = torch.relu(lyap_est) + return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ + + else: + raise ValueError(f"Unknown reg_type: {reg_type}") + + +def train_epoch( + model, loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, + progress=True, compute_sv_every=10, + reg_type="squared", current_lambda_reg=None, + lyap_threshold=2.0 +): + """ + Train one epoch. + + Args: + current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg. + reg_type: "squared", "hinge", "asymmetric", or "extreme" + lyap_threshold: Threshold for extreme reg_type + """ + model.train() + total_loss = 0.0 + correct = 0 + total = 0 + lyap_vals = [] + grad_norms = [] + grad_max_svs = [] + grad_min_svs = [] + grad_conditions = [] + + # Use warmup value if provided + effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for batch_idx, (x, y) in enumerate(iterator): + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps) + + loss = criterion(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold) + loss = loss + effective_lambda_reg * reg + lyap_vals.append(lyap_est.item()) + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan'), None, None, None + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5 + grad_norms.append(grad_norm) + + # Compute gradient SVs periodically (expensive) + if batch_idx % compute_sv_every == 0: + max_sv, min_sv, cond = compute_gradient_svs(model) + if max_sv is not None: + grad_max_svs.append(max_sv) + grad_min_svs.append(min_sv) + grad_conditions.append(cond) + + optimizer.step() + + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + return ( + total_loss / total, + correct / total, + np.mean(lyap_vals) if lyap_vals else None, + np.mean(grad_norms), + np.mean(grad_max_svs) if grad_max_svs else None, + np.mean(grad_min_svs) if grad_min_svs else None, + np.mean(grad_conditions) if grad_conditions else None, + ) + + +@torch.no_grad() +def evaluate(model, loader, criterion, device, progress=True): + model.eval() + total_loss = 0.0 + correct = 0 + total = 0 + + iterator = tqdm(loader, desc="eval", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + + loss = criterion(logits, y) + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + return total_loss / total, correct / total + + +def run_single_config( + dataset_name: str, + depth_config: Tuple[int, int], # (num_stages, blocks_per_stage) + use_lyapunov: bool, + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool = True, + reg_type: str = "squared", + warmup_epochs: int = 0, + stable_init: bool = False, + lyap_threshold: float = 2.0, +) -> List[TrainingMetrics]: + """Run training for a single configuration.""" + torch.manual_seed(seed) + + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + model = SpikingVGG( + in_channels=in_channels, + num_classes=num_classes, + base_channels=64, + num_stages=num_stages, + blocks_per_stage=blocks_per_stage, + T=T, + stable_init=stable_init, + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + method = "Lyapunov" if use_lyapunov else "Vanilla" + print(f" {method}: depth={total_depth}, params={num_params:,}") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + criterion = nn.CrossEntropyLoss() + + history = [] + best_acc = 0.0 + + for epoch in range(1, epochs + 1): + t0 = time.time() + + # Warmup: gradually increase lambda_reg + if warmup_epochs > 0 and epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, lambda_target, 1e-4, progress, + reg_type=reg_type, current_lambda_reg=current_lambda_reg, + lyap_threshold=lyap_threshold + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_acc = max(best_acc, test_acc) + + metrics = TrainingMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + history.append(metrics) + + if epoch % 10 == 0 or epoch == epochs: + lyap_str = f"λ={lyap:.3f}" if lyap else "" + sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else "" + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}") + + if np.isnan(train_loss): + print(f" DIVERGED at epoch {epoch}") + break + + print(f" Best test acc: {best_acc:.3f}") + return history + + +def run_depth_scaling_experiment( + dataset_name: str, + depth_configs: List[Tuple[int, int]], + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool, + reg_type: str = "squared", + warmup_epochs: int = 0, + stable_init: bool = False, + lyap_threshold: float = 2.0, +) -> Dict: + """Run full depth scaling experiment.""" + + results = {"vanilla": {}, "lyapunov": {}} + + print(f"Regularization type: {reg_type}") + print(f"Warmup epochs: {warmup_epochs}") + print(f"Stable init: {stable_init}") + print(f"Lyapunov threshold: {lyap_threshold}") + + for depth_config in depth_configs: + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + print(f"\n{'='*60}") + print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + + history = run_single_config( + dataset_name=dataset_name, + depth_config=depth_config, + use_lyapunov=use_lyap, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=T, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + device=device, + seed=seed, + progress=progress, + reg_type=reg_type, + warmup_epochs=warmup_epochs, + stable_init=stable_init, + lyap_threshold=lyap_threshold, + ) + + results[method][total_depth] = history + + return results + + +def print_summary(results: Dict, dataset_name: str): + """Print final summary table.""" + print("\n" + "=" * 100) + print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}") + print("=" * 100) + print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}") + print("-" * 100) + + depths = sorted(results["vanilla"].keys()) + + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0 + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}" + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A" + + van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" + lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A" + van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" + + print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}") + + print("=" * 100) + + # Gradient health analysis + print("\nGRADIENT HEALTH ANALYSIS:") + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_cond = van.grad_condition if van.grad_condition else 0 + lyap_cond = lyap.grad_condition if lyap.grad_condition else 0 + + status = "" + if van_cond > 1e6: + status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)" + elif van_cond > 1e4: + status = "~ Vanilla has moderately ill-conditioned gradients" + + if status: + print(f" Depth {depth}: {status}") + + print("") + + # Analysis + print("\nKEY OBSERVATIONS:") + shallow = min(depths) + deep = max(depths) + + van_shallow = results["vanilla"][shallow][-1].test_acc + van_deep = results["vanilla"][deep][-1].test_acc + lyap_shallow = results["lyapunov"][shallow][-1].test_acc + lyap_deep = results["lyapunov"][deep][-1].test_acc + + van_gain = van_deep - van_shallow + lyap_gain = lyap_deep - lyap_shallow + + print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change") + print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change") + + if lyap_gain > van_gain + 0.02: + print(f" ✓ Lyapunov regularization enables better depth scaling!") + elif lyap_gain > van_gain: + print(f" ~ Lyapunov shows slight improvement in depth scaling") + else: + print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range") + + +def save_results(results: Dict, output_dir: str, config: Dict): + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, history in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in history] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs") + + p.add_argument("--dataset", type=str, default="cifar100", + choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) + p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16], + help="Total conv layer depths to test") + p.add_argument("--T", type=int, default=4, help="Timesteps") + p.add_argument("--epochs", type=int, default=100) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--lambda_reg", type=float, default=0.3) + p.add_argument("--lambda_target", type=float, default=-0.1) + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--out_dir", type=str, default="runs/depth_scaling") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--no-progress", action="store_true") + p.add_argument("--reg_type", type=str, default="squared", + choices=["squared", "hinge", "asymmetric", "extreme"], + help="Lyapunov regularization type") + p.add_argument("--warmup_epochs", type=int, default=0, + help="Epochs to warmup lambda_reg (0 = no warmup)") + p.add_argument("--stable_init", action="store_true", + help="Use stability-aware weight initialization") + p.add_argument("--lyap_threshold", type=float, default=2.0, + help="Threshold for extreme reg_type (only penalize λ > threshold)") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("DEPTH SCALING BENCHMARK") + print("=" * 80) + print(f"Dataset: {args.dataset}") + print(f"Depths: {args.depths}") + print(f"Timesteps: {args.T}") + print(f"Epochs: {args.epochs}") + print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") + print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}") + print(f"Device: {device}") + print("=" * 80) + + # Load data + print(f"\nLoading {args.dataset}...") + train_loader, test_loader, num_classes, input_shape = get_dataset( + args.dataset, args.data_dir, args.batch_size + ) + in_channels = input_shape[0] + print(f"Classes: {num_classes}, Input: {input_shape}") + print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") + + # Convert depths to (num_stages, blocks_per_stage) configs + # We use 4 stages (3 for smaller nets), adjust blocks_per_stage + depth_configs = [] + for d in args.depths: + if d <= 4: + depth_configs.append((d, 1)) # d stages, 1 block each + elif d <= 8: + depth_configs.append((4, d // 4)) # 4 stages + else: + depth_configs.append((4, d // 4)) # 4 stages, more blocks + + print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}") + + # Run experiment + results = run_depth_scaling_experiment( + dataset_name=args.dataset, + depth_configs=depth_configs, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=args.T, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + progress=not args.no_progress, + reg_type=args.reg_type, + warmup_epochs=args.warmup_epochs, + stable_init=args.stable_init, + lyap_threshold=args.lyap_threshold, + ) + + # Summary + print_summary(results, args.dataset) + + # Save + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/hyperparameter_grid_search.py b/files/experiments/hyperparameter_grid_search.py new file mode 100644 index 0000000..011387f --- /dev/null +++ b/files/experiments/hyperparameter_grid_search.py @@ -0,0 +1,597 @@ +""" +Hyperparameter Grid Search for Lyapunov-Regularized SNNs. + +Goal: Find optimal (lambda_reg, lambda_target) for each network depth + and derive an adaptive curve for automatic hyperparameter selection. + +Usage: + python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Tuple +from itertools import product + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN + + +@dataclass +class GridSearchResult: + """Result from a single grid search configuration.""" + depth: int + lambda_reg: float + lambda_target: float + final_train_acc: float + final_val_acc: float + final_lyapunov: float + final_grad_norm: float + converged: bool # Did training succeed (not NaN)? + epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never) + + +def create_synthetic_data( + n_train: int = 2000, + n_val: int = 500, + T: int = 50, + D: int = 100, + n_classes: int = 10, + seed: int = 42, + batch_size: int = 128, +) -> Tuple[DataLoader, DataLoader, int, int, int]: + """Create synthetic spike data.""" + torch.manual_seed(seed) + np.random.seed(seed) + + def generate_data(n_samples): + x = torch.zeros(n_samples, T, D) + y = torch.randint(0, n_classes, (n_samples,)) + for i in range(n_samples): + label = y[i].item() + base_rate = 0.05 + 0.02 * label + class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) + for t in range(T): + x[i, t] = (torch.rand(D) < base_rate).float() + for c in class_channels: + if torch.rand(1) < base_rate * 3: + x[i, t, c] = 1.0 + return x, y + + x_train, y_train = generate_data(n_train) + x_val, y_val = generate_data(n_val) + + train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) + val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False) + + return train_loader, val_loader, T, D, n_classes + + +def train_and_evaluate( + depth: int, + lambda_reg: float, + lambda_target: float, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + device: torch.device, + seed: int = 42, + warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early +) -> GridSearchResult: + """Train a single configuration and return results.""" + torch.manual_seed(seed) + + # Create model + hidden_dims = [hidden_dim] * depth + model = LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=0.9, + threshold=1.0, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=lr) + ce_loss = nn.CrossEntropyLoss() + + best_val_acc = 0.0 + epochs_to_90 = -1 + final_lyap = 0.0 + final_grad = 0.0 + converged = True + + for epoch in range(1, epochs + 1): + # Warmup: gradually increase lambda_reg + if epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + # Training + model.train() + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + + for x, y in train_loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False) + + ce = ce_loss(logits, y) + if lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + current_lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + converged = False + break + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + optimizer.step() + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + if not converged: + break + + train_acc = total_correct / total_samples + final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0 + final_grad = np.mean(grad_norms) if grad_norms else 0.0 + + # Track epochs to 90% accuracy + if epochs_to_90 < 0 and train_acc >= 0.9: + epochs_to_90 = epoch + + # Validation + model.eval() + val_correct = 0 + val_total = 0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + preds = logits.argmax(dim=1) + val_correct += (preds == y).sum().item() + val_total += x.size(0) + val_acc = val_correct / val_total + best_val_acc = max(best_val_acc, val_acc) + + return GridSearchResult( + depth=depth, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + final_train_acc=train_acc if converged else 0.0, + final_val_acc=best_val_acc if converged else 0.0, + final_lyapunov=final_lyap, + final_grad_norm=final_grad, + converged=converged, + epochs_to_90pct=epochs_to_90, + ) + + +def run_grid_search( + depths: List[int], + lambda_regs: List[float], + lambda_targets: List[float], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + device: torch.device, + seed: int = 42, + progress: bool = True, +) -> List[GridSearchResult]: + """Run full grid search.""" + results = [] + + # Total configurations + configs = list(product(depths, lambda_regs, lambda_targets)) + total = len(configs) + + iterator = tqdm(configs, desc="Grid Search", disable=not progress) + + for depth, lambda_reg, lambda_target in iterator: + if progress: + iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target}) + + result = train_and_evaluate( + depth=depth, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + hidden_dim=hidden_dim, + epochs=epochs, + lr=lr, + device=device, + seed=seed, + ) + results.append(result) + + if progress: + iterator.set_postfix({ + "d": depth, + "λr": lambda_reg, + "λt": lambda_target, + "acc": f"{result.final_val_acc:.2f}" + }) + + return results + + +def analyze_results(results: List[GridSearchResult]) -> Dict: + """Analyze grid search results and find optimal hyperparameters per depth.""" + + # Group by depth + by_depth = {} + for r in results: + if r.depth not in by_depth: + by_depth[r.depth] = [] + by_depth[r.depth].append(r) + + analysis = { + "optimal_per_depth": {}, + "all_results": [asdict(r) for r in results], + } + + print("\n" + "=" * 80) + print("GRID SEARCH ANALYSIS") + print("=" * 80) + + # Find optimal for each depth + print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}") + print("-" * 80) + + optimal_lambda_regs = [] + optimal_lambda_targets = [] + depths_list = [] + + for depth in sorted(by_depth.keys()): + depth_results = by_depth[depth] + # Find best by validation accuracy + best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0) + + analysis["optimal_per_depth"][depth] = { + "lambda_reg": best.lambda_reg, + "lambda_target": best.lambda_target, + "val_acc": best.final_val_acc, + "lyapunov": best.final_lyapunov, + "epochs_to_90": best.epochs_to_90pct, + } + + print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} " + f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}") + + if best.final_val_acc > 0.5: # Only use successful runs for curve fitting + depths_list.append(depth) + optimal_lambda_regs.append(best.lambda_reg) + optimal_lambda_targets.append(best.lambda_target) + + # Fit adaptive curves + print("\n" + "=" * 80) + print("ADAPTIVE HYPERPARAMETER CURVES") + print("=" * 80) + + if len(depths_list) >= 3: + # Fit polynomial curves + depths_arr = np.array(depths_list) + lambda_regs_arr = np.array(optimal_lambda_regs) + lambda_targets_arr = np.array(optimal_lambda_targets) + + # Fit lambda_reg vs depth (expect increasing with depth) + try: + reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1)) + reg_poly = np.poly1d(reg_coeffs) + print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}" + if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}") + analysis["lambda_reg_curve"] = reg_coeffs.tolist() + except Exception as e: + print(f"Could not fit λ_reg curve: {e}") + + # Fit lambda_target vs depth (expect decreasing / more negative with depth) + try: + target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1)) + target_poly = np.poly1d(target_coeffs) + print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}" + if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}") + analysis["lambda_target_curve"] = target_coeffs.tolist() + except Exception as e: + print(f"Could not fit λ_target curve: {e}") + + # Print recommendations + print("\n" + "-" * 80) + print("RECOMMENDED HYPERPARAMETERS BY DEPTH:") + print("-" * 80) + for d in [2, 4, 6, 8, 10, 12, 14, 16]: + rec_reg = max(0.01, reg_poly(d)) + rec_target = min(0.0, target_poly(d)) + print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}") + + else: + print("Not enough successful runs to fit curves") + + return analysis + + +def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict): + """Save grid search results.""" + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f: + json.dump(analysis, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def plot_grid_search(results: List[GridSearchResult], output_dir: str): + """Generate visualization of grid search results.""" + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not available, skipping plots") + return + + # Group by depth + by_depth = {} + for r in results: + if r.depth not in by_depth: + by_depth[r.depth] = [] + by_depth[r.depth].append(r) + + depths = sorted(by_depth.keys()) + + # Get unique lambda values + lambda_regs = sorted(set(r.lambda_reg for r in results)) + lambda_targets = sorted(set(r.lambda_target for r in results)) + + # Create heatmaps for each depth + n_depths = len(depths) + fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10)) + axes = axes.flatten() + + for idx, depth in enumerate(depths): + ax = axes[idx] + depth_results = by_depth[depth] + + # Create accuracy matrix + acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs))) + for r in depth_results: + i = lambda_targets.index(r.lambda_target) + j = lambda_regs.index(r.lambda_reg) + acc_matrix[i, j] = r.final_val_acc + + im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto') + ax.set_xticks(range(len(lambda_regs))) + ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45) + ax.set_yticks(range(len(lambda_targets))) + ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets]) + ax.set_xlabel("λ_reg") + ax.set_ylabel("λ_target") + ax.set_title(f"Depth {depth}") + + # Mark best + best = max(depth_results, key=lambda r: r.final_val_acc) + bi = lambda_targets.index(best.lambda_target) + bj = lambda_regs.index(best.lambda_reg) + ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2) + + # Add colorbar + plt.colorbar(im, ax=ax, label='Val Acc') + + # Hide unused subplots + for idx in range(len(depths), len(axes)): + axes[idx].axis('off') + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight') + plt.close() + + # Plot optimal hyperparameters vs depth + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + optimal_regs = [] + optimal_targets = [] + optimal_accs = [] + for depth in depths: + best = max(by_depth[depth], key=lambda r: r.final_val_acc) + optimal_regs.append(best.lambda_reg) + optimal_targets.append(best.lambda_target) + optimal_accs.append(best.final_val_acc) + + axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8) + axes[0].set_xlabel("Network Depth") + axes[0].set_ylabel("Optimal λ_reg") + axes[0].set_title("Optimal Regularization Strength vs Depth") + axes[0].grid(True, alpha=0.3) + + axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange') + axes[1].set_xlabel("Network Depth") + axes[1].set_ylabel("Optimal λ_target") + axes[1].set_title("Optimal Target Lyapunov vs Depth") + axes[1].grid(True, alpha=0.3) + + axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green') + axes[2].set_xlabel("Network Depth") + axes[2].set_ylabel("Best Validation Accuracy") + axes[2].set_title("Best Achievable Accuracy vs Depth") + axes[2].set_ylim(0, 1.05) + axes[2].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight') + plt.close() + + print(f"Plots saved to {output_dir}") + + +def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'): + """Get CIFAR-10 dataloaders with rate encoding for SNN.""" + from torchvision import datasets, transforms + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform) + val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform) + + # Rate encoding: convert images to spike sequences + class RateEncodedDataset(torch.utils.data.Dataset): + def __init__(self, dataset, T): + self.dataset = dataset + self.T = T + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + img, label = self.dataset[idx] + # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D) + flat = img.view(-1) # (3072,) + # Rate encoding: probability of spike = pixel intensity + spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float() + return spikes, label + + train_encoded = RateEncodedDataset(train_ds, T) + val_encoded = RateEncodedDataset(val_ds, T) + + train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_loader, val_loader, T, 3072, 10 # T, D, num_classes + + +def parse_args(): + p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN") + + # Grid search parameters + p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10], + help="Network depths to test") + p.add_argument("--lambda_regs", type=float, nargs="+", + default=[0.01, 0.05, 0.1, 0.2, 0.3], + help="Lambda_reg values to test") + p.add_argument("--lambda_targets", type=float, nargs="+", + default=[0.0, -0.05, -0.1, -0.2], + help="Lambda_target values to test") + + # Model parameters + p.add_argument("--hidden_dim", type=int, default=256) + p.add_argument("--epochs", type=int, default=15) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--seed", type=int, default=42) + + # Data + p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)") + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding") + + # Output + p.add_argument("--out_dir", type=str, default="runs/grid_search") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("HYPERPARAMETER GRID SEARCH") + print("=" * 80) + print(f"Depths: {args.depths}") + print(f"λ_reg values: {args.lambda_regs}") + print(f"λ_target values: {args.lambda_targets}") + print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}") + print(f"Epochs per config: {args.epochs}") + print(f"Device: {device}") + print("=" * 80) + + # Load data + if args.synthetic: + print("\nUsing synthetic data") + train_loader, val_loader, T, D, C = create_synthetic_data( + seed=args.seed, batch_size=args.batch_size + ) + else: + print("\nUsing CIFAR-10 with rate encoding") + train_loader, val_loader, T, D, C = get_cifar10_loaders( + batch_size=args.batch_size, + T=args.T, + data_dir=args.data_dir + ) + + print(f"Data: T={T}, D={D}, classes={C}\n") + + # Run grid search + results = run_grid_search( + depths=args.depths, + lambda_regs=args.lambda_regs, + lambda_targets=args.lambda_targets, + train_loader=train_loader, + val_loader=val_loader, + input_dim=D, + num_classes=C, + hidden_dim=args.hidden_dim, + epochs=args.epochs, + lr=args.lr, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Analyze results + analysis = analyze_results(results) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, analysis, output_dir, vars(args)) + + # Generate plots + plot_grid_search(results, output_dir) + + +if __name__ == "__main__": + main() diff --git a/files/experiments/lyapunov_diffonly_benchmark.py b/files/experiments/lyapunov_diffonly_benchmark.py new file mode 100644 index 0000000..05dbcd2 --- /dev/null +++ b/files/experiments/lyapunov_diffonly_benchmark.py @@ -0,0 +1,590 @@ +""" +Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation. + +Optimization B: Instead of storing two full membrane trajectories: + mems[i][0] = base trajectory + mems[i][1] = perturbed trajectory + +Store only: + base_mems[i] = base trajectory + delta_mems[i] = perturbation (perturbed - base) + +Benefits: + - ~2x less memory for membrane states + - Fewer memory reads/writes during renormalization + - Better cache utilization +""" + +import os +import sys +import time +import torch +import torch.nn as nn +from typing import Tuple, Optional, List + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import snntorch as snn +from snntorch import surrogate + + +class SpikingVGGBlock(nn.Module): + """Conv-BN-LIF block.""" + + def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): + super().__init__() + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + + def forward(self, x, mem): + h = self.bn(self.conv(x)) + spk, mem = self.lif(h, mem) + return spk, mem + + +class SpikingVGG_Original(nn.Module): + """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W).""" + + def __init__(self, in_channels=3, num_classes=100, base_channels=64, + num_stages=3, blocks_per_stage=2, T=4, beta=0.9): + super().__init__() + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + + # Build stages + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + in_ch = in_channels + out_ch = base_channels + current_size = 32 # CIFAR + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for _ in range(blocks_per_stage): + stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) + in_ch = out_ch + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + current_size //= 2 + if stage < num_stages - 1: + out_ch = min(out_ch * 2, 512) + + self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) + self._channel_sizes = self._compute_channel_sizes(base_channels) + + def _compute_channel_sizes(self, base): + sizes = [] + ch = base + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + sizes.append(ch) + if stage < self.num_stages - 1: + ch = min(ch * 2, 512) + return sizes + + def _init_mems(self, batch_size, device, dtype, P=1): + mems = [] + H, W = 32, 32 + for stage in range(self.num_stages): + for block_idx in range(self.blocks_per_stage): + layer_idx = stage * self.blocks_per_stage + block_idx + ch = self._channel_sizes[layer_idx] + mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + return mems + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + P = 2 if compute_lyapunov else 1 + + mems = self._init_mems(B, device, dtype, P=P) + + if compute_lyapunov: + for i in range(len(mems)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = None + + for t in range(self.T): + mem_idx = 0 + new_mems = [] + is_first_block = True + + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + h_conv = block.bn(block.conv(x)) + h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + h = spk + new_mems.append(mem_new) + is_first_block = False + else: + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + h_conv = block.bn(block.conv(h_flat)) + spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + h = spk + new_mems.append(mem_new) + mem_idx += 1 + + h_flat = h.reshape(P * B, *h.shape[2:]) + h_pooled = pool(h_flat) + h = h_pooled.view(P, B, *h_pooled.shape[1:]) + + mems = new_mems + + h_orig = h[0].view(B, -1) + if spike_sum is None: + spike_sum = h_orig + else: + spike_sum = spike_sum + h_orig + + if compute_lyapunov: + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + scale = (lyap_eps / delta).view(B, 1, 1, 1) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + mems[i] = torch.stack([ + new_mems[i][0], + new_mems[i][0] + diff * scale + ], dim=0) + + logits = self.fc(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +class SpikingVGG_DiffOnly(nn.Module): + """ + Optimized implementation: stores base + diff instead of 2 full trajectories. + + Memory layout: + base_mems[i]: (B, C, H, W) - base trajectory membrane + delta_mems[i]: (B, C, H, W) - perturbation vector + + Perturbed trajectory is materialized as (base + delta) only when needed. + """ + + def __init__(self, in_channels=3, num_classes=100, base_channels=64, + num_stages=3, blocks_per_stage=2, T=4, beta=0.9): + super().__init__() + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + in_ch = in_channels + out_ch = base_channels + current_size = 32 + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for _ in range(blocks_per_stage): + stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) + in_ch = out_ch + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + current_size //= 2 + if stage < num_stages - 1: + out_ch = min(out_ch * 2, 512) + + self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) + self._channel_sizes = self._compute_channel_sizes(base_channels) + + def _compute_channel_sizes(self, base): + sizes = [] + ch = base + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + sizes.append(ch) + if stage < self.num_stages - 1: + ch = min(ch * 2, 512) + return sizes + + def _init_mems(self, batch_size, device, dtype): + """Initialize base membrane states (B, C, H, W).""" + base_mems = [] + H, W = 32, 32 + for stage in range(self.num_stages): + for block_idx in range(self.blocks_per_stage): + layer_idx = stage * self.blocks_per_stage + block_idx + ch = self._channel_sizes[layer_idx] + base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + return base_mems + + def _init_deltas(self, base_mems, lyap_eps): + """Initialize perturbation vectors δ with ||δ||_global = eps.""" + delta_mems = [] + for base in base_mems: + delta_mems.append(lyap_eps * torch.randn_like(base)) + return delta_mems + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + + # Initialize base membrane states + base_mems = self._init_mems(B, device, dtype) + + # Initialize perturbations if computing Lyapunov + if compute_lyapunov: + delta_mems = self._init_deltas(base_mems, lyap_eps) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + else: + delta_mems = None + + spike_sum = None + + for t in range(self.T): + mem_idx = 0 + new_base_mems = [] + new_delta_mems = [] if compute_lyapunov else None + + # Track spikes for base and perturbed (if computing Lyapunov) + h_base = None + h_delta = None # Will store (h_perturbed - h_base) + is_first_block = True + + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + # First block: input x is same for both trajectories + h_conv = block.bn(block.conv(x)) # (B, C, H, W) + + # Base trajectory + spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx]) + new_base_mems.append(mem_base_new) + h_base = spk_base + + if compute_lyapunov: + # Perturbed trajectory: mem = base + delta + mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] + spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed) + # Store delta for new membrane + new_delta_mems.append(mem_perturbed_new - mem_base_new) + # Store spike difference for propagation + h_delta = spk_perturbed - spk_base + + is_first_block = False + else: + # Subsequent blocks: inputs differ + # Base trajectory + h_conv_base = block.bn(block.conv(h_base)) + spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx]) + new_base_mems.append(mem_base_new) + + if compute_lyapunov: + # Perturbed trajectory: h_perturbed = h_base + h_delta + h_perturbed = h_base + h_delta + h_conv_perturbed = block.bn(block.conv(h_perturbed)) + mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] + spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed) + new_delta_mems.append(mem_perturbed_new - mem_base_new) + h_delta = spk_perturbed - spk_base + + h_base = spk_base + + mem_idx += 1 + + # Pooling + h_base = pool(h_base) + if compute_lyapunov: + # Pool both and compute new delta + h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general + # But for AvgPool, it's linear so this is fine + h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool + h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear) + + # Update membrane states + base_mems = new_base_mems + + # Accumulate spikes from base trajectory + h_flat = h_base.view(B, -1) + if spike_sum is None: + spike_sum = h_flat + else: + spike_sum = spike_sum + h_flat + + # Lyapunov: compute global divergence and renormalize + if compute_lyapunov: + # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||² + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for delta in new_delta_mems: + delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3)) + + delta_norm = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12) + + # Renormalize: scale all deltas so ||δ||_global = eps + scale = (lyap_eps / delta_norm).view(B, 1, 1, 1) + delta_mems = [delta * scale for delta in new_delta_mems] + + logits = self.fc(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + +def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20): + """Benchmark forward pass time.""" + device = x.device + + # Warmup + for _ in range(num_warmup): + with torch.no_grad(): + _ = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(num_runs): + torch.cuda.synchronize() + start = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return times, lyap + + +def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov, + lambda_reg=0.3, num_warmup=5, num_runs=20): + """Benchmark forward + backward pass time.""" + device = x.device + + # Warmup + for _ in range(num_warmup): + model.zero_grad() + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + loss = criterion(logits, y) + if compute_lyapunov and lyap is not None: + loss = loss + lambda_reg * (lyap ** 2) + loss.backward() + + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(num_runs): + model.zero_grad() + torch.cuda.synchronize() + start = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + loss = criterion(logits, y) + if compute_lyapunov and lyap is not None: + loss = loss + lambda_reg * (lyap ** 2) + loss.backward() + + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return times + + +def measure_memory(model, x, compute_lyapunov): + """Measure peak GPU memory during forward pass.""" + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + with torch.no_grad(): + _ = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB + return peak_mem + + +def run_benchmark(): + print("=" * 70) + print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + if device.type == "cuda": + print(f"GPU: {torch.cuda.get_device_name()}") + + # Test configurations + configs = [ + {"depth": 4, "blocks_per_stage": 1, "batch_size": 64}, + {"depth": 8, "blocks_per_stage": 2, "batch_size": 64}, + {"depth": 12, "blocks_per_stage": 3, "batch_size": 32}, + ] + + print("\n" + "=" * 70) + + for cfg in configs: + depth = cfg["depth"] + blocks = cfg["blocks_per_stage"] + batch_size = cfg["batch_size"] + + print(f"\n{'='*70}") + print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}") + print(f"{'='*70}") + + # Create models + model_orig = SpikingVGG_Original( + blocks_per_stage=blocks, T=4 + ).to(device) + + model_diff = SpikingVGG_DiffOnly( + blocks_per_stage=blocks, T=4 + ).to(device) + + # Copy weights from original to diff-only + model_diff.load_state_dict(model_orig.state_dict()) + + print(f"Parameters: {count_parameters(model_orig):,}") + + # Create input + x = torch.randn(batch_size, 3, 32, 32, device=device) + y = torch.randint(0, 100, (batch_size,), device=device) + criterion = nn.CrossEntropyLoss() + + # ============================================================ + # Test 1: Verify outputs match + # ============================================================ + print("\n--- Output Verification ---") + model_orig.eval() + model_diff.eval() + + torch.manual_seed(42) + with torch.no_grad(): + logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4) + + torch.manual_seed(42) + with torch.no_grad(): + logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4) + + logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5) + lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations + + print(f"Logits match: {logits_match}") + print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}") + print(f"Lyapunov close (within 0.1): {lyap_close}") + + # ============================================================ + # Test 2: Forward-only speed (no grad) + # ============================================================ + print("\n--- Forward Speed (no_grad) ---") + model_orig.eval() + model_diff.eval() + + # Without Lyapunov + times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False) + times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False) + + mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000 + mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000 + + print(f" Without Lyapunov:") + print(f" Original: {mean_orig:.2f} ms") + print(f" Diff-only: {mean_diff:.2f} ms") + + # With Lyapunov + times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True) + times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True) + + mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000 + mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000 + speedup = mean_orig_ly / mean_diff_ly + + print(f" With Lyapunov:") + print(f" Original: {mean_orig_ly:.2f} ms") + print(f" Diff-only: {mean_diff_ly:.2f} ms") + print(f" Speedup: {speedup:.2f}x") + + # ============================================================ + # Test 3: Forward + Backward speed (training mode) + # ============================================================ + print("\n--- Forward+Backward Speed (training) ---") + model_orig.train() + model_diff.train() + + times_orig_train = benchmark_forward_backward( + model_orig, x, y, criterion, compute_lyapunov=True + ) + times_diff_train = benchmark_forward_backward( + model_diff, x, y, criterion, compute_lyapunov=True + ) + + mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000 + mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000 + speedup_train = mean_orig_train / mean_diff_train + + print(f" With Lyapunov + backward:") + print(f" Original: {mean_orig_train:.2f} ms") + print(f" Diff-only: {mean_diff_train:.2f} ms") + print(f" Speedup: {speedup_train:.2f}x") + + # ============================================================ + # Test 4: Memory usage + # ============================================================ + if device.type == "cuda": + print("\n--- Peak GPU Memory ---") + + mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False) + mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False) + + mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True) + mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True) + + print(f" Without Lyapunov:") + print(f" Original: {mem_orig_noly:.1f} MB") + print(f" Diff-only: {mem_diff_noly:.1f} MB") + print(f" With Lyapunov:") + print(f" Original: {mem_orig_ly:.1f} MB") + print(f" Diff-only: {mem_diff_ly:.1f} MB") + print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)") + + # Cleanup + del model_orig, model_diff, x, y + torch.cuda.empty_cache() + + print("\n" + "=" * 70) + print("BENCHMARK COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + run_benchmark() diff --git a/files/experiments/lyapunov_speedup_benchmark.py b/files/experiments/lyapunov_speedup_benchmark.py new file mode 100644 index 0000000..117009b --- /dev/null +++ b/files/experiments/lyapunov_speedup_benchmark.py @@ -0,0 +1,638 @@ +""" +Lyapunov Computation Speedup Benchmark + +Tests different optimization approaches for computing Lyapunov exponents +during SNN training. All approaches should produce equivalent results +(within numerical precision) but with different performance characteristics. + +Approaches tested: +- Baseline: Current sequential implementation +- Approach A: Trajectory-as-batch (P=2), share first Linear +- Approach B: Global-norm divergence + single-scale renorm +- Approach C: torch.compile the time loop +- Combined: A + B + C together +""" + +import os +import sys +import time +from typing import Tuple, Optional, List +from dataclasses import dataclass + +import torch +import torch.nn as nn +import snntorch as snn +from snntorch import surrogate + +# Ensure we can import from project +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + + +# ============================================================================= +# Baseline Implementation (Current) +# ============================================================================= + +class BaselineSNN(nn.Module): + """Current implementation: sequential perturbed trajectory.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + # Simple feedforward for benchmarking (not full VGG) + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims # Flattened input + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) # Flatten + + # Init membrane potentials + mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Original trajectory + h = x + new_mems = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h = lin(h) + spk, mem = lif(h, mems[i]) + new_mems.append(mem) + h = spk + mems = new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory (SEPARATE PASS - this is slow) + h_p = x + new_mems_p = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h_p = lin(h_p) + spk_p, mem_p = lif(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = spk_p + + # Divergence (per-layer norms, then sum) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + delta_sq += (diff ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize (per-layer - SLOW) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 + new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm + + mems_p = new_mems_p + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach A: Trajectory-as-batch (P=2), share first Linear +# ============================================================================= + +class ApproachA_SNN(nn.Module): + """Batch both trajectories together, share Linear_1.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + P = 2 if compute_lyapunov else 1 + + # State layout: (P, B, H) where P=2 for [original, perturbed] + mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + # Initialize perturbed state + for i in range(len(self.hidden_dims)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Layer 1: compute Linear ONCE, expand to (P, B, H1) + h1 = self.linears[0](x) # (B, H1) - computed ONCE + + if compute_lyapunov: + h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view + else: + h = h1.unsqueeze(0) # (1, B, H1) + + # LIF layer 1 + spk, mems[0] = self.lifs[0](h, mems[0]) + h = spk + + # Layers 2+: different inputs for each trajectory + for i in range(1, len(self.hidden_dims)): + # Reshape to (P*B, H) for batched Linear + h_flat = h.reshape(P * B, -1) + h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) + spk, mems[i] = self.lifs[i](h_lin, mems[i]) + h = spk + + # Accumulate spikes from original trajectory only + spike_sum = spike_sum + h[0] + + if compute_lyapunov: + # Global divergence across all layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] # (B, H_i) + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # Renormalize with global scale (per-layer still, but simpler) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 + mems[i][1] = mems[i][0] + lyap_eps * diff / norm + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach B: Global-norm divergence + single-scale renorm +# ============================================================================= + +class ApproachB_SNN(nn.Module): + """Global norm for divergence, single scale factor for renorm.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Original trajectory + h = x + new_mems = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h = lin(h) + spk, mem = lif(h, mems[i]) + new_mems.append(mem) + h = spk + mems = new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory + h_p = x + new_mems_p = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h_p = lin(h_p) + spk_p, mem_p = lif(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = spk_p + + # GLOBAL divergence (one delta per batch element) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # SINGLE SCALE renormalization (key optimization) + scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + new_mems_p[i] = new_mems[i] + diff * scale + + mems_p = new_mems_p + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach A+B Combined: Batched trajectories + global renorm +# ============================================================================= + +class ApproachAB_SNN(nn.Module): + """Combined: trajectory-as-batch + global-norm renorm.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + P = 2 if compute_lyapunov else 1 + + # State: (P, B, H) + mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + for i in range(len(self.hidden_dims)): + mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Layer 1: Linear computed ONCE + h1 = self.linears[0](x) + h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0) + + spk, mems[0] = self.lifs[0](h, mems[0]) + h = spk + + # Layers 2+ + for i in range(1, len(self.hidden_dims)): + h_flat = h.reshape(P * B, -1) + h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) + spk, mems[i] = self.lifs[i](h_lin, mems[i]) + h = spk + + spike_sum = spike_sum + h[0] + + if compute_lyapunov: + # Global divergence + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # Global scale renorm + scale = (lyap_eps / delta).unsqueeze(-1) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + mems[i][1] = mems[i][0] + diff * scale + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach C: torch.compile wrapper +# ============================================================================= + +def make_compiled_model(model_class, *args, **kwargs): + """Create a model and compile its forward pass.""" + model = model_class(*args, **kwargs) + # Compile the forward method + model.forward = torch.compile(model.forward, mode="reduce-overhead") + return model + + +# ============================================================================= +# Benchmarking +# ============================================================================= + +@dataclass +class BenchmarkResult: + name: str + forward_time_ms: float + backward_time_ms: float + total_time_ms: float + lyap_value: float + memory_mb: float + + def __str__(self): + return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | " + f"Bwd: {self.backward_time_ms:7.2f}ms | " + f"Total: {self.total_time_ms:7.2f}ms | " + f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB") + + +def benchmark_model( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + name: str, + warmup_iters: int = 5, + bench_iters: int = 20, +) -> BenchmarkResult: + """Benchmark a single model configuration.""" + + device = x.device + criterion = nn.CrossEntropyLoss() + + # Warmup + for _ in range(warmup_iters): + logits, lyap = model(x, compute_lyapunov=True) + loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) + loss.backward() + model.zero_grad() + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + fwd_times = [] + bwd_times = [] + lyap_vals = [] + + for _ in range(bench_iters): + # Forward + torch.cuda.synchronize() + t0 = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=True) + loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) + + torch.cuda.synchronize() + t1 = time.perf_counter() + + # Backward + loss.backward() + + torch.cuda.synchronize() + t2 = time.perf_counter() + + fwd_times.append((t1 - t0) * 1000) + bwd_times.append((t2 - t1) * 1000) + if lyap is not None: + lyap_vals.append(lyap.item()) + + model.zero_grad() + + peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 + + return BenchmarkResult( + name=name, + forward_time_ms=sum(fwd_times) / len(fwd_times), + backward_time_ms=sum(bwd_times) / len(bwd_times), + total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times), + lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0, + memory_mb=peak_mem, + ) + + +def run_benchmarks( + batch_size: int = 64, + T: int = 4, + hidden_dims: List[int] = [64, 128, 256], + device: str = "cuda", +): + """Run all benchmarks and compare.""" + + print("=" * 80) + print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK") + print("=" * 80) + print(f"Batch size: {batch_size}") + print(f"Timesteps: {T}") + print(f"Hidden dims: {hidden_dims}") + print(f"Device: {device}") + print("=" * 80) + + # Create dummy data + x = torch.randn(batch_size, 3, 32, 32, device=device) + y = torch.randint(0, 10, (batch_size,), device=device) + + results = [] + + # 1. Baseline + print("\n[1/6] Benchmarking Baseline...") + model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "Baseline")) + del model + torch.cuda.empty_cache() + + # 2. Approach A (batched trajectories) + print("[2/6] Benchmarking Approach A (batched)...") + model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "A: Batched trajectories")) + del model + torch.cuda.empty_cache() + + # 3. Approach B (global renorm) + print("[3/6] Benchmarking Approach B (global renorm)...") + model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "B: Global renorm")) + del model + torch.cuda.empty_cache() + + # 4. Approach A+B combined + print("[4/6] Benchmarking Approach A+B (combined)...") + model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "A+B: Combined")) + del model + torch.cuda.empty_cache() + + # 5. Approach C (torch.compile on baseline) + print("[5/6] Benchmarking Approach C (compiled baseline)...") + try: + model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10)) + del model + torch.cuda.empty_cache() + except Exception as e: + print(f" torch.compile failed: {e}") + results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0)) + + # 6. A+B+C (all combined) + print("[6/6] Benchmarking A+B+C (all optimizations)...") + try: + model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10)) + del model + torch.cuda.empty_cache() + except Exception as e: + print(f" torch.compile failed: {e}") + results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0)) + + # Print results + print("\n" + "=" * 80) + print("RESULTS") + print("=" * 80) + + baseline_time = results[0].total_time_ms + + for r in results: + print(r) + + print("\n" + "-" * 80) + print("SPEEDUP vs BASELINE:") + print("-" * 80) + + for r in results[1:]: + if r.total_time_ms > 0: + speedup = baseline_time / r.total_time_ms + print(f" {r.name:<25}: {speedup:.2f}x") + + # Verify Lyapunov values are consistent + print("\n" + "-" * 80) + print("LYAPUNOV VALUE CONSISTENCY CHECK:") + print("-" * 80) + + base_lyap = results[0].lyap_value + for r in results[1:]: + if r.lyap_value != 0: + diff = abs(r.lyap_value - base_lyap) + status = "✓" if diff < 0.1 else "✗" + print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}") + + return results + + +def run_scaling_test(device: str = "cuda"): + """Test how approaches scale with batch size and timesteps.""" + + print("\n" + "=" * 80) + print("SCALING TESTS") + print("=" * 80) + + configs = [ + {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model + ] + + print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}") + print("-" * 80) + + for cfg in configs: + x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device) + y = torch.randint(0, 10, (cfg["batch_size"],), device=device) + + # Baseline + model_base = BaselineSNN(**cfg).to(device) + r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10) + del model_base + + # A+B + model_ab = ApproachAB_SNN(**cfg).to(device) + r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10) + del model_ab + + torch.cuda.empty_cache() + + speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0 + + cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}" + print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--T", type=int, default=4) + parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256]) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--scaling", action="store_true", help="Run scaling tests") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA not available, using CPU (results will not be representative)") + args.device = "cpu" + + # Main benchmark + results = run_benchmarks( + batch_size=args.batch_size, + T=args.T, + hidden_dims=args.hidden_dims, + device=args.device, + ) + + # Scaling tests + if args.scaling: + run_scaling_test(args.device) diff --git a/files/experiments/plot_depth_comparison.py b/files/experiments/plot_depth_comparison.py new file mode 100644 index 0000000..2222b7b --- /dev/null +++ b/files/experiments/plot_depth_comparison.py @@ -0,0 +1,305 @@ +""" +Visualization for depth comparison experiments. + +Usage: + python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP +""" + +import os +import sys +import json +import argparse +from typing import Dict, List + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + + +def load_results(results_dir: str) -> Dict: + """Load results from JSON file.""" + with open(os.path.join(results_dir, "results.json"), "r") as f: + return json.load(f) + + +def load_config(results_dir: str) -> Dict: + """Load config from JSON file.""" + config_path = os.path.join(results_dir, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + return {} + + +def plot_training_curves(results: Dict, output_path: str): + """ + Plot training curves for each depth. + + Creates a figure with subplots for each depth showing: + - Training loss + - Validation accuracy + - Lyapunov exponent (if available) + - Gradient norm + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + n_depths = len(depths) + + fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths)) + if n_depths == 1: + axes = axes.reshape(1, -1) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"} + + for i, depth in enumerate(depths): + for method in ["vanilla", "lyapunov"]: + metrics = results[method][str(depth)] + epochs = [m["epoch"] for m in metrics] + + # Training Loss + train_loss = [m["train_loss"] for m in metrics] + axes[i, 0].plot(epochs, train_loss, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 0].set_ylabel("Train Loss") + axes[i, 0].set_title(f"Depth={depth}: Training Loss") + axes[i, 0].set_yscale("log") + axes[i, 0].grid(True, alpha=0.3) + + # Validation Accuracy + val_acc = [m["val_acc"] for m in metrics] + axes[i, 1].plot(epochs, val_acc, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 1].set_ylabel("Val Accuracy") + axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy") + axes[i, 1].set_ylim(0, 1) + axes[i, 1].grid(True, alpha=0.3) + + # Lyapunov Exponent + lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None] + lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None] + if lyap: + axes[i, 2].plot(lyap_epochs, lyap, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5) + axes[i, 2].set_ylabel("Lyapunov λ") + axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent") + axes[i, 2].grid(True, alpha=0.3) + + # Gradient Norm + grad_norm = [m["grad_norm"] for m in metrics] + axes[i, 3].plot(epochs, grad_norm, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 3].set_ylabel("Gradient Norm") + axes[i, 3].set_title(f"Depth={depth}: Gradient Norm") + axes[i, 3].set_yscale("log") + axes[i, 3].grid(True, alpha=0.3) + + # Add legend to first row + if i == 0: + for ax in axes[i]: + ax.legend(loc="upper right") + + # Set x-labels on bottom row + for ax in axes[-1]: + ax.set_xlabel("Epoch") + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved training curves to {output_path}") + + +def plot_depth_summary(results: Dict, output_path: str): + """ + Plot summary comparing methods across depths. + + Creates a figure showing: + - Final validation accuracy vs depth + - Final gradient norm vs depth + - Final Lyapunov exponent vs depth + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + + fig, axes = plt.subplots(1, 3, figsize=(14, 4)) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + markers = {"vanilla": "o", "lyapunov": "s"} + + # Collect final metrics + van_acc = [] + lyap_acc = [] + van_grad = [] + lyap_grad = [] + lyap_lambda = [] + + for depth in depths: + van_metrics = results["vanilla"][str(depth)][-1] + lyap_metrics = results["lyapunov"][str(depth)][-1] + + van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0) + lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0) + + van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0) + lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0) + + if lyap_metrics["lyapunov"] is not None: + lyap_lambda.append(lyap_metrics["lyapunov"]) + else: + lyap_lambda.append(0) + + # Plot 1: Validation Accuracy vs Depth + ax = axes[0] + ax.plot(depths, van_acc, 'o-', color=colors["vanilla"], + label="Vanilla", linewidth=2, markersize=8) + ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"], + label="Lyapunov", linewidth=2, markersize=8) + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Validation Accuracy") + ax.set_title("Accuracy vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05) + + # Plot 2: Gradient Norm vs Depth + ax = axes[1] + ax.plot(depths, van_grad, 'o-', color=colors["vanilla"], + label="Vanilla", linewidth=2, markersize=8) + ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"], + label="Lyapunov", linewidth=2, markersize=8) + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Gradient Norm") + ax.set_title("Gradient Stability vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_yscale("log") + + # Plot 3: Lyapunov Exponent vs Depth + ax = axes[2] + ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"], + linewidth=2, markersize=8) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)") + ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region") + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Lyapunov Exponent") + ax.set_title("Lyapunov Exponent vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved depth summary to {output_path}") + + +def plot_stability_comparison(results: Dict, output_path: str): + """ + Plot stability metrics comparison. + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + + # Collect metrics over training + for depth in depths: + van_metrics = results["vanilla"][str(depth)] + lyap_metrics = results["lyapunov"][str(depth)] + + van_epochs = [m["epoch"] for m in van_metrics] + lyap_epochs = [m["epoch"] for m in lyap_metrics] + + # Firing rate + van_fr = [m["firing_rate"] for m in van_metrics] + lyap_fr = [m["firing_rate"] for m in lyap_metrics] + axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"], + alpha=0.3 + 0.1 * depths.index(depth)) + axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"], + alpha=0.3 + 0.1 * depths.index(depth)) + + # Dead neurons + van_dead = [m["dead_neurons"] for m in van_metrics] + lyap_dead = [m["dead_neurons"] for m in lyap_metrics] + axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"], + alpha=0.3 + 0.1 * depths.index(depth)) + axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"], + alpha=0.3 + 0.1 * depths.index(depth)) + + axes[0, 0].set_xlabel("Epoch") + axes[0, 0].set_ylabel("Firing Rate") + axes[0, 0].set_title("Firing Rate Over Training") + axes[0, 0].grid(True, alpha=0.3) + + axes[0, 1].set_xlabel("Epoch") + axes[0, 1].set_ylabel("Dead Neuron Fraction") + axes[0, 1].set_title("Dead Neurons Over Training") + axes[0, 1].grid(True, alpha=0.3) + + # Final metrics bar chart + van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths] + lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths] + + x = np.arange(len(depths)) + width = 0.35 + + axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"]) + axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"]) + axes[1, 0].set_xlabel("Network Depth") + axes[1, 0].set_ylabel("Final Validation Accuracy") + axes[1, 0].set_title("Final Accuracy Comparison") + axes[1, 0].set_xticks(x) + axes[1, 0].set_xticklabels(depths) + axes[1, 0].legend() + axes[1, 0].grid(True, alpha=0.3, axis='y') + + # Improvement percentage + improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)] + colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements] + + axes[1, 1].bar(x, improvements, color=colors_bar) + axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5) + axes[1, 1].set_xlabel("Network Depth") + axes[1, 1].set_ylabel("Accuracy Improvement") + axes[1, 1].set_title("Lyapunov Improvement over Vanilla") + axes[1, 1].set_xticks(x) + axes[1, 1].set_xticklabels(depths) + axes[1, 1].grid(True, alpha=0.3, axis='y') + + # Add legend for line plots + custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2), + Line2D([0], [0], color=colors["lyapunov"], lw=2)] + axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov']) + axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov']) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved stability comparison to {output_path}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--results_dir", type=str, required=True, + help="Directory containing results.json") + parser.add_argument("--output_dir", type=str, default=None, + help="Output directory for plots (default: same as results_dir)") + args = parser.parse_args() + + output_dir = args.output_dir or args.results_dir + + print(f"Loading results from {args.results_dir}") + results = load_results(args.results_dir) + config = load_config(args.results_dir) + + print(f"Config: {config}") + + # Generate plots + plot_training_curves(results, os.path.join(output_dir, "training_curves.png")) + plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png")) + plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png")) + + print(f"\nAll plots saved to {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/files/experiments/posthoc_finetune.py b/files/experiments/posthoc_finetune.py new file mode 100644 index 0000000..3f3bf6c --- /dev/null +++ b/files/experiments/posthoc_finetune.py @@ -0,0 +1,323 @@ +""" +Post-hoc Lyapunov Fine-tuning Experiment + +Strategy: +1. Train network with vanilla (no Lyapunov) for N epochs +2. Then fine-tune with Lyapunov regularization for M epochs + +This allows the network to learn task-relevant features first, +then stabilize dynamics without starting from chaotic initialization. +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from files.experiments.depth_scaling_benchmark import ( + SpikingVGG, + get_dataset, + train_epoch, + evaluate, + TrainingMetrics, + compute_lyap_reg_loss, +) + + +def run_posthoc_experiment( + dataset_name: str, + depth_config: Tuple[int, int], + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + pretrain_epochs: int, + finetune_epochs: int, + lr: float, + finetune_lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + reg_type: str = "extreme", + lyap_threshold: float = 2.0, + progress: bool = True, +) -> Dict: + """Run post-hoc fine-tuning experiment.""" + torch.manual_seed(seed) + + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + print(f"\n{'='*60}") + print(f"POST-HOC FINE-TUNING: Depth = {total_depth}") + print(f"Pretrain: {pretrain_epochs} epochs (vanilla)") + print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})") + print(f"{'='*60}") + + model = SpikingVGG( + in_channels=in_channels, + num_classes=num_classes, + base_channels=64, + num_stages=num_stages, + blocks_per_stage=blocks_per_stage, + T=T, + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Parameters: {num_params:,}") + + criterion = nn.CrossEntropyLoss() + + # Phase 1: Vanilla pre-training + print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---") + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs) + + pretrain_history = [] + best_pretrain_acc = 0.0 + + for epoch in range(1, pretrain_epochs + 1): + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov=False, # No Lyapunov during pre-training + lambda_reg=0, lambda_target=0, lyap_eps=1e-4, + progress=progress, + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_pretrain_acc = max(best_pretrain_acc, test_acc) + + metrics = TrainingMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + pretrain_history.append(metrics) + + if epoch % 10 == 0 or epoch == pretrain_epochs: + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}") + + print(f" Best pretrain acc: {best_pretrain_acc:.3f}") + + # Phase 2: Lyapunov fine-tuning + print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---") + print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}") + + # Reset optimizer with lower learning rate for fine-tuning + optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs) + + finetune_history = [] + best_finetune_acc = 0.0 + + for epoch in range(1, finetune_epochs + 1): + t0 = time.time() + + # Warmup lambda_reg over first 10 epochs of fine-tuning + warmup_epochs = 10 + if epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov=True, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=1e-4, + progress=progress, + reg_type=reg_type, + current_lambda_reg=current_lambda_reg, + lyap_threshold=lyap_threshold, + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_finetune_acc = max(best_finetune_acc, test_acc) + + metrics = TrainingMetrics( + epoch=pretrain_epochs + epoch, # Continue epoch numbering + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + finetune_history.append(metrics) + + if epoch % 10 == 0 or epoch == finetune_epochs: + lyap_str = f"λ={lyap:.3f}" if lyap else "" + print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}") + + if np.isnan(train_loss): + print(f" DIVERGED at epoch {epoch}") + break + + print(f" Best finetune acc: {best_finetune_acc:.3f}") + print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "") + + return { + "depth": total_depth, + "pretrain_history": pretrain_history, + "finetune_history": finetune_history, + "best_pretrain_acc": best_pretrain_acc, + "best_finetune_acc": best_finetune_acc, + } + + +def main(): + parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning") + parser.add_argument("--dataset", type=str, default="cifar100", + choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) + parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16]) + parser.add_argument("--T", type=int, default=4) + parser.add_argument("--pretrain_epochs", type=int, default=100) + parser.add_argument("--finetune_epochs", type=int, default=50) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--finetune_lr", type=float, default=1e-4) + parser.add_argument("--lambda_reg", type=float, default=0.1) + parser.add_argument("--lambda_target", type=float, default=-0.1) + parser.add_argument("--reg_type", type=str, default="extreme") + parser.add_argument("--lyap_threshold", type=float, default=2.0) + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--no-progress", action="store_true") + + args = parser.parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT") + print("=" * 80) + print(f"Dataset: {args.dataset}") + print(f"Depths: {args.depths}") + print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})") + print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})") + print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}") + print("=" * 80) + + # Load data + train_loader, test_loader, num_classes, input_shape = get_dataset( + args.dataset, args.data_dir, args.batch_size + ) + in_channels = input_shape[0] + + # Convert depths to configs + depth_configs = [] + for d in args.depths: + if d <= 4: + depth_configs.append((d, 1)) + else: + depth_configs.append((4, d // 4)) + + # Run experiments + all_results = [] + for depth_config in depth_configs: + result = run_posthoc_experiment( + dataset_name=args.dataset, + depth_config=depth_config, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=args.T, + pretrain_epochs=args.pretrain_epochs, + finetune_epochs=args.finetune_epochs, + lr=args.lr, + finetune_lr=args.finetune_lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + reg_type=args.reg_type, + lyap_threshold=args.lyap_threshold, + progress=not args.no_progress, + ) + all_results.append(result) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}") + print("-" * 80) + + for r in all_results: + pre_acc = r["best_pretrain_acc"] + fine_acc = r["best_finetune_acc"] + change = fine_acc - pre_acc + final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None + lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A" + change_str = f"{change:+.3f}" + + print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}") + + print("=" * 80) + + # Save results + os.makedirs(args.out_dir, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json") + + serializable_results = [] + for r in all_results: + sr = { + "depth": r["depth"], + "best_pretrain_acc": r["best_pretrain_acc"], + "best_finetune_acc": r["best_finetune_acc"], + "pretrain_history": [asdict(m) for m in r["pretrain_history"]], + "finetune_history": [asdict(m) for m in r["finetune_history"]], + } + serializable_results.append(sr) + + with open(output_file, "w") as f: + json.dump({"config": vars(args), "results": serializable_results}, f, indent=2) + + print(f"\nResults saved to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/files/experiments/scaled_reg_grid_search.py b/files/experiments/scaled_reg_grid_search.py new file mode 100644 index 0000000..928caff --- /dev/null +++ b/files/experiments/scaled_reg_grid_search.py @@ -0,0 +1,301 @@ +""" +Grid Search: Multiplier-Scaled Regularization Experiments + +Tests the new multiplier-scaled regularization approach: + loss = (λ_reg × g(relu(λ))) × relu(λ) + +Where g(x) is the multiplier scaling function: + - mult_linear: g(x) = x → loss = λ_reg × relu(λ)² + - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³ + - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ) + +Grid: + - λ_reg: 0.01, 0.05, 0.1, 0.3 + - reg_type: mult_linear, mult_squared, mult_log + - depth: specified via command line + +Usage: + python scaled_reg_grid_search.py --depth 4 + python scaled_reg_grid_search.py --depth 8 + python scaled_reg_grid_search.py --depth 12 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional +from itertools import product + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +# Import from main benchmark +from depth_scaling_benchmark import ( + SpikingVGG, + compute_lyap_reg_loss, +) + +import snntorch as snn +from snntorch import surrogate + + +@dataclass +class ExperimentResult: + depth: int + reg_type: str + lambda_reg: float + vanilla_acc: float + lyapunov_acc: float + final_lyap: Optional[float] + delta: float + + +def train_epoch(model, loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, reg_type, progress=False): + """Train one epoch.""" + model.train() + total_loss = 0.0 + correct = 0 + total = 0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4) + loss = criterion(logits, y) + + if use_lyapunov and lyap_est is not None: + # Target is implicitly 0 for scaled reg types + lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0) + loss = loss + lambda_reg * lyap_reg + lyap_vals.append(lyap_est.item()) + + loss.backward() + optimizer.step() + + total_loss += loss.item() * x.size(0) + _, pred = logits.max(1) + correct += pred.eq(y).sum().item() + total += x.size(0) + + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return total_loss / total, correct / total, avg_lyap + + +def evaluate(model, loader, device): + """Evaluate model.""" + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + _, pred = logits.max(1) + correct += pred.eq(y).sum().item() + total += x.size(0) + + return correct / total + + +def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader, + device, epochs=100, lr=0.001): + """Run a single experiment configuration.""" + + # Determine blocks per stage based on depth + # depth = num_stages * blocks_per_stage, with num_stages=4 + blocks_per_stage = depth // 4 + + print(f"\n{'='*60}") + print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}") + print(f"{'='*60}") + + # --- Run Vanilla baseline --- + print(f" Training Vanilla...") + model_v = SpikingVGG( + num_classes=100, + blocks_per_stage=blocks_per_stage, + T=4, + ).to(device) + + optimizer_v = optim.Adam(model_v.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss() + scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs) + + best_vanilla = 0.0 + for epoch in range(epochs): + train_epoch(model_v, train_loader, optimizer_v, criterion, device, + use_lyapunov=False, lambda_reg=0, reg_type="squared") + scheduler_v.step() + + if (epoch + 1) % 10 == 0 or epoch == epochs - 1: + acc = evaluate(model_v, test_loader, device) + best_vanilla = max(best_vanilla, acc) + print(f" Epoch {epoch+1:3d}: test={acc:.3f}") + + del model_v, optimizer_v, scheduler_v + torch.cuda.empty_cache() + + # --- Run Lyapunov version --- + print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...") + model_l = SpikingVGG( + num_classes=100, + blocks_per_stage=blocks_per_stage, + T=4, + ).to(device) + + optimizer_l = optim.Adam(model_l.parameters(), lr=lr) + scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs) + + best_lyap_acc = 0.0 + final_lyap = None + + for epoch in range(epochs): + _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device, + use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type) + scheduler_l.step() + final_lyap = lyap + + if (epoch + 1) % 10 == 0 or epoch == epochs - 1: + acc = evaluate(model_l, test_loader, device) + best_lyap_acc = max(best_lyap_acc, acc) + lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A" + print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}") + + del model_l, optimizer_l, scheduler_l + torch.cuda.empty_cache() + + delta = best_lyap_acc - best_vanilla + + result = ExperimentResult( + depth=depth, + reg_type=reg_type, + lambda_reg=lambda_reg, + vanilla_acc=best_vanilla, + lyapunov_acc=best_lyap_acc, + final_lyap=final_lyap, + delta=delta, + ) + + print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}") + + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12]) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print("=" * 70) + print("SCALED REGULARIZATION GRID SEARCH") + print("=" * 70) + print(f"Depth: {args.depth}") + print(f"Epochs: {args.epochs}") + print(f"Device: {device}") + if device.type == "cuda": + print(f"GPU: {torch.cuda.get_device_name()}") + print("=" * 70) + + # Grid parameters + lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks + reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning + + print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments") + print(f"λ_reg values: {lambda_regs}") + print(f"reg_types: {reg_types}") + + # Load data + print(f"\nLoading CIFAR-100...") + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + + train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train) + test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=4, pin_memory=True) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=4, pin_memory=True) + + print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}") + + # Run grid search + results = [] + + for lambda_reg, reg_type in product(lambda_regs, reg_types): + result = run_single_experiment( + depth=args.depth, + reg_type=reg_type, + lambda_reg=lambda_reg, + train_loader=train_loader, + test_loader=test_loader, + device=device, + epochs=args.epochs, + lr=args.lr, + ) + results.append(result) + + # Print summary table + print("\n" + "=" * 70) + print(f"SUMMARY: DEPTH = {args.depth}") + print("=" * 70) + print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}") + print("-" * 70) + + for r in results: + lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A" + delta_str = f"{r.delta:+.3f}" + print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}") + + # Find best configuration + best = max(results, key=lambda x: x.lyapunov_acc) + print("-" * 70) + print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})") + + # Save results + os.makedirs(args.out_dir, exist_ok=True) + out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json") + with open(out_file, "w") as f: + json.dump([asdict(r) for r in results], f, indent=2) + print(f"\nResults saved to: {out_file}") + + print("\n" + "=" * 70) + print("GRID SEARCH COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/files/models/conv_snn.py b/files/models/conv_snn.py new file mode 100644 index 0000000..69f77d6 --- /dev/null +++ b/files/models/conv_snn.py @@ -0,0 +1,483 @@ +""" +Convolutional SNN with Lyapunov regularization for image classification. + +Properly handles spatial structure: +- Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor +- Uses Conv-LIF layers to preserve spatial hierarchy +- Rate encoding converts images to spike trains + +Based on standard SNN vision practices: +- Rate/Poisson encoding for input +- Conv → BatchNorm → LIF → Pool architecture +- Time comes from encoding + LIF dynamics, not flattening +""" + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import snntorch as snn +from snntorch import surrogate + + +class RateEncoder(nn.Module): + """ + Rate (Poisson/Bernoulli) encoder for static images. + + Converts intensity x ∈ [0,1] to spike probability per timestep. + Each pixel independently fires with P(spike) = x * gain. + + Args: + T: Number of timesteps + gain: Scaling factor for firing probability (default 1.0) + + Input: (B, C, H, W) normalized image in [0, 1] + Output: (B, T, C, H, W) binary spike tensor + """ + + def __init__(self, T: int = 25, gain: float = 1.0): + super().__init__() + self.T = T + self.gain = gain + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C, H, W) image tensor, values in [0, 1] + Returns: + spikes: (B, T, C, H, W) binary spike tensor + """ + # Clamp to valid probability range + prob = (x * self.gain).clamp(0, 1) + + # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W) + prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1) + + # Sample spikes + spikes = torch.bernoulli(prob) + + return spikes + + +class DirectEncoder(nn.Module): + """ + Direct encoding - feed static image as constant current. + + Common in surrogate gradient papers: no spike encoding at input, + let spiking emerge from the network dynamics. + + Input: (B, C, H, W) image + Output: (B, T, C, H, W) repeated image (as analog current) + """ + + def __init__(self, T: int = 25): + super().__init__() + self.T = T + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Simply repeat across time + return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1) + + +class ConvLIFBlock(nn.Module): + """ + Conv → BatchNorm → LIF block. + + Maintains spatial structure while adding spiking dynamics. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + beta: float = 0.9, + threshold: float = 1.0, + spike_grad=None, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d( + in_channels, out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.lif = snn.Leaky( + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + ) + + def forward( + self, + x: torch.Tensor, + mem: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: (B, C_in, H, W) input (spikes or current) + mem: (B, C_out, H', W') membrane potential + Returns: + spk: (B, C_out, H', W') output spikes + mem: (B, C_out, H', W') updated membrane + """ + cur = self.bn(self.conv(x)) + spk, mem = self.lif(cur, mem) + return spk, mem + + +class ConvLyapunovSNN(nn.Module): + """ + Convolutional SNN with Lyapunov exponent regularization. + + Architecture for CIFAR-10 (32x32x3): + Input → Encoder → [Conv-LIF-Pool] × N → FC → Output + + Properly preserves spatial structure for hierarchical feature learning. + + Args: + in_channels: Input channels (3 for RGB) + num_classes: Output classes + channels: List of channel sizes for conv layers + T: Number of timesteps + beta: LIF membrane decay + threshold: LIF firing threshold + encoding: 'rate', 'direct', or 'none' (pre-encoded input) + encoding_gain: Gain for rate encoding + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + channels: List[int] = [64, 128, 256], + T: int = 25, + beta: float = 0.9, + threshold: float = 1.0, + encoding: str = 'rate', + encoding_gain: float = 1.0, + dropout: float = 0.2, + ): + super().__init__() + + self.T = T + self.encoding_type = encoding + self.channels = channels + self.num_layers = len(channels) + + # Input encoder + if encoding == 'rate': + self.encoder = RateEncoder(T=T, gain=encoding_gain) + elif encoding == 'direct': + self.encoder = DirectEncoder(T=T) + else: + self.encoder = None # Expect pre-encoded (B, T, C, H, W) input + + # Build conv-LIF layers + self.blocks = nn.ModuleList() + self.pools = nn.ModuleList() + + ch_in = in_channels + for ch_out in channels: + self.blocks.append( + ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold) + ) + self.pools.append(nn.AvgPool2d(2)) + ch_in = ch_out + + # Calculate output spatial size after pooling + # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers) + spatial_size = 32 // (2 ** len(channels)) + fc_input = channels[-1] * spatial_size * spatial_size + + # Fully connected readout + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(fc_input, num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]: + """Initialize membrane potentials for all layers.""" + mems = [] + H, W = 32, 32 + for i, ch in enumerate(self.channels): + H, W = H // 2, W // 2 # After pooling + # Actually we need size BEFORE pooling for LIF + H_pre, W_pre = H * 2, W * 2 + mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype)) + return mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + """ + Forward pass with optional Lyapunov computation. + + Args: + x: Input tensor + - If encoder: (B, C, H, W) static image + - If no encoder: (B, T, C, H, W) pre-encoded spikes + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude + + Returns: + logits: (B, num_classes) + lyap_est: Scalar Lyapunov estimate or None + recordings: Optional dict with spike recordings + """ + # Encode input if needed + if self.encoder is not None: + x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W) + + B, T, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # Initialize membrane potentials + mems = self._init_mem(B, device, dtype) + + # For accumulating output spikes + spike_sum = None + + # Lyapunov setup + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + # Time loop + for t in range(T): + x_t = x[:, t] # (B, C, H, W) + + # Forward through conv-LIF blocks + h = x_t + new_mems = [] + for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): + h, mem = block(h, mems[i]) + new_mems.append(mem) + h = pool(h) # Spatial downsampling + + mems = new_mems + + # Accumulate final layer spikes + if spike_sum is None: + spike_sum = h.view(B, -1) + else: + spike_sum = spike_sum + h.view(B, -1) + + # Lyapunov computation + if compute_lyapunov: + h_p = x_t + new_mems_p = [] + for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): + h_p, mem_p = block(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = pool(h_p) + + # Compute divergence + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(self.num_layers): + diff = new_mems_p[i] - new_mems[i] + delta_sq += (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize perturbation + for i in range(self.num_layers): + diff = new_mems_p[i] - new_mems[i] + norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12) + # Broadcast norm to spatial dimensions + norm = norm.view(B, 1, 1, 1) + new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm + + mems_p = new_mems_p + + # Readout + out = self.dropout(spike_sum) + logits = self.fc(out) + + if compute_lyapunov: + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + return logits, lyap_est, None + + +class VGGLyapunovSNN(nn.Module): + """ + VGG-style deep Conv-SNN with Lyapunov regularization. + + Deeper architecture for more challenging benchmarks. + Uses multiple conv layers between pooling to increase depth. + + Architecture (VGG-9 style): + [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + T: int = 25, + beta: float = 0.9, + threshold: float = 1.0, + encoding: str = 'rate', + dropout: float = 0.3, + ): + super().__init__() + + self.T = T + self.encoding_type = encoding + + spike_grad = surrogate.fast_sigmoid(slope=25) + + if encoding == 'rate': + self.encoder = RateEncoder(T=T) + elif encoding == 'direct': + self.encoder = DirectEncoder(T=T) + else: + self.encoder = None + + # VGG-style blocks: (in_ch, out_ch, num_convs) + block_configs = [ + (in_channels, 64, 2), # 32x32 -> 16x16 + (64, 128, 2), # 16x16 -> 8x8 + (128, 256, 2), # 8x8 -> 4x4 + ] + + self.blocks = nn.ModuleList() + for in_ch, out_ch, n_convs in block_configs: + layers = [] + for i in range(n_convs): + ch_in = in_ch if i == 0 else out_ch + layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False)) + layers.append(nn.BatchNorm2d(out_ch)) + self.blocks.append(nn.ModuleList(layers)) + + # LIF neurons for each conv layer + self.lifs = nn.ModuleList([ + snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + for _ in range(6) # 2 convs × 3 blocks + ]) + + self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)]) + + # FC layers + self.fc1 = nn.Linear(256 * 4 * 4, 512) + self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + self.dropout = nn.Dropout(dropout) + self.fc2 = nn.Linear(512, num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + + if self.encoder is not None: + x = self.encoder(x) + + B, T, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # Initialize all membrane potentials + # For each conv layer output + mem_shapes = [ + (B, 64, 32, 32), (B, 64, 32, 32), # Block 1 + (B, 128, 16, 16), (B, 128, 16, 16), # Block 2 + (B, 256, 8, 8), (B, 256, 8, 8), # Block 3 + (B, 512), # FC + ] + mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes] + + spike_sum = torch.zeros(B, 512, device=device, dtype=dtype) + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + for t in range(T): + h = x[:, t] + + lif_idx = 0 + for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)): + for i in range(0, len(block_layers), 2): # Conv, BN pairs + conv, bn = block_layers[i], block_layers[i + 1] + h = bn(conv(h)) + h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx]) + lif_idx += 1 + h = pool(h) + + # FC layers + h = h.view(B, -1) + h = self.fc1(h) + h, mems[6] = self.lif_fc(h, mems[6]) + spike_sum = spike_sum + h + + # Lyapunov (simplified - just on last layer) + if compute_lyapunov: + diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6]) + delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12 + if t > 0: + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6]) + + out = self.dropout(spike_sum) + logits = self.fc2(out) + + lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None + + return logits, lyap_est, None + + +def create_conv_snn( + model_type: str = 'simple', + **kwargs, +) -> nn.Module: + """ + Factory function for Conv-SNN models. + + Args: + model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style) + **kwargs: Arguments passed to model constructor + """ + if model_type == 'simple': + return ConvLyapunovSNN(**kwargs) + elif model_type == 'vgg': + return VGGLyapunovSNN(**kwargs) + else: + raise ValueError(f"Unknown model_type: {model_type}") diff --git a/files/models/snn.py b/files/models/snn.py new file mode 100644 index 0000000..b1cf633 --- /dev/null +++ b/files/models/snn.py @@ -0,0 +1,141 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SurrogateStep(torch.autograd.Function): + """ + Heaviside with a smooth surrogate gradient (fast sigmoid). + """ + @staticmethod + def forward(ctx, x: torch.Tensor, alpha: float): + ctx.save_for_backward(x) + ctx.alpha = alpha + return (x > 0).to(x.dtype) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + alpha = ctx.alpha + # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid) + # Use fast sigmoid: s = 1 / (1 + |alpha*x|) + s = 1.0 / (1.0 + (alpha * x).abs()) + grad = grad_output * s * s + return grad, None + + +def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor: + return SurrogateStep.apply(x, alpha) + + +class LIFLayer(nn.Module): + """ + Single LIF layer without recurrent synapses between neurons. + Dynamics per neuron i: + v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1} + s_t = H( v_t - v_th ) with surrogate gradient + """ + def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): + super().__init__() + self.linear = nn.Linear(input_dim, hidden_dim, bias=True) + self.v_threshold = float(v_threshold) + self.decay = float(decay) + self.spike_alpha = float(spike_alpha) + self.rec_strength = float(rec_strength) + self.rec = None + if self.rec_strength != 0.0: + self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False) + + nn.init.xavier_uniform_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + if self.rec is not None: + nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale) + + def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H) + I_t = self.linear(x_t) # (B, H) + R_t = 0.0 + if self.rec is not None: + R_t = self.rec_strength * self.rec(s_prev) + v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev + s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha) + return v_t, s_t + + +class SimpleSNN(nn.Module): + """ + Minimal SNN for SHD-like input (B,T,D): + - One LIF hidden layer + - Readout linear on time-summed spikes + """ + def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): + super().__init__() + self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale) + self.readout = nn.Linear(hidden_dim, num_classes) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + @torch.no_grad() + def _init_states(self, batch_size: int, hidden_dim: int, device, dtype): + v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) + s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) + return v0, s0 + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-3, + lyap_safe_eps: float = 1e-8, + lyap_measure: str = "v", # "v" or "s" + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + x: (B, T, D) + Returns: + logits: (B, C) + lyap_est: scalar tensor if compute_lyapunov else None + """ + assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}" + B, T, D = x.shape + device, dtype = x.device, x.dtype + H = self.readout.in_features + + v, s = self._init_states(B, H, device, dtype) + spike_sum = torch.zeros(B, H, device=device, dtype=dtype) + + if compute_lyapunov: + v_p = v + lyap_eps * torch.randn_like(v) + s_p = s.clone() + delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps + lyap_terms = [] + + for t in range(T): + x_t = x[:, t, :] + v, s = self.lif(x_t, v, s) + spike_sum = spike_sum + s + + if compute_lyapunov: + # run perturbed trajectory through same ops + v_p, s_p = self.lif(x_t, v_p, s_p) + if lyap_measure == "s": + delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps + else: + delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps + ratio = delta_t / delta_prev + lyap_terms.append(torch.log(ratio + lyap_safe_eps)) + delta_prev = delta_t + + logits = self.readout(spike_sum) # (B, C) + + if compute_lyapunov: + lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,) + lyap_est = lyap_batch.mean() # scalar + else: + lyap_est = None + + return logits, lyap_est + + diff --git a/files/models/snn_snntorch.py b/files/models/snn_snntorch.py new file mode 100644 index 0000000..71c1c18 --- /dev/null +++ b/files/models/snn_snntorch.py @@ -0,0 +1,398 @@ +""" +snnTorch-based SNN with Lyapunov exponent regularization. + +This module provides deep SNN architectures using snnTorch with proper +finite-time Lyapunov exponent computation for training stabilization. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import snntorch as snn +from snntorch import surrogate + + +class LyapunovSNN(nn.Module): + """ + Multi-layer SNN using snnTorch with Lyapunov exponent computation. + + Architecture: + Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits + + Args: + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers) + num_classes: Number of output classes + beta: Membrane potential decay factor (0 < beta < 1) + threshold: Firing threshold + spike_grad: Surrogate gradient function (default: fast_sigmoid) + dropout: Dropout probability between layers (0 = no dropout) + """ + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + beta: float = 0.9, + threshold: float = 1.0, + spike_grad: Optional[Any] = None, + dropout: float = 0.0, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.hidden_dims = hidden_dims + self.num_layers = len(hidden_dims) + self.beta = beta + self.threshold = threshold + + # Build layers + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + self.dropouts = nn.ModuleList() if dropout > 0 else None + + dims = [input_dim] + hidden_dims + for i in range(self.num_layers): + self.linears.append(nn.Linear(dims[i], dims[i + 1])) + self.lifs.append( + snn.Leaky( + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + reset_mechanism="subtract", + ) + ) + if dropout > 0: + self.dropouts.append(nn.Dropout(p=dropout)) + + # Readout layer + self.readout = nn.Linear(hidden_dims[-1], num_classes) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + for lin in self.linears: + nn.init.xavier_uniform_(lin.weight) + nn.init.zeros_(lin.bias) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]: + """Initialize membrane potentials for all layers.""" + mems = [] + for dim in self.hidden_dims: + mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + return mems + + def _step( + self, + x_t: torch.Tensor, + mems: List[torch.Tensor], + training: bool = True, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Single timestep forward pass. + + Returns: + spike_out: Output spikes from last layer (B, H_last) + new_mems: Updated membrane potentials + all_mems: Membrane potentials from all layers (for Lyapunov) + """ + new_mems = [] + all_mems = [] + + h = x_t + for i in range(self.num_layers): + h = self.linears[i](h) + spk, mem = self.lifs[i](h, mems[i]) + new_mems.append(mem) + all_mems.append(mem) + h = spk + if self.dropouts is not None and training: + h = self.dropouts[i](h) + + return h, new_mems, all_mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + lyap_layers: Optional[List[int]] = None, + record_states: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]: + """ + Forward pass with optional Lyapunov exponent computation. + + Args: + x: Input tensor (B, T, D) + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude for Lyapunov computation + lyap_layers: Which layers to measure (default: all). + e.g., [0] for first layer only, [-1] for last layer + record_states: Whether to record spikes and membrane potentials + + Returns: + logits: Classification logits (B, num_classes) + lyap_est: Estimated Lyapunov exponent (scalar) or None + recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None + """ + B, T, D = x.shape + device, dtype = x.device, x.dtype + + # Initialize states + mems = self._init_states(B, device, dtype) + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + # Recording setup + if record_states: + spike_rec = [] + mem_rec = [] + + # Lyapunov setup + if compute_lyapunov: + if lyap_layers is None: + lyap_layers = list(range(self.num_layers)) + + # Perturbed trajectory - perturb all membrane potentials + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + # Time loop + for t in range(T): + x_t = x[:, t, :] + + # Nominal trajectory + spk, mems, all_mems = self._step(x_t, mems, training=self.training) + spike_sum = spike_sum + spk + + if record_states: + spike_rec.append(spk.detach()) + mem_rec.append(all_mems[-1].detach()) # Last layer membrane + + if compute_lyapunov: + # Perturbed trajectory + _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False) + + # Compute divergence across selected layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + delta_p_sq = torch.zeros(B, device=device, dtype=dtype) + + for layer_idx in lyap_layers: + diff = all_mems_p[layer_idx] - all_mems[layer_idx] + delta_sq += (diff ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + + # Renormalization step (key for numerical stability) + # Rescale perturbation back to fixed magnitude + for layer_idx in lyap_layers: + diff = mems_p[layer_idx] - mems[layer_idx] + # Normalize to maintain fixed perturbation magnitude + norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12 + diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm + mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized + + # Accumulate log-divergence + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + logits = self.readout(spike_sum) + + if compute_lyapunov: + # Average over time and batch + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + if record_states: + recordings = { + "spikes": torch.stack(spike_rec, dim=1), # (B, T, H) + "membrane": torch.stack(mem_rec, dim=1), # (B, T, H) + } + else: + recordings = None + + return logits, lyap_est, recordings + + +class RecurrentLyapunovSNN(nn.Module): + """ + Recurrent SNN with Lyapunov exponent computation. + + Uses snnTorch's RSynaptic (recurrent synaptic) neurons for + richer temporal dynamics. + + Args: + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes + num_classes: Number of output classes + alpha: Synaptic current decay rate + beta: Membrane potential decay rate + threshold: Firing threshold + """ + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + alpha: float = 0.9, + beta: float = 0.85, + threshold: float = 1.0, + spike_grad: Optional[Any] = None, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.hidden_dims = hidden_dims + self.num_layers = len(hidden_dims) + self.alpha = alpha + self.beta = beta + + # Build layers with recurrent synaptic neurons + self.linears = nn.ModuleList() + self.neurons = nn.ModuleList() + + dims = [input_dim] + hidden_dims + for i in range(self.num_layers): + self.linears.append(nn.Linear(dims[i], dims[i + 1])) + self.neurons.append( + snn.RSynaptic( + alpha=alpha, + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + reset_mechanism="subtract", + all_to_all=True, + linear_features=dims[i + 1], + ) + ) + + self.readout = nn.Linear(hidden_dims[-1], num_classes) + self._init_weights() + + def _init_weights(self): + for lin in self.linears: + nn.init.xavier_uniform_(lin.weight) + nn.init.zeros_(lin.bias) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + def _init_states(self, batch_size: int, device, dtype): + """Initialize synaptic currents and membrane potentials.""" + syns = [] + mems = [] + for dim in self.hidden_dims: + syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + return syns, mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass with optional Lyapunov computation.""" + B, T, D = x.shape + device, dtype = x.device, x.dtype + + syns, mems = self._init_states(B, device, dtype) + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + if compute_lyapunov: + # Perturb both synaptic currents and membrane potentials + syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns] + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + for t in range(T): + x_t = x[:, t, :] + + # Nominal trajectory + h = x_t + new_syns, new_mems = [], [] + for i in range(self.num_layers): + h = self.linears[i](h) + spk, syn, mem = self.neurons[i](h, syns[i], mems[i]) + new_syns.append(syn) + new_mems.append(mem) + h = spk + syns, mems = new_syns, new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory + h_p = x_t + new_syns_p, new_mems_p = [], [] + for i in range(self.num_layers): + h_p = self.linears[i](h_p) + spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i]) + new_syns_p.append(syn_p) + new_mems_p.append(mem_p) + h_p = spk_p + + # Compute divergence (on membrane potentials) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(self.num_layers): + diff_m = new_mems_p[i] - new_mems[i] + diff_s = new_syns_p[i] - new_syns[i] + delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize perturbation + total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem + scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12) + + syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i]) + for i in range(self.num_layers)] + mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i]) + for i in range(self.num_layers)] + + logits = self.readout(spike_sum) + + if compute_lyapunov: + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + return logits, lyap_est + + +def create_snn( + model_type: str, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + **kwargs, +) -> nn.Module: + """ + Factory function to create SNN models. + + Args: + model_type: "feedforward" or "recurrent" + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes + num_classes: Number of output classes + **kwargs: Additional arguments passed to model constructor + + Returns: + SNN model instance + """ + if model_type == "feedforward": + return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) + elif model_type == "recurrent": + return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) + else: + raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'") diff --git a/files/tests/conftest.py b/files/tests/conftest.py new file mode 100644 index 0000000..a63c259 --- /dev/null +++ b/files/tests/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + +# Ensure project root is importable so `import files...` works when tests are under files/tests +_HERE = os.path.dirname(__file__) +_FILES_DIR = os.path.dirname(_HERE) +_ROOT = os.path.dirname(_FILES_DIR) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + + diff --git a/files/tests/test_data_io.py b/files/tests/test_data_io.py new file mode 100644 index 0000000..1f2ccd8 --- /dev/null +++ b/files/tests/test_data_io.py @@ -0,0 +1,11 @@ +import torch +from files.data_io.dataset_loader import get_dataloader + + +def test_dataloader_shape(): + """Smoke test: verify dataloader output shape.""" + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + x, y = next(iter(train_loader)) + assert isinstance(x, torch.Tensor) + assert x.ndim == 3 + assert y.ndim == 1 diff --git a/files/tests/test_shd_loader_properties.py b/files/tests/test_shd_loader_properties.py new file mode 100644 index 0000000..237740c --- /dev/null +++ b/files/tests/test_shd_loader_properties.py @@ -0,0 +1,40 @@ +import torch +from files.data_io.dataset_loader import get_dataloader, SHDDataset + + +def test_shd_dataset_global_T_D_consistency(): + # Ensure SHDDataset computes global T and adaptive D once and applies consistently + ds = SHDDataset( + data_dir="/u/yurenh2/ml-projects/snn-training/files/data", + split="train", + dt_ms=1.0, + default_D=700, + ) + assert ds.T >= 1 + assert ds.D >= 700 # by construction at least default_D + x0, y0 = ds[0] + x1, y1 = ds[1] + assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor) + assert x0.shape == (ds.T, ds.D) + assert x1.shape == (ds.T, ds.D) + assert isinstance(y0, int) and isinstance(y1, int) + # values should be finite + assert torch.isfinite(x0).all() + assert torch.isfinite(x1).all() + + +def test_dataloader_batch_shapes_and_finiteness(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + assert isinstance(xb, torch.Tensor) + assert xb.ndim == 3 # (B, T, D) + assert yb.ndim == 1 + B, T, D = xb.shape + assert B >= 1 and T >= 1 and D >= 1 + # finiteness and no NaNs + assert torch.isfinite(xb).all() + # After normalization (enabled in config), per-sample per-channel mean over time should be ~0 + mean_t = xb.mean(dim=1) # (B, D) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3) + + diff --git a/files/tests/test_train_smoke.py b/files/tests/test_train_smoke.py new file mode 100644 index 0000000..e04e40b --- /dev/null +++ b/files/tests/test_train_smoke.py @@ -0,0 +1,33 @@ +import torch +from files.models.snn import SimpleSNN +from files.data_io.dataset_loader import get_dataloader +import torch.nn as nn +import torch.optim as optim + + +def _one_step(model, xb, yb, lyapunov=False): + model.train() + opt = optim.Adam(model.parameters(), lr=1e-3) + ce = nn.CrossEntropyLoss() + opt.zero_grad(set_to_none=True) + logits, lyap = model(xb, compute_lyapunov=lyapunov) + loss = ce(logits, yb) + if lyapunov and lyap is not None: + loss = loss + 0.1 * (lyap - 0.0) ** 2 + loss.backward() + opt.step() + assert torch.isfinite(loss).all() + + +def test_train_step_baseline_and_lyapunov(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + B, T, D = xb.shape + C = 20 + model = SimpleSNN(input_dim=D, hidden_dim=64, num_classes=C) + # baseline + _one_step(model, xb, yb, lyapunov=False) + # lyapunov-regularized + _one_step(model, xb, yb, lyapunov=True) + + diff --git a/files/tests/test_transforms_normalize.py b/files/tests/test_transforms_normalize.py new file mode 100644 index 0000000..cc5a3d4 --- /dev/null +++ b/files/tests/test_transforms_normalize.py @@ -0,0 +1,53 @@ +import torch +from files.data_io.transforms.normalization import Normalize + + +def test_normalize_zscore_2d_zero_mean_unitish_var(): + # (T, D) toy example + T, D = 5, 3 + x = torch.arange(T * D, dtype=torch.float32).reshape(T, D) + norm = Normalize(mode="zscore") + xz = norm(x) + # mean over time per channel should be ~0 + mean_t = xz.mean(dim=0) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) + # std over time per channel should be ~1 + std_t = xz.std(dim=0, unbiased=False) + assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5) + + +def test_normalize_zscore_3d_per_sample(): + # (B, T, D) + B, T, D = 2, 6, 4 + x = torch.randn(B, T, D) + norm = Normalize(mode="zscore") + xz = norm(x) + mean_t = xz.mean(dim=1) # (B, D) + std_t = xz.std(dim=1, unbiased=False) # (B, D) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) + assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4) + + +def test_normalize_minmax_range_01_2d(): + T, D = 7, 2 + x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D) + norm = Normalize(mode="minmax") + xm = norm(x) + assert xm.min().item() >= -1e-6 + assert xm.max().item() <= 1 + 1e-6 + # Check endpoints map to 0 and 1 + assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6) + assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6) + + +def test_normalize_rejects_bad_ndim(): + x = torch.ones(1, 2, 3, 4) + norm = Normalize() + try: + _ = norm(x) + except ValueError: + pass + else: + raise AssertionError("Normalize should raise on x.ndim not in {2,3}") + + diff --git a/files/train_mvp.py b/files/train_mvp.py new file mode 100644 index 0000000..b89ddc6 --- /dev/null +++ b/files/train_mvp.py @@ -0,0 +1,202 @@ +import os +import sys +import json +import csv +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(_HERE) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) +import argparse +import time +from typing import Optional + +import torch +import torch.nn as nn +import torch.optim as optim + +from files.data_io.dataset_loader import get_dataloader, SHDDataset +from files.models.snn import SimpleSNN +from tqdm.auto import tqdm + + +def _prepare_run_dir(base_dir: str): + ts = time.strftime("%Y%m%d-%H%M%S") + run_dir = os.path.join(base_dir, ts) + os.makedirs(run_dir, exist_ok=True) + return run_dir + + +def _append_metrics(csv_path: str, row: dict): + write_header = not os.path.exists(csv_path) + with open(csv_path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=row.keys()) + if write_header: + writer.writeheader() + writer.writerow(row) + + +def parse_args(): + p = argparse.ArgumentParser(description="MVP training: baseline vs Lyapunov-regularized") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="YAML config for dataloader") + p.add_argument("--epochs", type=int, default=2) + p.add_argument("--hidden", type=int, default=256) + p.add_argument("--classes", type=int, default=20, help="Number of classes") + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regulation") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Weight for Lyapunov penalty") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target average log growth (≈0 for neutral)") + p.add_argument("--no-progress", action="store_true", help="Disable tqdm progress bar") + p.add_argument("--out_dir", type=str, default="runs/mvp", help="Directory to save metrics/checkpoints") + p.add_argument("--log_batches", action="store_true", help="Also log per-batch metrics to CSV") + # Model dynamics and recurrence controls + p.add_argument("--spike_alpha", type=float, default=5.0, help="Surrogate spike sharpness") + p.add_argument("--decay", type=float, default=0.95, help="Membrane decay") + p.add_argument("--v_threshold", type=float, default=1.0, help="Firing threshold") + p.add_argument("--rec_strength", type=float, default=0.0, help="Recurrent coupling strength on spikes") + p.add_argument("--rec_init_scale", type=float, default=1.0, help="Gain for recurrent weight init") + # Lyapunov measurement controls + p.add_argument("--lyap_measure", type=str, default="v", choices=["v", "s"], help="Measure divergence on 'v' or 's'") + p.add_argument("--lyap_eps", type=float, default=1e-3, help="Initial perturbation magnitude") + return p.parse_args() + + +def train_one_epoch( + model: SimpleSNN, + loader, + optimizer, + device, + ce_loss: nn.Module, + lyapunov: bool, + lambda_reg: float, + lambda_target: float, + progress: bool, + run_dir: str | None = None, + epoch_idx: int | None = None, + log_batches: bool = False, + lyap_measure: str = "v", + lyap_eps: float = 1e-3, +): + model.train() + total = 0 + correct = 0 + running_loss = 0.0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader + for bidx, (x, y) in enumerate(iterator): + x = x.to(device) # (B, T, D) + y = y.to(device) + optimizer.zero_grad(set_to_none=True) + logits, lyap_est = model( + x, + compute_lyapunov=lyapunov, + lyap_eps=lyap_eps, + lyap_measure=lyap_measure, + ) + ce = ce_loss(logits, y) + if lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.detach().item()) + else: + loss = ce + loss.backward() + optimizer.step() + + running_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + batch_correct = (preds == y).sum().item() + correct += batch_correct + total += x.size(0) + if log_batches and run_dir is not None and epoch_idx is not None: + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "batch", + "epoch": int(epoch_idx), + "batch": int(bidx), + "loss": float(loss.item()), + "acc": float(batch_correct / max(x.size(0), 1)), + "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"), + "time_sec": float(0.0), + }, + ) + if progress: + avg_loss = running_loss / max(total, 1) + avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None + postfix = {"loss": f"{avg_loss:.4f}"} + if avg_lyap is not None: + postfix["lyap"] = f"{avg_lyap:.4f}" + iterator.set_postfix(postfix) + + avg_loss = running_loss / max(total, 1) + acc = correct / max(total, 1) + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return avg_loss, acc, avg_lyap + + +def main(): + args = parse_args() + device = torch.device(args.device) + + # Prepare output directory and save run config + run_dir = _prepare_run_dir(args.out_dir) + with open(os.path.join(run_dir, "args.json"), "w") as f: + json.dump(vars(args), f, indent=2) + + train_loader, val_loader = get_dataloader(args.cfg) + + # Infer input dim and classes from a sample and args + xb, yb = next(iter(train_loader)) + _, T, D = xb.shape + C = args.classes + + model = SimpleSNN( + input_dim=D, + hidden_dim=args.hidden, + num_classes=C, + v_threshold=args.v_threshold, + decay=args.decay, + spike_alpha=args.spike_alpha, + rec_strength=args.rec_strength, + rec_init_scale=args.rec_init_scale, + ).to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + ce_loss = nn.CrossEntropyLoss() + + print(f"Starting training on {device} | lyapunov={args.lyapunov} lambda={args.lambda_reg} target={args.lambda_target}") + print(f"Saving run to: {run_dir}") + for epoch in range(1, args.epochs + 1): + t0 = time.time() + tr_loss, tr_acc, tr_lyap = train_one_epoch( + model, train_loader, optimizer, device, ce_loss, + lyapunov=args.lyapunov, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, + progress=(not args.no_progress), + run_dir=run_dir, + epoch_idx=epoch, + log_batches=args.log_batches, + lyap_measure=args.lyap_measure, + lyap_eps=args.lyap_eps, + ) + dt = time.time() - t0 + lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else "" + print(f"[Epoch {epoch}] loss={tr_loss:.4f} acc={tr_acc:.3f}{lyap_str} ({dt:.1f}s)") + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "epoch", + "epoch": int(epoch), + "batch": int(-1), + "loss": float(tr_loss), + "acc": float(tr_acc), + "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"), + "time_sec": float(dt), + }, + ) + + +if __name__ == "__main__": + main() + + diff --git a/files/train_snntorch.py b/files/train_snntorch.py new file mode 100644 index 0000000..dbdce37 --- /dev/null +++ b/files/train_snntorch.py @@ -0,0 +1,345 @@ +""" +Training script for snnTorch-based deep SNNs with Lyapunov regularization. + +Usage: + # Baseline (no Lyapunov) + python files/train_snntorch.py --hidden 256 128 --epochs 10 + + # With Lyapunov regularization + python files/train_snntorch.py --hidden 256 128 --epochs 10 --lyapunov --lambda_reg 0.1 + + # Recurrent model + python files/train_snntorch.py --model recurrent --hidden 256 --epochs 10 --lyapunov +""" + +import os +import sys +import json +import csv + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(_HERE) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import time +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm.auto import tqdm + +from files.data_io.dataset_loader import get_dataloader +from files.models.snn_snntorch import create_snn + + +def _prepare_run_dir(base_dir: str) -> str: + ts = time.strftime("%Y%m%d-%H%M%S") + run_dir = os.path.join(base_dir, ts) + os.makedirs(run_dir, exist_ok=True) + return run_dir + + +def _append_metrics(csv_path: str, row: dict): + write_header = not os.path.exists(csv_path) + with open(csv_path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=row.keys()) + if write_header: + writer.writeheader() + writer.writerow(row) + + +def parse_args(): + p = argparse.ArgumentParser( + description="Train deep SNN with snnTorch and optional Lyapunov regularization" + ) + + # Model architecture + p.add_argument( + "--model", type=str, default="feedforward", choices=["feedforward", "recurrent"], + help="Model type: 'feedforward' (LIF) or 'recurrent' (RSynaptic)" + ) + p.add_argument( + "--hidden", type=int, nargs="+", default=[256], + help="Hidden layer sizes (e.g., --hidden 256 128 for 2 layers)" + ) + p.add_argument("--classes", type=int, default=20, help="Number of output classes") + p.add_argument("--beta", type=float, default=0.9, help="Membrane decay (beta)") + p.add_argument("--threshold", type=float, default=1.0, help="Firing threshold") + p.add_argument("--dropout", type=float, default=0.0, help="Dropout between layers") + p.add_argument( + "--surrogate_slope", type=float, default=25.0, + help="Slope for fast_sigmoid surrogate gradient" + ) + + # Recurrent-specific (only for --model recurrent) + p.add_argument("--alpha", type=float, default=0.9, help="Synaptic current decay (recurrent only)") + + # Training + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--weight_decay", type=float, default=0.0, help="L2 regularization") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") + + # Lyapunov regularization + p.add_argument("--lyapunov", action="store_true", help="Enable Lyapunov regularization") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov penalty weight") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") + p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation magnitude") + p.add_argument( + "--lyap_layers", type=int, nargs="*", default=None, + help="Which layers to measure (default: all). E.g., --lyap_layers 0 1" + ) + + # Output + p.add_argument("--out_dir", type=str, default="runs/snntorch", help="Output directory") + p.add_argument("--log_batches", action="store_true", help="Log per-batch metrics") + p.add_argument("--no-progress", action="store_true", help="Disable progress bar") + p.add_argument("--save_model", action="store_true", help="Save model checkpoint") + + return p.parse_args() + + +def train_one_epoch( + model: nn.Module, + loader, + optimizer: optim.Optimizer, + device: torch.device, + ce_loss: nn.Module, + lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + lyap_layers: Optional[List[int]], + progress: bool, + run_dir: Optional[str] = None, + epoch_idx: Optional[int] = None, + log_batches: bool = False, +): + model.train() + total = 0 + correct = 0 + running_loss = 0.0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False, dynamic_ncols=True) if progress else loader + + for bidx, (x, y) in enumerate(iterator): + x = x.to(device) # (B, T, D) + y = y.to(device) + + optimizer.zero_grad(set_to_none=True) + + logits, lyap_est = model( + x, + compute_lyapunov=lyapunov, + lyap_eps=lyap_eps, + lyap_layers=lyap_layers, + ) + + ce = ce_loss(logits, y) + + if lyapunov and lyap_est is not None: + # Penalize deviation from target Lyapunov exponent + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.detach().item()) + else: + loss = ce + + loss.backward() + optimizer.step() + + running_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + batch_correct = (preds == y).sum().item() + correct += batch_correct + total += x.size(0) + + if log_batches and run_dir is not None and epoch_idx is not None: + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "batch", + "epoch": int(epoch_idx), + "batch": int(bidx), + "loss": float(loss.item()), + "acc": float(batch_correct / max(x.size(0), 1)), + "lyap": float(lyap_est.item()) if (lyapunov and lyap_est is not None) else float("nan"), + "time_sec": 0.0, + }, + ) + + if progress: + avg_loss = running_loss / max(total, 1) + avg_lyap = (sum(lyap_vals) / len(lyap_vals)) if lyap_vals else None + postfix = {"loss": f"{avg_loss:.4f}", "acc": f"{correct / total:.3f}"} + if avg_lyap is not None: + postfix["lyap"] = f"{avg_lyap:.4f}" + iterator.set_postfix(postfix) + + avg_loss = running_loss / max(total, 1) + acc = correct / max(total, 1) + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return avg_loss, acc, avg_lyap + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader, + device: torch.device, + ce_loss: nn.Module, + progress: bool, +): + model.eval() + total = 0 + correct = 0 + running_loss = 0.0 + + iterator = tqdm(loader, desc="eval", leave=False, dynamic_ncols=True) if progress else loader + + for x, y in iterator: + x = x.to(device) + y = y.to(device) + + logits, _ = model(x, compute_lyapunov=False) + loss = ce_loss(logits, y) + + running_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + correct += (preds == y).sum().item() + total += x.size(0) + + avg_loss = running_loss / max(total, 1) + acc = correct / max(total, 1) + return avg_loss, acc + + +def main(): + args = parse_args() + device = torch.device(args.device) + + # Prepare output directory + run_dir = _prepare_run_dir(args.out_dir) + with open(os.path.join(run_dir, "args.json"), "w") as f: + json.dump(vars(args), f, indent=2) + + # Load data + train_loader, val_loader = get_dataloader(args.cfg) + + # Infer dimensions from data + xb, yb = next(iter(train_loader)) + _, T, D = xb.shape + C = args.classes + + print(f"Data: T={T}, D={D}, classes={C}") + print(f"Model: {args.model}, hidden={args.hidden}") + + # Create model + from snntorch import surrogate + spike_grad = surrogate.fast_sigmoid(slope=args.surrogate_slope) + + if args.model == "feedforward": + model = create_snn( + model_type="feedforward", + input_dim=D, + hidden_dims=args.hidden, + num_classes=C, + beta=args.beta, + threshold=args.threshold, + spike_grad=spike_grad, + dropout=args.dropout, + ) + else: # recurrent + model = create_snn( + model_type="recurrent", + input_dim=D, + hidden_dims=args.hidden, + num_classes=C, + alpha=args.alpha, + beta=args.beta, + threshold=args.threshold, + spike_grad=spike_grad, + ) + + model = model.to(device) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Parameters: {num_params:,}") + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + ce_loss = nn.CrossEntropyLoss() + + print(f"\nTraining on {device} | lyapunov={args.lyapunov} λ_reg={args.lambda_reg} λ_target={args.lambda_target}") + print(f"Output: {run_dir}\n") + + best_val_acc = 0.0 + + for epoch in range(1, args.epochs + 1): + t0 = time.time() + + tr_loss, tr_acc, tr_lyap = train_one_epoch( + model=model, + loader=train_loader, + optimizer=optimizer, + device=device, + ce_loss=ce_loss, + lyapunov=args.lyapunov, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + lyap_layers=args.lyap_layers, + progress=(not args.no_progress), + run_dir=run_dir, + epoch_idx=epoch, + log_batches=args.log_batches, + ) + + val_loss, val_acc = evaluate( + model=model, + loader=val_loader, + device=device, + ce_loss=ce_loss, + progress=(not args.no_progress), + ) + + dt = time.time() - t0 + lyap_str = f" lyap={tr_lyap:.4f}" if tr_lyap is not None else "" + + print( + f"[Epoch {epoch:3d}] " + f"train_loss={tr_loss:.4f} train_acc={tr_acc:.3f}{lyap_str} | " + f"val_loss={val_loss:.4f} val_acc={val_acc:.3f} ({dt:.1f}s)" + ) + + _append_metrics( + os.path.join(run_dir, "metrics.csv"), + { + "step": "epoch", + "epoch": int(epoch), + "batch": -1, + "loss": float(tr_loss), + "acc": float(tr_acc), + "val_loss": float(val_loss), + "val_acc": float(val_acc), + "lyap": float(tr_lyap) if tr_lyap is not None else float("nan"), + "time_sec": float(dt), + }, + ) + + # Save best model + if args.save_model and val_acc > best_val_acc: + best_val_acc = val_acc + torch.save(model.state_dict(), os.path.join(run_dir, "best_model.pt")) + + print(f"\nTraining complete. Best val_acc: {best_val_acc:.3f}") + if args.save_model: + torch.save(model.state_dict(), os.path.join(run_dir, "final_model.pt")) + print(f"Model saved to {run_dir}") + + +if __name__ == "__main__": + main() |
