diff options
Diffstat (limited to 'files/experiments/scaled_reg_grid_search.py')
| -rw-r--r-- | files/experiments/scaled_reg_grid_search.py | 301 |
1 files changed, 301 insertions, 0 deletions
diff --git a/files/experiments/scaled_reg_grid_search.py b/files/experiments/scaled_reg_grid_search.py new file mode 100644 index 0000000..928caff --- /dev/null +++ b/files/experiments/scaled_reg_grid_search.py @@ -0,0 +1,301 @@ +""" +Grid Search: Multiplier-Scaled Regularization Experiments + +Tests the new multiplier-scaled regularization approach: + loss = (λ_reg × g(relu(λ))) × relu(λ) + +Where g(x) is the multiplier scaling function: + - mult_linear: g(x) = x → loss = λ_reg × relu(λ)² + - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³ + - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ) + +Grid: + - λ_reg: 0.01, 0.05, 0.1, 0.3 + - reg_type: mult_linear, mult_squared, mult_log + - depth: specified via command line + +Usage: + python scaled_reg_grid_search.py --depth 4 + python scaled_reg_grid_search.py --depth 8 + python scaled_reg_grid_search.py --depth 12 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional +from itertools import product + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +# Import from main benchmark +from depth_scaling_benchmark import ( + SpikingVGG, + compute_lyap_reg_loss, +) + +import snntorch as snn +from snntorch import surrogate + + +@dataclass +class ExperimentResult: + depth: int + reg_type: str + lambda_reg: float + vanilla_acc: float + lyapunov_acc: float + final_lyap: Optional[float] + delta: float + + +def train_epoch(model, loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, reg_type, progress=False): + """Train one epoch.""" + model.train() + total_loss = 0.0 + correct = 0 + total = 0 + lyap_vals = [] + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4) + loss = criterion(logits, y) + + if use_lyapunov and lyap_est is not None: + # Target is implicitly 0 for scaled reg types + lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0) + loss = loss + lambda_reg * lyap_reg + lyap_vals.append(lyap_est.item()) + + loss.backward() + optimizer.step() + + total_loss += loss.item() * x.size(0) + _, pred = logits.max(1) + correct += pred.eq(y).sum().item() + total += x.size(0) + + avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None + return total_loss / total, correct / total, avg_lyap + + +def evaluate(model, loader, device): + """Evaluate model.""" + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + _, pred = logits.max(1) + correct += pred.eq(y).sum().item() + total += x.size(0) + + return correct / total + + +def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader, + device, epochs=100, lr=0.001): + """Run a single experiment configuration.""" + + # Determine blocks per stage based on depth + # depth = num_stages * blocks_per_stage, with num_stages=4 + blocks_per_stage = depth // 4 + + print(f"\n{'='*60}") + print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}") + print(f"{'='*60}") + + # --- Run Vanilla baseline --- + print(f" Training Vanilla...") + model_v = SpikingVGG( + num_classes=100, + blocks_per_stage=blocks_per_stage, + T=4, + ).to(device) + + optimizer_v = optim.Adam(model_v.parameters(), lr=lr) + criterion = nn.CrossEntropyLoss() + scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs) + + best_vanilla = 0.0 + for epoch in range(epochs): + train_epoch(model_v, train_loader, optimizer_v, criterion, device, + use_lyapunov=False, lambda_reg=0, reg_type="squared") + scheduler_v.step() + + if (epoch + 1) % 10 == 0 or epoch == epochs - 1: + acc = evaluate(model_v, test_loader, device) + best_vanilla = max(best_vanilla, acc) + print(f" Epoch {epoch+1:3d}: test={acc:.3f}") + + del model_v, optimizer_v, scheduler_v + torch.cuda.empty_cache() + + # --- Run Lyapunov version --- + print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...") + model_l = SpikingVGG( + num_classes=100, + blocks_per_stage=blocks_per_stage, + T=4, + ).to(device) + + optimizer_l = optim.Adam(model_l.parameters(), lr=lr) + scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs) + + best_lyap_acc = 0.0 + final_lyap = None + + for epoch in range(epochs): + _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device, + use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type) + scheduler_l.step() + final_lyap = lyap + + if (epoch + 1) % 10 == 0 or epoch == epochs - 1: + acc = evaluate(model_l, test_loader, device) + best_lyap_acc = max(best_lyap_acc, acc) + lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A" + print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}") + + del model_l, optimizer_l, scheduler_l + torch.cuda.empty_cache() + + delta = best_lyap_acc - best_vanilla + + result = ExperimentResult( + depth=depth, + reg_type=reg_type, + lambda_reg=lambda_reg, + vanilla_acc=best_vanilla, + lyapunov_acc=best_lyap_acc, + final_lyap=final_lyap, + delta=delta, + ) + + print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}") + + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12]) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print("=" * 70) + print("SCALED REGULARIZATION GRID SEARCH") + print("=" * 70) + print(f"Depth: {args.depth}") + print(f"Epochs: {args.epochs}") + print(f"Device: {device}") + if device.type == "cuda": + print(f"GPU: {torch.cuda.get_device_name()}") + print("=" * 70) + + # Grid parameters + lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks + reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning + + print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments") + print(f"λ_reg values: {lambda_regs}") + print(f"reg_types: {reg_types}") + + # Load data + print(f"\nLoading CIFAR-100...") + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + + train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train) + test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=4, pin_memory=True) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, + num_workers=4, pin_memory=True) + + print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}") + + # Run grid search + results = [] + + for lambda_reg, reg_type in product(lambda_regs, reg_types): + result = run_single_experiment( + depth=args.depth, + reg_type=reg_type, + lambda_reg=lambda_reg, + train_loader=train_loader, + test_loader=test_loader, + device=device, + epochs=args.epochs, + lr=args.lr, + ) + results.append(result) + + # Print summary table + print("\n" + "=" * 70) + print(f"SUMMARY: DEPTH = {args.depth}") + print("=" * 70) + print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}") + print("-" * 70) + + for r in results: + lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A" + delta_str = f"{r.delta:+.3f}" + print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}") + + # Find best configuration + best = max(results, key=lambda x: x.lyapunov_acc) + print("-" * 70) + print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})") + + # Save results + os.makedirs(args.out_dir, exist_ok=True) + out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json") + with open(out_file, "w") as f: + json.dump([asdict(r) for r in results], f, indent=2) + print(f"\nResults saved to: {out_file}") + + print("\n" + "=" * 70) + print("GRID SEARCH COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + main() |
