""" Grid Search: Multiplier-Scaled Regularization Experiments Tests the new multiplier-scaled regularization approach: loss = (λ_reg × g(relu(λ))) × relu(λ) Where g(x) is the multiplier scaling function: - mult_linear: g(x) = x → loss = λ_reg × relu(λ)² - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³ - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ) Grid: - λ_reg: 0.01, 0.05, 0.1, 0.3 - reg_type: mult_linear, mult_squared, mult_log - depth: specified via command line Usage: python scaled_reg_grid_search.py --depth 4 python scaled_reg_grid_search.py --depth 8 python scaled_reg_grid_search.py --depth 12 """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Optional from itertools import product _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from tqdm.auto import tqdm # Import from main benchmark from depth_scaling_benchmark import ( SpikingVGG, compute_lyap_reg_loss, ) import snntorch as snn from snntorch import surrogate @dataclass class ExperimentResult: depth: int reg_type: str lambda_reg: float vanilla_acc: float lyapunov_acc: float final_lyap: Optional[float] delta: float def train_epoch(model, loader, optimizer, criterion, device, use_lyapunov, lambda_reg, reg_type, progress=False): """Train one epoch.""" model.train() total_loss = 0.0 correct = 0 total = 0 lyap_vals = [] iterator = tqdm(loader, desc="train", leave=False) if progress else loader for x, y in iterator: x, y = x.to(device), y.to(device) optimizer.zero_grad() logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4) loss = criterion(logits, y) if use_lyapunov and lyap_est is not None: # Target is implicitly 0 for scaled reg types lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0) loss = loss + lambda_reg * lyap_reg lyap_vals.append(lyap_est.item()) loss.backward() optimizer.step() total_loss += loss.item() * x.size(0) _, pred = logits.max(1) correct += pred.eq(y).sum().item() total += x.size(0) avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None return total_loss / total, correct / total, avg_lyap def evaluate(model, loader, device): """Evaluate model.""" model.eval() correct = 0 total = 0 with torch.no_grad(): for x, y in loader: x, y = x.to(device), y.to(device) logits, _, _ = model(x, compute_lyapunov=False) _, pred = logits.max(1) correct += pred.eq(y).sum().item() total += x.size(0) return correct / total def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader, device, epochs=100, lr=0.001): """Run a single experiment configuration.""" # Determine blocks per stage based on depth # depth = num_stages * blocks_per_stage, with num_stages=4 blocks_per_stage = depth // 4 print(f"\n{'='*60}") print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}") print(f"{'='*60}") # --- Run Vanilla baseline --- print(f" Training Vanilla...") model_v = SpikingVGG( num_classes=100, blocks_per_stage=blocks_per_stage, T=4, ).to(device) optimizer_v = optim.Adam(model_v.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs) best_vanilla = 0.0 for epoch in range(epochs): train_epoch(model_v, train_loader, optimizer_v, criterion, device, use_lyapunov=False, lambda_reg=0, reg_type="squared") scheduler_v.step() if (epoch + 1) % 10 == 0 or epoch == epochs - 1: acc = evaluate(model_v, test_loader, device) best_vanilla = max(best_vanilla, acc) print(f" Epoch {epoch+1:3d}: test={acc:.3f}") del model_v, optimizer_v, scheduler_v torch.cuda.empty_cache() # --- Run Lyapunov version --- print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...") model_l = SpikingVGG( num_classes=100, blocks_per_stage=blocks_per_stage, T=4, ).to(device) optimizer_l = optim.Adam(model_l.parameters(), lr=lr) scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs) best_lyap_acc = 0.0 final_lyap = None for epoch in range(epochs): _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device, use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type) scheduler_l.step() final_lyap = lyap if (epoch + 1) % 10 == 0 or epoch == epochs - 1: acc = evaluate(model_l, test_loader, device) best_lyap_acc = max(best_lyap_acc, acc) lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A" print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}") del model_l, optimizer_l, scheduler_l torch.cuda.empty_cache() delta = best_lyap_acc - best_vanilla result = ExperimentResult( depth=depth, reg_type=reg_type, lambda_reg=lambda_reg, vanilla_acc=best_vanilla, lyapunov_acc=best_lyap_acc, final_lyap=final_lyap, delta=delta, ) print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}") return result def main(): parser = argparse.ArgumentParser() parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12]) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--data_dir", type=str, default="./data") parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("=" * 70) print("SCALED REGULARIZATION GRID SEARCH") print("=" * 70) print(f"Depth: {args.depth}") print(f"Epochs: {args.epochs}") print(f"Device: {device}") if device.type == "cuda": print(f"GPU: {torch.cuda.get_device_name()}") print("=" * 70) # Grid parameters lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments") print(f"λ_reg values: {lambda_regs}") print(f"reg_types: {reg_types}") # Load data print(f"\nLoading CIFAR-100...") transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train) test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}") # Run grid search results = [] for lambda_reg, reg_type in product(lambda_regs, reg_types): result = run_single_experiment( depth=args.depth, reg_type=reg_type, lambda_reg=lambda_reg, train_loader=train_loader, test_loader=test_loader, device=device, epochs=args.epochs, lr=args.lr, ) results.append(result) # Print summary table print("\n" + "=" * 70) print(f"SUMMARY: DEPTH = {args.depth}") print("=" * 70) print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}") print("-" * 70) for r in results: lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A" delta_str = f"{r.delta:+.3f}" print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}") # Find best configuration best = max(results, key=lambda x: x.lyapunov_acc) print("-" * 70) print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})") # Save results os.makedirs(args.out_dir, exist_ok=True) out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json") with open(out_file, "w") as f: json.dump([asdict(r) for r in results], f, indent=2) print(f"\nResults saved to: {out_file}") print("\n" + "=" * 70) print("GRID SEARCH COMPLETE") print("=" * 70) if __name__ == "__main__": main()