summaryrefslogtreecommitdiff
path: root/files/experiments/depth_comparison.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/experiments/depth_comparison.py')
-rw-r--r--files/experiments/depth_comparison.py542
1 files changed, 542 insertions, 0 deletions
diff --git a/files/experiments/depth_comparison.py b/files/experiments/depth_comparison.py
new file mode 100644
index 0000000..48c62d8
--- /dev/null
+++ b/files/experiments/depth_comparison.py
@@ -0,0 +1,542 @@
+"""
+Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths.
+
+Hypothesis:
+- Shallow networks (1-2 layers): Both methods train successfully
+- Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds
+
+Usage:
+ # Quick test (synthetic data)
+ python files/experiments/depth_comparison.py --synthetic --epochs 20
+
+ # Full experiment with SHD data
+ python files/experiments/depth_comparison.py --epochs 50
+
+ # Specific depths to test
+ python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+from files.analysis.stability_monitor import StabilityMonitor
+
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for a single experiment run."""
+ depth: int
+ hidden_dim: int
+ use_lyapunov: bool
+ lambda_reg: float
+ lambda_target: float
+ lyap_eps: float
+ epochs: int
+ lr: float
+ batch_size: int
+ beta: float
+ threshold: float
+ seed: int
+
+
+@dataclass
+class EpochMetrics:
+ """Metrics collected per epoch."""
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ firing_rate: float
+ dead_neurons: float
+ time_sec: float
+
+
+def create_synthetic_data(
+ n_train: int = 2000,
+ n_val: int = 500,
+ T: int = 50,
+ D: int = 100,
+ n_classes: int = 10,
+ seed: int = 42,
+) -> Tuple[DataLoader, DataLoader]:
+ """Create synthetic spike data for testing."""
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ def generate_data(n_samples):
+ # Generate class-conditional spike patterns
+ x = torch.zeros(n_samples, T, D)
+ y = torch.randint(0, n_classes, (n_samples,))
+
+ for i in range(n_samples):
+ label = y[i].item()
+ # Each class has different firing rate pattern
+ base_rate = 0.05 + 0.02 * label
+ # Class-specific channels fire more
+ class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes))
+ for t in range(T):
+ # Background activity
+ x[i, t] = (torch.rand(D) < base_rate).float()
+ # Enhanced activity for class-specific channels
+ for c in class_channels:
+ if torch.rand(1) < base_rate * 3:
+ x[i, t, c] = 1.0
+
+ return x, y
+
+ x_train, y_train = generate_data(n_train)
+ x_val, y_val = generate_data(n_val)
+
+ train_loader = DataLoader(
+ TensorDataset(x_train, y_train),
+ batch_size=64,
+ shuffle=True,
+ )
+ val_loader = DataLoader(
+ TensorDataset(x_val, y_val),
+ batch_size=64,
+ shuffle=False,
+ )
+
+ return train_loader, val_loader, T, D, n_classes
+
+
+def create_model(
+ input_dim: int,
+ num_classes: int,
+ depth: int,
+ hidden_dim: int = 128,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+) -> LyapunovSNN:
+ """Create SNN with specified depth."""
+ # Create hidden dims list based on depth
+ # Gradually decrease size for deeper networks to keep param count reasonable
+ hidden_dims = []
+ current_dim = hidden_dim
+ for i in range(depth):
+ hidden_dims.append(current_dim)
+ # Optionally decrease dim in deeper layers
+ # current_dim = max(64, current_dim // 2)
+
+ return LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=beta,
+ threshold=threshold,
+ )
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ monitor: StabilityMonitor,
+) -> Tuple[float, float, float, float, float, float]:
+ """Train for one epoch, return metrics."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+ firing_rates = []
+ dead_fracs = []
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, recordings = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ record_states=True,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ # Check for NaN
+ if torch.isnan(loss):
+ return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0
+
+ loss.backward()
+
+ # Gradient clipping for stability comparison fairness
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ optimizer.step()
+
+ # Collect metrics
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ # Stability metrics
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ if recordings is not None:
+ spikes = recordings['spikes']
+ fr = spikes.mean().item()
+ dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item()
+ firing_rates.append(fr)
+ dead_fracs.append(dead)
+
+ avg_loss = total_loss / total_samples
+ avg_acc = total_correct / total_samples
+ avg_lyap = np.mean(lyap_vals) if lyap_vals else None
+ avg_grad = np.mean(grad_norms)
+ avg_fr = np.mean(firing_rates) if firing_rates else 0.0
+ avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0
+
+ return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+) -> Tuple[float, float]:
+ """Evaluate model on validation set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+
+ loss = ce_loss(logits, y)
+ if torch.isnan(loss):
+ return float('nan'), 0.0
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_single_experiment(
+ config: ExperimentConfig,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ device: torch.device,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run a single experiment with given configuration."""
+ torch.manual_seed(config.seed)
+
+ model = create_model(
+ input_dim=input_dim,
+ num_classes=num_classes,
+ depth=config.depth,
+ hidden_dim=config.hidden_dim,
+ beta=config.beta,
+ threshold=config.threshold,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=config.lr)
+ ce_loss = nn.CrossEntropyLoss()
+ monitor = StabilityMonitor()
+
+ metrics_history = []
+ method = "Lyapunov" if config.use_lyapunov else "Vanilla"
+
+ iterator = range(1, config.epochs + 1)
+ if progress:
+ iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False)
+
+ for epoch in iterator:
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch(
+ model=model,
+ loader=train_loader,
+ optimizer=optimizer,
+ ce_loss=ce_loss,
+ device=device,
+ use_lyapunov=config.use_lyapunov,
+ lambda_reg=config.lambda_reg,
+ lambda_target=config.lambda_target,
+ lyap_eps=config.lyap_eps,
+ monitor=monitor,
+ )
+
+ val_loss, val_acc = evaluate(model, val_loader, ce_loss, device)
+
+ dt = time.time() - t0
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=val_loss,
+ val_acc=val_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ firing_rate=fr,
+ dead_neurons=dead,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ # Early stopping if training diverged
+ if np.isnan(train_loss):
+ print(f" Training diverged at epoch {epoch}")
+ break
+
+ return metrics_history
+
+
+def run_depth_comparison(
+ depths: List[int],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ device: torch.device,
+ epochs: int = 30,
+ hidden_dim: int = 128,
+ lr: float = 1e-3,
+ lambda_reg: float = 0.1,
+ lambda_target: float = 0.0,
+ lyap_eps: float = 1e-4,
+ beta: float = 0.9,
+ seed: int = 42,
+ progress: bool = True,
+) -> Dict[str, Dict[int, List[EpochMetrics]]]:
+ """
+ Run comparison experiments across depths.
+
+ Returns:
+ Dictionary with structure:
+ {
+ "vanilla": {1: [metrics...], 2: [metrics...], ...},
+ "lyapunov": {1: [metrics...], 2: [metrics...], ...}
+ }
+ """
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for depth in depths:
+ print(f"\n{'='*50}")
+ print(f"Depth = {depth} layers")
+ print(f"{'='*50}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ config = ExperimentConfig(
+ depth=depth,
+ hidden_dim=hidden_dim,
+ use_lyapunov=use_lyap,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=lyap_eps,
+ epochs=epochs,
+ lr=lr,
+ batch_size=64,
+ beta=beta,
+ threshold=1.0,
+ seed=seed,
+ )
+
+ metrics = run_single_experiment(
+ config=config,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ device=device,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ # Print final metrics
+ final = metrics[-1]
+ lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A"
+ print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} "
+ f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}")
+
+ return results
+
+
+def save_results(results: Dict, output_dir: str, config: dict):
+ """Save experiment results to JSON."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Convert metrics to dicts
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]):
+ """Print summary comparison table."""
+ print("\n" + "=" * 70)
+ print("SUMMARY: Final Validation Accuracy by Depth")
+ print("=" * 70)
+ print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}")
+ print("-" * 70)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van_metrics = results["vanilla"][depth]
+ lyap_metrics = results["lyapunov"][depth]
+
+ van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0
+ lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0
+
+ van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED"
+ lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED"
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}")
+
+ print("=" * 70)
+
+ # Gradient analysis
+ print("\nGradient Norm Analysis (final epoch):")
+ print("-" * 70)
+ print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}")
+ print("-" * 70)
+ for depth in depths:
+ van_grad = results["vanilla"][depth][-1].grad_norm
+ lyap_grad = results["lyapunov"][depth][-1].grad_norm
+ print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths")
+ p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6],
+ help="Network depths to test")
+ p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer")
+ p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment")
+ p.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
+ p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight")
+ p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent")
+ p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov")
+ p.add_argument("--beta", type=float, default=0.9, help="Membrane decay")
+ p.add_argument("--seed", type=int, default=42, help="Random seed")
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing")
+ p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config")
+ p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--no-progress", action="store_true", help="Disable progress bars")
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print("Experiment: Vanilla vs Lyapunov-Regularized SNN")
+ print("=" * 70)
+ print(f"Depths: {args.depths}")
+ print(f"Hidden dim: {args.hidden_dim}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Lambda_reg: {args.lambda_reg}")
+ print(f"Device: {device}")
+
+ # Load data
+ if args.synthetic:
+ print("\nUsing SYNTHETIC data for quick testing")
+ train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed)
+ else:
+ print(f"\nLoading data from {args.cfg}")
+ from files.data_io.dataset_loader import get_dataloader
+ train_loader, val_loader = get_dataloader(args.cfg)
+ xb, _ = next(iter(train_loader))
+ _, T, D = xb.shape
+ C = 20 # SHD has 20 classes
+
+ print(f"Data: T={T}, D={D}, classes={C}")
+
+ # Run experiments
+ results = run_depth_comparison(
+ depths=args.depths,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=D,
+ num_classes=C,
+ device=device,
+ epochs=args.epochs,
+ hidden_dim=args.hidden_dim,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ lyap_eps=args.lyap_eps,
+ beta=args.beta,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Print summary
+ print_summary(results)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()