summaryrefslogtreecommitdiff
path: root/files/experiments/scaled_reg_grid_search.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/scaled_reg_grid_search.py
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/experiments/scaled_reg_grid_search.py')
-rw-r--r--files/experiments/scaled_reg_grid_search.py301
1 files changed, 301 insertions, 0 deletions
diff --git a/files/experiments/scaled_reg_grid_search.py b/files/experiments/scaled_reg_grid_search.py
new file mode 100644
index 0000000..928caff
--- /dev/null
+++ b/files/experiments/scaled_reg_grid_search.py
@@ -0,0 +1,301 @@
+"""
+Grid Search: Multiplier-Scaled Regularization Experiments
+
+Tests the new multiplier-scaled regularization approach:
+ loss = (λ_reg × g(relu(λ))) × relu(λ)
+
+Where g(x) is the multiplier scaling function:
+ - mult_linear: g(x) = x → loss = λ_reg × relu(λ)²
+ - mult_squared: g(x) = x² → loss = λ_reg × relu(λ)³
+ - mult_log: g(x) = log(1+x) → loss = λ_reg × log(1+relu(λ)) × relu(λ)
+
+Grid:
+ - λ_reg: 0.01, 0.05, 0.1, 0.3
+ - reg_type: mult_linear, mult_squared, mult_log
+ - depth: specified via command line
+
+Usage:
+ python scaled_reg_grid_search.py --depth 4
+ python scaled_reg_grid_search.py --depth 8
+ python scaled_reg_grid_search.py --depth 12
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional
+from itertools import product
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+# Import from main benchmark
+from depth_scaling_benchmark import (
+ SpikingVGG,
+ compute_lyap_reg_loss,
+)
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+@dataclass
+class ExperimentResult:
+ depth: int
+ reg_type: str
+ lambda_reg: float
+ vanilla_acc: float
+ lyapunov_acc: float
+ final_lyap: Optional[float]
+ delta: float
+
+
+def train_epoch(model, loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, reg_type, progress=False):
+ """Train one epoch."""
+ model.train()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+ lyap_vals = []
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=1e-4)
+ loss = criterion(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ # Target is implicitly 0 for scaled reg types
+ lyap_reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target=0.0)
+ loss = loss + lambda_reg * lyap_reg
+ lyap_vals.append(lyap_est.item())
+
+ loss.backward()
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ _, pred = logits.max(1)
+ correct += pred.eq(y).sum().item()
+ total += x.size(0)
+
+ avg_lyap = sum(lyap_vals) / len(lyap_vals) if lyap_vals else None
+ return total_loss / total, correct / total, avg_lyap
+
+
+def evaluate(model, loader, device):
+ """Evaluate model."""
+ model.eval()
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+ _, pred = logits.max(1)
+ correct += pred.eq(y).sum().item()
+ total += x.size(0)
+
+ return correct / total
+
+
+def run_single_experiment(depth, reg_type, lambda_reg, train_loader, test_loader,
+ device, epochs=100, lr=0.001):
+ """Run a single experiment configuration."""
+
+ # Determine blocks per stage based on depth
+ # depth = num_stages * blocks_per_stage, with num_stages=4
+ blocks_per_stage = depth // 4
+
+ print(f"\n{'='*60}")
+ print(f"Config: depth={depth}, reg_type={reg_type}, λ_reg={lambda_reg}")
+ print(f"{'='*60}")
+
+ # --- Run Vanilla baseline ---
+ print(f" Training Vanilla...")
+ model_v = SpikingVGG(
+ num_classes=100,
+ blocks_per_stage=blocks_per_stage,
+ T=4,
+ ).to(device)
+
+ optimizer_v = optim.Adam(model_v.parameters(), lr=lr)
+ criterion = nn.CrossEntropyLoss()
+ scheduler_v = optim.lr_scheduler.CosineAnnealingLR(optimizer_v, T_max=epochs)
+
+ best_vanilla = 0.0
+ for epoch in range(epochs):
+ train_epoch(model_v, train_loader, optimizer_v, criterion, device,
+ use_lyapunov=False, lambda_reg=0, reg_type="squared")
+ scheduler_v.step()
+
+ if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
+ acc = evaluate(model_v, test_loader, device)
+ best_vanilla = max(best_vanilla, acc)
+ print(f" Epoch {epoch+1:3d}: test={acc:.3f}")
+
+ del model_v, optimizer_v, scheduler_v
+ torch.cuda.empty_cache()
+
+ # --- Run Lyapunov version ---
+ print(f" Training Lyapunov ({reg_type}, λ_reg={lambda_reg})...")
+ model_l = SpikingVGG(
+ num_classes=100,
+ blocks_per_stage=blocks_per_stage,
+ T=4,
+ ).to(device)
+
+ optimizer_l = optim.Adam(model_l.parameters(), lr=lr)
+ scheduler_l = optim.lr_scheduler.CosineAnnealingLR(optimizer_l, T_max=epochs)
+
+ best_lyap_acc = 0.0
+ final_lyap = None
+
+ for epoch in range(epochs):
+ _, _, lyap = train_epoch(model_l, train_loader, optimizer_l, criterion, device,
+ use_lyapunov=True, lambda_reg=lambda_reg, reg_type=reg_type)
+ scheduler_l.step()
+ final_lyap = lyap
+
+ if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
+ acc = evaluate(model_l, test_loader, device)
+ best_lyap_acc = max(best_lyap_acc, acc)
+ lyap_str = f"λ={lyap:.3f}" if lyap else "λ=N/A"
+ print(f" Epoch {epoch+1:3d}: test={acc:.3f} {lyap_str}")
+
+ del model_l, optimizer_l, scheduler_l
+ torch.cuda.empty_cache()
+
+ delta = best_lyap_acc - best_vanilla
+
+ result = ExperimentResult(
+ depth=depth,
+ reg_type=reg_type,
+ lambda_reg=lambda_reg,
+ vanilla_acc=best_vanilla,
+ lyapunov_acc=best_lyap_acc,
+ final_lyap=final_lyap,
+ delta=delta,
+ )
+
+ print(f" Result: Vanilla={best_vanilla:.3f}, Lyap={best_lyap_acc:.3f}, Δ={delta:+.3f}")
+
+ return result
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--depth", type=int, required=True, choices=[4, 8, 12])
+ parser.add_argument("--epochs", type=int, default=100)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--lr", type=float, default=0.001)
+ parser.add_argument("--data_dir", type=str, default="./data")
+ parser.add_argument("--out_dir", type=str, default="./runs/scaled_grid")
+ args = parser.parse_args()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ print("=" * 70)
+ print("SCALED REGULARIZATION GRID SEARCH")
+ print("=" * 70)
+ print(f"Depth: {args.depth}")
+ print(f"Epochs: {args.epochs}")
+ print(f"Device: {device}")
+ if device.type == "cuda":
+ print(f"GPU: {torch.cuda.get_device_name()}")
+ print("=" * 70)
+
+ # Grid parameters
+ lambda_regs = [0.0005, 0.001, 0.002, 0.005] # smaller values for deeper networks
+ reg_types = ["mult_linear", "mult_log"] # mult_squared too aggressive, kills learning
+
+ print(f"\nGrid: {len(lambda_regs)} λ_reg × {len(reg_types)} reg_types = {len(lambda_regs) * len(reg_types)} experiments")
+ print(f"λ_reg values: {lambda_regs}")
+ print(f"reg_types: {reg_types}")
+
+ # Load data
+ print(f"\nLoading CIFAR-100...")
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+
+ train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=True, transform=transform_train)
+ test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=True, transform=transform_test)
+
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
+ num_workers=4, pin_memory=True)
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
+ num_workers=4, pin_memory=True)
+
+ print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
+
+ # Run grid search
+ results = []
+
+ for lambda_reg, reg_type in product(lambda_regs, reg_types):
+ result = run_single_experiment(
+ depth=args.depth,
+ reg_type=reg_type,
+ lambda_reg=lambda_reg,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ device=device,
+ epochs=args.epochs,
+ lr=args.lr,
+ )
+ results.append(result)
+
+ # Print summary table
+ print("\n" + "=" * 70)
+ print(f"SUMMARY: DEPTH = {args.depth}")
+ print("=" * 70)
+ print(f"{'reg_type':<16} {'λ_reg':>8} {'Vanilla':>8} {'Lyapunov':>8} {'Δ':>8} {'Final λ':>8}")
+ print("-" * 70)
+
+ for r in results:
+ lyap_str = f"{r.final_lyap:.3f}" if r.final_lyap else "N/A"
+ delta_str = f"{r.delta:+.3f}"
+ print(f"{r.reg_type:<16} {r.lambda_reg:>8.2f} {r.vanilla_acc:>8.3f} {r.lyapunov_acc:>8.3f} {delta_str:>8} {lyap_str:>8}")
+
+ # Find best configuration
+ best = max(results, key=lambda x: x.lyapunov_acc)
+ print("-" * 70)
+ print(f"BEST: {best.reg_type}, λ_reg={best.lambda_reg} → {best.lyapunov_acc:.3f} (Δ={best.delta:+.3f})")
+
+ # Save results
+ os.makedirs(args.out_dir, exist_ok=True)
+ out_file = os.path.join(args.out_dir, f"depth{args.depth}_results.json")
+ with open(out_file, "w") as f:
+ json.dump([asdict(r) for r in results], f, indent=2)
+ print(f"\nResults saved to: {out_file}")
+
+ print("\n" + "=" * 70)
+ print("GRID SEARCH COMPLETE")
+ print("=" * 70)
+
+
+if __name__ == "__main__":
+ main()