summaryrefslogtreecommitdiff
path: root/files/experiments/cifar10_conv_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/experiments/cifar10_conv_experiment.py')
-rw-r--r--files/experiments/cifar10_conv_experiment.py448
1 files changed, 448 insertions, 0 deletions
diff --git a/files/experiments/cifar10_conv_experiment.py b/files/experiments/cifar10_conv_experiment.py
new file mode 100644
index 0000000..a582f9f
--- /dev/null
+++ b/files/experiments/cifar10_conv_experiment.py
@@ -0,0 +1,448 @@
+"""
+CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization.
+
+Uses proper convolutional architecture that preserves spatial structure.
+Tests whether Lyapunov regularization helps train deeper Conv-SNNs.
+
+Architecture:
+ Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output
+
+Usage:
+ python files/experiments/cifar10_conv_experiment.py --model simple --T 25
+ python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+from files.models.conv_snn import create_conv_snn
+
+
+@dataclass
+class EpochMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ val_loss: float
+ val_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ time_sec: float
+
+
+def get_cifar10_loaders(
+ data_dir: str = './data',
+ batch_size: int = 128,
+ num_workers: int = 4,
+) -> Tuple[DataLoader, DataLoader]:
+ """
+ Get CIFAR-10 dataloaders with standard normalization.
+
+ Images normalized to [0, 1] for rate encoding.
+ """
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ # Note: For rate encoding, we keep values in [0, 1]
+ # No normalization to negative values
+ ])
+
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ train_dataset = datasets.CIFAR10(
+ root=data_dir, train=True, download=True, transform=transform_train
+ )
+ test_dataset = datasets.CIFAR10(
+ root=data_dir, train=False, download=True, transform=transform_test
+ )
+
+ train_loader = DataLoader(
+ train_dataset, batch_size=batch_size, shuffle=True,
+ num_workers=num_workers, pin_memory=True
+ )
+ test_loader = DataLoader(
+ test_dataset, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True
+ )
+
+ return train_loader, test_loader
+
+
+def train_epoch(
+ model: nn.Module,
+ loader: DataLoader,
+ optimizer: optim.Optimizer,
+ ce_loss: nn.Module,
+ device: torch.device,
+ use_lyapunov: bool,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ progress: bool = True,
+) -> Tuple[float, float, Optional[float], float]:
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32)
+
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(
+ x,
+ compute_lyapunov=use_lyapunov,
+ lyap_eps=lyap_eps,
+ )
+
+ ce = ce_loss(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan')
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+ optimizer.step()
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ if progress:
+ iterator.set_postfix({
+ "loss": f"{loss.item():.3f}",
+ "acc": f"{total_correct/total_samples:.3f}",
+ })
+
+ return (
+ total_loss / total_samples,
+ total_correct / total_samples,
+ np.mean(lyap_vals) if lyap_vals else None,
+ np.mean(grad_norms),
+ )
+
+
+@torch.no_grad()
+def evaluate(
+ model: nn.Module,
+ loader: DataLoader,
+ ce_loss: nn.Module,
+ device: torch.device,
+ progress: bool = True,
+) -> Tuple[float, float]:
+ """Evaluate on test set."""
+ model.eval()
+ total_loss = 0.0
+ total_correct = 0
+ total_samples = 0
+
+ iterator = tqdm(loader, desc="eval", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+
+ loss = ce_loss(logits, y)
+ total_loss += loss.item() * x.size(0)
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ return total_loss / total_samples, total_correct / total_samples
+
+
+def run_experiment(
+ model_type: str,
+ channels: List[int],
+ T: int,
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ lyap_eps: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+) -> List[EpochMetrics]:
+ """Run single experiment."""
+ torch.manual_seed(seed)
+
+ model = create_conv_snn(
+ model_type=model_type,
+ in_channels=3,
+ num_classes=10,
+ channels=channels,
+ T=T,
+ encoding='rate',
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f" Model: {model_type}, params: {num_params:,}")
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ ce_loss = nn.CrossEntropyLoss()
+
+ metrics_history = []
+ best_acc = 0.0
+
+ for epoch in range(1, epochs + 1):
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm = train_epoch(
+ model, train_loader, optimizer, ce_loss, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_acc = max(best_acc, test_acc)
+
+ metrics = EpochMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ val_loss=test_loss,
+ val_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ time_sec=dt,
+ )
+ metrics_history.append(metrics)
+
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)")
+
+ if np.isnan(train_loss):
+ print(" Training diverged!")
+ break
+
+ print(f" Best test accuracy: {best_acc:.3f}")
+ return metrics_history
+
+
+def run_comparison(
+ model_type: str,
+ channels_configs: List[List[int]],
+ T: int,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool,
+) -> Dict:
+ """Compare vanilla vs Lyapunov across different depths."""
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ for channels in channels_configs:
+ depth = len(channels)
+ print(f"\n{'='*60}")
+ print(f"Depth = {depth} conv layers, channels = {channels}")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+ print(f"\n Training {method.upper()}...")
+
+ metrics = run_experiment(
+ model_type=model_type,
+ channels=channels,
+ T=T,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=1e-4,
+ device=device,
+ seed=seed,
+ progress=progress,
+ )
+
+ results[method][depth] = metrics
+
+ return results
+
+
+def print_summary(results: Dict):
+ """Print comparison summary."""
+ print("\n" + "=" * 70)
+ print("SUMMARY: CIFAR-10 Conv-SNN Results")
+ print("=" * 70)
+ print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}")
+ print("-" * 70)
+
+ depths = sorted(results["vanilla"].keys())
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}"
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+
+ print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}")
+
+ print("=" * 70)
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ """Save results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, metrics_list in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in metrics_list]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser()
+
+ # Model
+ p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"])
+ p.add_argument("--channels", type=int, nargs="+", default=None,
+ help="Channel sizes (default: test multiple depths)")
+ p.add_argument("--T", type=int, default=25, help="Timesteps")
+
+ # Training
+ p.add_argument("--epochs", type=int, default=50)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+
+ # Lyapunov
+ p.add_argument("--lambda_reg", type=float, default=0.3)
+ p.add_argument("--lambda_target", type=float, default=-0.1)
+
+ # Other
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--out_dir", type=str, default="runs/cifar10_conv")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 70)
+ print("CIFAR-10 Conv-SNN Experiment")
+ print("=" * 70)
+ print(f"Model: {args.model}")
+ print(f"Timesteps: {args.T}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Device: {device}")
+ print("=" * 70)
+
+ # Load data
+ print("\nLoading CIFAR-10...")
+ train_loader, test_loader = get_cifar10_loaders(
+ data_dir=args.data_dir,
+ batch_size=args.batch_size,
+ )
+ print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
+
+ # Define depth configurations to test
+ if args.channels:
+ channels_configs = [args.channels]
+ else:
+ # Test increasing depths
+ channels_configs = [
+ [64, 128], # 2 conv layers (shallow)
+ [64, 128, 256], # 3 conv layers
+ [64, 128, 256, 512], # 4 conv layers (deep)
+ ]
+
+ # Run comparison
+ results = run_comparison(
+ model_type=args.model,
+ channels_configs=channels_configs,
+ T=args.T,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Summary
+ print_summary(results)
+
+ # Save
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()