""" Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths. Hypothesis: - Shallow networks (1-2 layers): Both methods train successfully - Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds Usage: # Quick test (synthetic data) python files/experiments/depth_comparison.py --synthetic --epochs 20 # Full experiment with SHD data python files/experiments/depth_comparison.py --epochs 50 # Specific depths to test python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30 """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Optional, Tuple _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from tqdm.auto import tqdm from files.models.snn_snntorch import LyapunovSNN from files.analysis.stability_monitor import StabilityMonitor @dataclass class ExperimentConfig: """Configuration for a single experiment run.""" depth: int hidden_dim: int use_lyapunov: bool lambda_reg: float lambda_target: float lyap_eps: float epochs: int lr: float batch_size: int beta: float threshold: float seed: int @dataclass class EpochMetrics: """Metrics collected per epoch.""" epoch: int train_loss: float train_acc: float val_loss: float val_acc: float lyapunov: Optional[float] grad_norm: float firing_rate: float dead_neurons: float time_sec: float def create_synthetic_data( n_train: int = 2000, n_val: int = 500, T: int = 50, D: int = 100, n_classes: int = 10, seed: int = 42, ) -> Tuple[DataLoader, DataLoader]: """Create synthetic spike data for testing.""" torch.manual_seed(seed) np.random.seed(seed) def generate_data(n_samples): # Generate class-conditional spike patterns x = torch.zeros(n_samples, T, D) y = torch.randint(0, n_classes, (n_samples,)) for i in range(n_samples): label = y[i].item() # Each class has different firing rate pattern base_rate = 0.05 + 0.02 * label # Class-specific channels fire more class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) for t in range(T): # Background activity x[i, t] = (torch.rand(D) < base_rate).float() # Enhanced activity for class-specific channels for c in class_channels: if torch.rand(1) < base_rate * 3: x[i, t, c] = 1.0 return x, y x_train, y_train = generate_data(n_train) x_val, y_val = generate_data(n_val) train_loader = DataLoader( TensorDataset(x_train, y_train), batch_size=64, shuffle=True, ) val_loader = DataLoader( TensorDataset(x_val, y_val), batch_size=64, shuffle=False, ) return train_loader, val_loader, T, D, n_classes def create_model( input_dim: int, num_classes: int, depth: int, hidden_dim: int = 128, beta: float = 0.9, threshold: float = 1.0, ) -> LyapunovSNN: """Create SNN with specified depth.""" # Create hidden dims list based on depth # Gradually decrease size for deeper networks to keep param count reasonable hidden_dims = [] current_dim = hidden_dim for i in range(depth): hidden_dims.append(current_dim) # Optionally decrease dim in deeper layers # current_dim = max(64, current_dim // 2) return LyapunovSNN( input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes, beta=beta, threshold=threshold, ) def train_epoch( model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, ce_loss: nn.Module, device: torch.device, use_lyapunov: bool, lambda_reg: float, lambda_target: float, lyap_eps: float, monitor: StabilityMonitor, ) -> Tuple[float, float, float, float, float, float]: """Train for one epoch, return metrics.""" model.train() total_loss = 0.0 total_correct = 0 total_samples = 0 lyap_vals = [] grad_norms = [] firing_rates = [] dead_fracs = [] for x, y in loader: x, y = x.to(device), y.to(device) optimizer.zero_grad() logits, lyap_est, recordings = model( x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps, record_states=True, ) ce = ce_loss(logits, y) if use_lyapunov and lyap_est is not None: reg = (lyap_est - lambda_target) ** 2 loss = ce + lambda_reg * reg lyap_vals.append(lyap_est.item()) else: loss = ce # Check for NaN if torch.isnan(loss): return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0 loss.backward() # Gradient clipping for stability comparison fairness torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) optimizer.step() # Collect metrics total_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) # Stability metrics grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 grad_norms.append(grad_norm) if recordings is not None: spikes = recordings['spikes'] fr = spikes.mean().item() dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item() firing_rates.append(fr) dead_fracs.append(dead) avg_loss = total_loss / total_samples avg_acc = total_correct / total_samples avg_lyap = np.mean(lyap_vals) if lyap_vals else None avg_grad = np.mean(grad_norms) avg_fr = np.mean(firing_rates) if firing_rates else 0.0 avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0 return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead @torch.no_grad() def evaluate( model: nn.Module, loader: DataLoader, ce_loss: nn.Module, device: torch.device, ) -> Tuple[float, float]: """Evaluate model on validation set.""" model.eval() total_loss = 0.0 total_correct = 0 total_samples = 0 for x, y in loader: x, y = x.to(device), y.to(device) logits, _, _ = model(x, compute_lyapunov=False, record_states=False) loss = ce_loss(logits, y) if torch.isnan(loss): return float('nan'), 0.0 total_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) return total_loss / total_samples, total_correct / total_samples def run_single_experiment( config: ExperimentConfig, train_loader: DataLoader, val_loader: DataLoader, input_dim: int, num_classes: int, device: torch.device, progress: bool = True, ) -> List[EpochMetrics]: """Run a single experiment with given configuration.""" torch.manual_seed(config.seed) model = create_model( input_dim=input_dim, num_classes=num_classes, depth=config.depth, hidden_dim=config.hidden_dim, beta=config.beta, threshold=config.threshold, ).to(device) optimizer = optim.Adam(model.parameters(), lr=config.lr) ce_loss = nn.CrossEntropyLoss() monitor = StabilityMonitor() metrics_history = [] method = "Lyapunov" if config.use_lyapunov else "Vanilla" iterator = range(1, config.epochs + 1) if progress: iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False) for epoch in iterator: t0 = time.time() train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch( model=model, loader=train_loader, optimizer=optimizer, ce_loss=ce_loss, device=device, use_lyapunov=config.use_lyapunov, lambda_reg=config.lambda_reg, lambda_target=config.lambda_target, lyap_eps=config.lyap_eps, monitor=monitor, ) val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) dt = time.time() - t0 metrics = EpochMetrics( epoch=epoch, train_loss=train_loss, train_acc=train_acc, val_loss=val_loss, val_acc=val_acc, lyapunov=lyap, grad_norm=grad_norm, firing_rate=fr, dead_neurons=dead, time_sec=dt, ) metrics_history.append(metrics) # Early stopping if training diverged if np.isnan(train_loss): print(f" Training diverged at epoch {epoch}") break return metrics_history def run_depth_comparison( depths: List[int], train_loader: DataLoader, val_loader: DataLoader, input_dim: int, num_classes: int, device: torch.device, epochs: int = 30, hidden_dim: int = 128, lr: float = 1e-3, lambda_reg: float = 0.1, lambda_target: float = 0.0, lyap_eps: float = 1e-4, beta: float = 0.9, seed: int = 42, progress: bool = True, ) -> Dict[str, Dict[int, List[EpochMetrics]]]: """ Run comparison experiments across depths. Returns: Dictionary with structure: { "vanilla": {1: [metrics...], 2: [metrics...], ...}, "lyapunov": {1: [metrics...], 2: [metrics...], ...} } """ results = {"vanilla": {}, "lyapunov": {}} for depth in depths: print(f"\n{'='*50}") print(f"Depth = {depth} layers") print(f"{'='*50}") for use_lyap in [False, True]: method = "lyapunov" if use_lyap else "vanilla" print(f"\n Training {method.upper()}...") config = ExperimentConfig( depth=depth, hidden_dim=hidden_dim, use_lyapunov=use_lyap, lambda_reg=lambda_reg, lambda_target=lambda_target, lyap_eps=lyap_eps, epochs=epochs, lr=lr, batch_size=64, beta=beta, threshold=1.0, seed=seed, ) metrics = run_single_experiment( config=config, train_loader=train_loader, val_loader=val_loader, input_dim=input_dim, num_classes=num_classes, device=device, progress=progress, ) results[method][depth] = metrics # Print final metrics final = metrics[-1] lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}") return results def save_results(results: Dict, output_dir: str, config: dict): """Save experiment results to JSON.""" os.makedirs(output_dir, exist_ok=True) # Convert metrics to dicts serializable = {} for method, depth_results in results.items(): serializable[method] = {} for depth, metrics_list in depth_results.items(): serializable[method][str(depth)] = [asdict(m) for m in metrics_list] with open(os.path.join(output_dir, "results.json"), "w") as f: json.dump(serializable, f, indent=2) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f"\nResults saved to {output_dir}") def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]): """Print summary comparison table.""" print("\n" + "=" * 70) print("SUMMARY: Final Validation Accuracy by Depth") print("=" * 70) print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}") print("-" * 70) depths = sorted(results["vanilla"].keys()) for depth in depths: van_metrics = results["vanilla"][depth] lyap_metrics = results["lyapunov"][depth] van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0 lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0 van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED" lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED" diff = lyap_acc - van_acc diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") print("=" * 70) # Gradient analysis print("\nGradient Norm Analysis (final epoch):") print("-" * 70) print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}") print("-" * 70) for depth in depths: van_grad = results["vanilla"][depth][-1].grad_norm lyap_grad = results["lyapunov"][depth][-1].grad_norm print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}") def parse_args(): p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths") p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6], help="Network depths to test") p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer") p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment") p.add_argument("--lr", type=float, default=1e-3, help="Learning rate") p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight") p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov") p.add_argument("--beta", type=float, default=0.9, help="Membrane decay") p.add_argument("--seed", type=int, default=42, help="Random seed") p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing") p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--no-progress", action="store_true", help="Disable progress bars") return p.parse_args() def main(): args = parse_args() device = torch.device(args.device) print("=" * 70) print("Experiment: Vanilla vs Lyapunov-Regularized SNN") print("=" * 70) print(f"Depths: {args.depths}") print(f"Hidden dim: {args.hidden_dim}") print(f"Epochs: {args.epochs}") print(f"Lambda_reg: {args.lambda_reg}") print(f"Device: {device}") # Load data if args.synthetic: print("\nUsing SYNTHETIC data for quick testing") train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed) else: print(f"\nLoading data from {args.cfg}") from files.data_io.dataset_loader import get_dataloader train_loader, val_loader = get_dataloader(args.cfg) xb, _ = next(iter(train_loader)) _, T, D = xb.shape C = 20 # SHD has 20 classes print(f"Data: T={T}, D={D}, classes={C}") # Run experiments results = run_depth_comparison( depths=args.depths, train_loader=train_loader, val_loader=val_loader, input_dim=D, num_classes=C, device=device, epochs=args.epochs, hidden_dim=args.hidden_dim, lr=args.lr, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, lyap_eps=args.lyap_eps, beta=args.beta, seed=args.seed, progress=not args.no_progress, ) # Print summary print_summary(results) # Save results ts = time.strftime("%Y%m%d-%H%M%S") output_dir = os.path.join(args.out_dir, ts) save_results(results, output_dir, vars(args)) if __name__ == "__main__": main()