""" CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization. Uses proper convolutional architecture that preserves spatial structure. Tests whether Lyapunov regularization helps train deeper Conv-SNNs. Architecture: Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output Usage: python files/experiments/cifar10_conv_experiment.py --model simple --T 25 python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Optional, Tuple _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from tqdm.auto import tqdm from files.models.conv_snn import create_conv_snn @dataclass class EpochMetrics: epoch: int train_loss: float train_acc: float val_loss: float val_acc: float lyapunov: Optional[float] grad_norm: float time_sec: float def get_cifar10_loaders( data_dir: str = './data', batch_size: int = 128, num_workers: int = 4, ) -> Tuple[DataLoader, DataLoader]: """ Get CIFAR-10 dataloaders with standard normalization. Images normalized to [0, 1] for rate encoding. """ transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # Note: For rate encoding, we keep values in [0, 1] # No normalization to negative values ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = datasets.CIFAR10( root=data_dir, train=True, download=True, transform=transform_train ) test_dataset = datasets.CIFAR10( root=data_dir, train=False, download=True, transform=transform_test ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True ) return train_loader, test_loader def train_epoch( model: nn.Module, loader: DataLoader, optimizer: optim.Optimizer, ce_loss: nn.Module, device: torch.device, use_lyapunov: bool, lambda_reg: float, lambda_target: float, lyap_eps: float, progress: bool = True, ) -> Tuple[float, float, Optional[float], float]: """Train one epoch.""" model.train() total_loss = 0.0 total_correct = 0 total_samples = 0 lyap_vals = [] grad_norms = [] iterator = tqdm(loader, desc="train", leave=False) if progress else loader for x, y in iterator: x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32) optimizer.zero_grad() logits, lyap_est, _ = model( x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps, ) ce = ce_loss(logits, y) if use_lyapunov and lyap_est is not None: reg = (lyap_est - lambda_target) ** 2 loss = ce + lambda_reg * reg lyap_vals.append(lyap_est.item()) else: loss = ce if torch.isnan(loss): return float('nan'), 0.0, None, float('nan') loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) optimizer.step() grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 grad_norms.append(grad_norm) total_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) if progress: iterator.set_postfix({ "loss": f"{loss.item():.3f}", "acc": f"{total_correct/total_samples:.3f}", }) return ( total_loss / total_samples, total_correct / total_samples, np.mean(lyap_vals) if lyap_vals else None, np.mean(grad_norms), ) @torch.no_grad() def evaluate( model: nn.Module, loader: DataLoader, ce_loss: nn.Module, device: torch.device, progress: bool = True, ) -> Tuple[float, float]: """Evaluate on test set.""" model.eval() total_loss = 0.0 total_correct = 0 total_samples = 0 iterator = tqdm(loader, desc="eval", leave=False) if progress else loader for x, y in iterator: x, y = x.to(device), y.to(device) logits, _, _ = model(x, compute_lyapunov=False) loss = ce_loss(logits, y) total_loss += loss.item() * x.size(0) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) return total_loss / total_samples, total_correct / total_samples def run_experiment( model_type: str, channels: List[int], T: int, use_lyapunov: bool, train_loader: DataLoader, test_loader: DataLoader, epochs: int, lr: float, lambda_reg: float, lambda_target: float, lyap_eps: float, device: torch.device, seed: int, progress: bool = True, ) -> List[EpochMetrics]: """Run single experiment.""" torch.manual_seed(seed) model = create_conv_snn( model_type=model_type, in_channels=3, num_classes=10, channels=channels, T=T, encoding='rate', ).to(device) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Model: {model_type}, params: {num_params:,}") optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) ce_loss = nn.CrossEntropyLoss() metrics_history = [] best_acc = 0.0 for epoch in range(1, epochs + 1): t0 = time.time() train_loss, train_acc, lyap, grad_norm = train_epoch( model, train_loader, optimizer, ce_loss, device, use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress ) test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress) scheduler.step() dt = time.time() - t0 best_acc = max(best_acc, test_acc) metrics = EpochMetrics( epoch=epoch, train_loss=train_loss, train_acc=train_acc, val_loss=test_loss, val_acc=test_acc, lyapunov=lyap, grad_norm=grad_norm, time_sec=dt, ) metrics_history.append(metrics) lyap_str = f"λ={lyap:.3f}" if lyap else "" print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)") if np.isnan(train_loss): print(" Training diverged!") break print(f" Best test accuracy: {best_acc:.3f}") return metrics_history def run_comparison( model_type: str, channels_configs: List[List[int]], T: int, train_loader: DataLoader, test_loader: DataLoader, epochs: int, lr: float, lambda_reg: float, lambda_target: float, device: torch.device, seed: int, progress: bool, ) -> Dict: """Compare vanilla vs Lyapunov across different depths.""" results = {"vanilla": {}, "lyapunov": {}} for channels in channels_configs: depth = len(channels) print(f"\n{'='*60}") print(f"Depth = {depth} conv layers, channels = {channels}") print(f"{'='*60}") for use_lyap in [False, True]: method = "lyapunov" if use_lyap else "vanilla" print(f"\n Training {method.upper()}...") metrics = run_experiment( model_type=model_type, channels=channels, T=T, use_lyapunov=use_lyap, train_loader=train_loader, test_loader=test_loader, epochs=epochs, lr=lr, lambda_reg=lambda_reg, lambda_target=lambda_target, lyap_eps=1e-4, device=device, seed=seed, progress=progress, ) results[method][depth] = metrics return results def print_summary(results: Dict): """Print comparison summary.""" print("\n" + "=" * 70) print("SUMMARY: CIFAR-10 Conv-SNN Results") print("=" * 70) print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}") print("-" * 70) depths = sorted(results["vanilla"].keys()) for depth in depths: van = results["vanilla"][depth][-1] lyap = results["lyapunov"][depth][-1] van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 diff = lyap_acc - van_acc diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") print("=" * 70) def save_results(results: Dict, output_dir: str, config: Dict): """Save results.""" os.makedirs(output_dir, exist_ok=True) serializable = {} for method, depth_results in results.items(): serializable[method] = {} for depth, metrics_list in depth_results.items(): serializable[method][str(depth)] = [asdict(m) for m in metrics_list] with open(os.path.join(output_dir, "results.json"), "w") as f: json.dump(serializable, f, indent=2) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f"\nResults saved to {output_dir}") def parse_args(): p = argparse.ArgumentParser() # Model p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"]) p.add_argument("--channels", type=int, nargs="+", default=None, help="Channel sizes (default: test multiple depths)") p.add_argument("--T", type=int, default=25, help="Timesteps") # Training p.add_argument("--epochs", type=int, default=50) p.add_argument("--batch_size", type=int, default=128) p.add_argument("--lr", type=float, default=1e-3) # Lyapunov p.add_argument("--lambda_reg", type=float, default=0.3) p.add_argument("--lambda_target", type=float, default=-0.1) # Other p.add_argument("--data_dir", type=str, default="./data") p.add_argument("--out_dir", type=str, default="runs/cifar10_conv") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--seed", type=int, default=42) p.add_argument("--no-progress", action="store_true") return p.parse_args() def main(): args = parse_args() device = torch.device(args.device) print("=" * 70) print("CIFAR-10 Conv-SNN Experiment") print("=" * 70) print(f"Model: {args.model}") print(f"Timesteps: {args.T}") print(f"Epochs: {args.epochs}") print(f"Device: {device}") print("=" * 70) # Load data print("\nLoading CIFAR-10...") train_loader, test_loader = get_cifar10_loaders( data_dir=args.data_dir, batch_size=args.batch_size, ) print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") # Define depth configurations to test if args.channels: channels_configs = [args.channels] else: # Test increasing depths channels_configs = [ [64, 128], # 2 conv layers (shallow) [64, 128, 256], # 3 conv layers [64, 128, 256, 512], # 4 conv layers (deep) ] # Run comparison results = run_comparison( model_type=args.model, channels_configs=channels_configs, T=args.T, train_loader=train_loader, test_loader=test_loader, epochs=args.epochs, lr=args.lr, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, device=device, seed=args.seed, progress=not args.no_progress, ) # Summary print_summary(results) # Save ts = time.strftime("%Y%m%d-%H%M%S") output_dir = os.path.join(args.out_dir, ts) save_results(results, output_dir, vars(args)) if __name__ == "__main__": main()