diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/depth_comparison.py | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/experiments/depth_comparison.py')
| -rw-r--r-- | files/experiments/depth_comparison.py | 542 |
1 files changed, 542 insertions, 0 deletions
diff --git a/files/experiments/depth_comparison.py b/files/experiments/depth_comparison.py new file mode 100644 index 0000000..48c62d8 --- /dev/null +++ b/files/experiments/depth_comparison.py @@ -0,0 +1,542 @@ +""" +Experiment: Compare Vanilla vs Lyapunov-Regularized SNN across network depths. + +Hypothesis: +- Shallow networks (1-2 layers): Both methods train successfully +- Deep networks (4+ layers): Vanilla fails (gradient issues), Lyapunov succeeds + +Usage: + # Quick test (synthetic data) + python files/experiments/depth_comparison.py --synthetic --epochs 20 + + # Full experiment with SHD data + python files/experiments/depth_comparison.py --epochs 50 + + # Specific depths to test + python files/experiments/depth_comparison.py --depths 1 2 4 6 8 --epochs 30 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN +from files.analysis.stability_monitor import StabilityMonitor + + +@dataclass +class ExperimentConfig: + """Configuration for a single experiment run.""" + depth: int + hidden_dim: int + use_lyapunov: bool + lambda_reg: float + lambda_target: float + lyap_eps: float + epochs: int + lr: float + batch_size: int + beta: float + threshold: float + seed: int + + +@dataclass +class EpochMetrics: + """Metrics collected per epoch.""" + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + firing_rate: float + dead_neurons: float + time_sec: float + + +def create_synthetic_data( + n_train: int = 2000, + n_val: int = 500, + T: int = 50, + D: int = 100, + n_classes: int = 10, + seed: int = 42, +) -> Tuple[DataLoader, DataLoader]: + """Create synthetic spike data for testing.""" + torch.manual_seed(seed) + np.random.seed(seed) + + def generate_data(n_samples): + # Generate class-conditional spike patterns + x = torch.zeros(n_samples, T, D) + y = torch.randint(0, n_classes, (n_samples,)) + + for i in range(n_samples): + label = y[i].item() + # Each class has different firing rate pattern + base_rate = 0.05 + 0.02 * label + # Class-specific channels fire more + class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) + for t in range(T): + # Background activity + x[i, t] = (torch.rand(D) < base_rate).float() + # Enhanced activity for class-specific channels + for c in class_channels: + if torch.rand(1) < base_rate * 3: + x[i, t, c] = 1.0 + + return x, y + + x_train, y_train = generate_data(n_train) + x_val, y_val = generate_data(n_val) + + train_loader = DataLoader( + TensorDataset(x_train, y_train), + batch_size=64, + shuffle=True, + ) + val_loader = DataLoader( + TensorDataset(x_val, y_val), + batch_size=64, + shuffle=False, + ) + + return train_loader, val_loader, T, D, n_classes + + +def create_model( + input_dim: int, + num_classes: int, + depth: int, + hidden_dim: int = 128, + beta: float = 0.9, + threshold: float = 1.0, +) -> LyapunovSNN: + """Create SNN with specified depth.""" + # Create hidden dims list based on depth + # Gradually decrease size for deeper networks to keep param count reasonable + hidden_dims = [] + current_dim = hidden_dim + for i in range(depth): + hidden_dims.append(current_dim) + # Optionally decrease dim in deeper layers + # current_dim = max(64, current_dim // 2) + + return LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=beta, + threshold=threshold, + ) + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + monitor: StabilityMonitor, +) -> Tuple[float, float, float, float, float, float]: + """Train for one epoch, return metrics.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + firing_rates = [] + dead_fracs = [] + + for x, y in loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, recordings = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + record_states=True, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + # Check for NaN + if torch.isnan(loss): + return float('nan'), 0.0, float('nan'), float('nan'), 0.0, 1.0 + + loss.backward() + + # Gradient clipping for stability comparison fairness + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + optimizer.step() + + # Collect metrics + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + # Stability metrics + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + if recordings is not None: + spikes = recordings['spikes'] + fr = spikes.mean().item() + dead = (spikes.sum(dim=1).mean(dim=0) < 0.01).float().mean().item() + firing_rates.append(fr) + dead_fracs.append(dead) + + avg_loss = total_loss / total_samples + avg_acc = total_correct / total_samples + avg_lyap = np.mean(lyap_vals) if lyap_vals else None + avg_grad = np.mean(grad_norms) + avg_fr = np.mean(firing_rates) if firing_rates else 0.0 + avg_dead = np.mean(dead_fracs) if dead_fracs else 0.0 + + return avg_loss, avg_acc, avg_lyap, avg_grad, avg_fr, avg_dead + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """Evaluate model on validation set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + for x, y in loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + + loss = ce_loss(logits, y) + if torch.isnan(loss): + return float('nan'), 0.0 + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_single_experiment( + config: ExperimentConfig, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + device: torch.device, + progress: bool = True, +) -> List[EpochMetrics]: + """Run a single experiment with given configuration.""" + torch.manual_seed(config.seed) + + model = create_model( + input_dim=input_dim, + num_classes=num_classes, + depth=config.depth, + hidden_dim=config.hidden_dim, + beta=config.beta, + threshold=config.threshold, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=config.lr) + ce_loss = nn.CrossEntropyLoss() + monitor = StabilityMonitor() + + metrics_history = [] + method = "Lyapunov" if config.use_lyapunov else "Vanilla" + + iterator = range(1, config.epochs + 1) + if progress: + iterator = tqdm(iterator, desc=f"Depth={config.depth} {method}", leave=False) + + for epoch in iterator: + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, fr, dead = train_epoch( + model=model, + loader=train_loader, + optimizer=optimizer, + ce_loss=ce_loss, + device=device, + use_lyapunov=config.use_lyapunov, + lambda_reg=config.lambda_reg, + lambda_target=config.lambda_target, + lyap_eps=config.lyap_eps, + monitor=monitor, + ) + + val_loss, val_acc = evaluate(model, val_loader, ce_loss, device) + + dt = time.time() - t0 + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=val_loss, + val_acc=val_acc, + lyapunov=lyap, + grad_norm=grad_norm, + firing_rate=fr, + dead_neurons=dead, + time_sec=dt, + ) + metrics_history.append(metrics) + + # Early stopping if training diverged + if np.isnan(train_loss): + print(f" Training diverged at epoch {epoch}") + break + + return metrics_history + + +def run_depth_comparison( + depths: List[int], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + device: torch.device, + epochs: int = 30, + hidden_dim: int = 128, + lr: float = 1e-3, + lambda_reg: float = 0.1, + lambda_target: float = 0.0, + lyap_eps: float = 1e-4, + beta: float = 0.9, + seed: int = 42, + progress: bool = True, +) -> Dict[str, Dict[int, List[EpochMetrics]]]: + """ + Run comparison experiments across depths. + + Returns: + Dictionary with structure: + { + "vanilla": {1: [metrics...], 2: [metrics...], ...}, + "lyapunov": {1: [metrics...], 2: [metrics...], ...} + } + """ + results = {"vanilla": {}, "lyapunov": {}} + + for depth in depths: + print(f"\n{'='*50}") + print(f"Depth = {depth} layers") + print(f"{'='*50}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + config = ExperimentConfig( + depth=depth, + hidden_dim=hidden_dim, + use_lyapunov=use_lyap, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=lyap_eps, + epochs=epochs, + lr=lr, + batch_size=64, + beta=beta, + threshold=1.0, + seed=seed, + ) + + metrics = run_single_experiment( + config=config, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + device=device, + progress=progress, + ) + + results[method][depth] = metrics + + # Print final metrics + final = metrics[-1] + lyap_str = f"λ={final.lyapunov:.3f}" if final.lyapunov else "λ=N/A" + print(f" Final: loss={final.train_loss:.4f} acc={final.train_acc:.3f} " + f"val_acc={final.val_acc:.3f} {lyap_str} ∇={final.grad_norm:.2f}") + + return results + + +def save_results(results: Dict, output_dir: str, config: dict): + """Save experiment results to JSON.""" + os.makedirs(output_dir, exist_ok=True) + + # Convert metrics to dicts + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def print_summary(results: Dict[str, Dict[int, List[EpochMetrics]]]): + """Print summary comparison table.""" + print("\n" + "=" * 70) + print("SUMMARY: Final Validation Accuracy by Depth") + print("=" * 70) + print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Difference':<15}") + print("-" * 70) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van_metrics = results["vanilla"][depth] + lyap_metrics = results["lyapunov"][depth] + + van_acc = van_metrics[-1].val_acc if not np.isnan(van_metrics[-1].val_acc) else 0.0 + lyap_acc = lyap_metrics[-1].val_acc if not np.isnan(lyap_metrics[-1].val_acc) else 0.0 + + van_str = f"{van_acc:.3f}" if not np.isnan(van_metrics[-1].train_loss) else "DIVERGED" + lyap_str = f"{lyap_acc:.3f}" if not np.isnan(lyap_metrics[-1].train_loss) else "DIVERGED" + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") + + print("=" * 70) + + # Gradient analysis + print("\nGradient Norm Analysis (final epoch):") + print("-" * 70) + print(f"{'Depth':<8} {'Vanilla ∇':<15} {'Lyapunov ∇':<15}") + print("-" * 70) + for depth in depths: + van_grad = results["vanilla"][depth][-1].grad_norm + lyap_grad = results["lyapunov"][depth][-1].grad_norm + print(f"{depth:<8} {van_grad:<15.2f} {lyap_grad:<15.2f}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Compare Vanilla vs Lyapunov SNN across depths") + p.add_argument("--depths", type=int, nargs="+", default=[1, 2, 3, 4, 6], + help="Network depths to test") + p.add_argument("--hidden_dim", type=int, default=128, help="Hidden dimension per layer") + p.add_argument("--epochs", type=int, default=30, help="Training epochs per experiment") + p.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + p.add_argument("--lambda_reg", type=float, default=0.1, help="Lyapunov regularization weight") + p.add_argument("--lambda_target", type=float, default=0.0, help="Target Lyapunov exponent") + p.add_argument("--lyap_eps", type=float, default=1e-4, help="Perturbation for Lyapunov") + p.add_argument("--beta", type=float, default=0.9, help="Membrane decay") + p.add_argument("--seed", type=int, default=42, help="Random seed") + p.add_argument("--synthetic", action="store_true", help="Use synthetic data for quick testing") + p.add_argument("--cfg", type=str, default="data_io/configs/shd.yaml", help="Dataset config") + p.add_argument("--out_dir", type=str, default="runs/depth_comparison", help="Output directory") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--no-progress", action="store_true", help="Disable progress bars") + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print("Experiment: Vanilla vs Lyapunov-Regularized SNN") + print("=" * 70) + print(f"Depths: {args.depths}") + print(f"Hidden dim: {args.hidden_dim}") + print(f"Epochs: {args.epochs}") + print(f"Lambda_reg: {args.lambda_reg}") + print(f"Device: {device}") + + # Load data + if args.synthetic: + print("\nUsing SYNTHETIC data for quick testing") + train_loader, val_loader, T, D, C = create_synthetic_data(seed=args.seed) + else: + print(f"\nLoading data from {args.cfg}") + from files.data_io.dataset_loader import get_dataloader + train_loader, val_loader = get_dataloader(args.cfg) + xb, _ = next(iter(train_loader)) + _, T, D = xb.shape + C = 20 # SHD has 20 classes + + print(f"Data: T={T}, D={D}, classes={C}") + + # Run experiments + results = run_depth_comparison( + depths=args.depths, + train_loader=train_loader, + val_loader=val_loader, + input_dim=D, + num_classes=C, + device=device, + epochs=args.epochs, + hidden_dim=args.hidden_dim, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + lyap_eps=args.lyap_eps, + beta=args.beta, + seed=args.seed, + progress=not args.no_progress, + ) + + # Print summary + print_summary(results) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() |
