summaryrefslogtreecommitdiff
path: root/files/experiments/hyperparameter_grid_search.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/hyperparameter_grid_search.py
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/experiments/hyperparameter_grid_search.py')
-rw-r--r--files/experiments/hyperparameter_grid_search.py597
1 files changed, 597 insertions, 0 deletions
diff --git a/files/experiments/hyperparameter_grid_search.py b/files/experiments/hyperparameter_grid_search.py
new file mode 100644
index 0000000..011387f
--- /dev/null
+++ b/files/experiments/hyperparameter_grid_search.py
@@ -0,0 +1,597 @@
+"""
+Hyperparameter Grid Search for Lyapunov-Regularized SNNs.
+
+Goal: Find optimal (lambda_reg, lambda_target) for each network depth
+ and derive an adaptive curve for automatic hyperparameter selection.
+
+Usage:
+ python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Tuple
+from itertools import product
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+from tqdm.auto import tqdm
+
+from files.models.snn_snntorch import LyapunovSNN
+
+
+@dataclass
+class GridSearchResult:
+ """Result from a single grid search configuration."""
+ depth: int
+ lambda_reg: float
+ lambda_target: float
+ final_train_acc: float
+ final_val_acc: float
+ final_lyapunov: float
+ final_grad_norm: float
+ converged: bool # Did training succeed (not NaN)?
+ epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never)
+
+
+def create_synthetic_data(
+ n_train: int = 2000,
+ n_val: int = 500,
+ T: int = 50,
+ D: int = 100,
+ n_classes: int = 10,
+ seed: int = 42,
+ batch_size: int = 128,
+) -> Tuple[DataLoader, DataLoader, int, int, int]:
+ """Create synthetic spike data."""
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ def generate_data(n_samples):
+ x = torch.zeros(n_samples, T, D)
+ y = torch.randint(0, n_classes, (n_samples,))
+ for i in range(n_samples):
+ label = y[i].item()
+ base_rate = 0.05 + 0.02 * label
+ class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes))
+ for t in range(T):
+ x[i, t] = (torch.rand(D) < base_rate).float()
+ for c in class_channels:
+ if torch.rand(1) < base_rate * 3:
+ x[i, t, c] = 1.0
+ return x, y
+
+ x_train, y_train = generate_data(n_train)
+ x_val, y_val = generate_data(n_val)
+
+ train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
+ val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False)
+
+ return train_loader, val_loader, T, D, n_classes
+
+
+def train_and_evaluate(
+ depth: int,
+ lambda_reg: float,
+ lambda_target: float,
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ device: torch.device,
+ seed: int = 42,
+ warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early
+) -> GridSearchResult:
+ """Train a single configuration and return results."""
+ torch.manual_seed(seed)
+
+ # Create model
+ hidden_dims = [hidden_dim] * depth
+ model = LyapunovSNN(
+ input_dim=input_dim,
+ hidden_dims=hidden_dims,
+ num_classes=num_classes,
+ beta=0.9,
+ threshold=1.0,
+ ).to(device)
+
+ optimizer = optim.Adam(model.parameters(), lr=lr)
+ ce_loss = nn.CrossEntropyLoss()
+
+ best_val_acc = 0.0
+ epochs_to_90 = -1
+ final_lyap = 0.0
+ final_grad = 0.0
+ converged = True
+
+ for epoch in range(1, epochs + 1):
+ # Warmup: gradually increase lambda_reg
+ if epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ # Training
+ model.train()
+ total_correct = 0
+ total_samples = 0
+ lyap_vals = []
+ grad_norms = []
+
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False)
+
+ ce = ce_loss(logits, y)
+ if lyap_est is not None:
+ reg = (lyap_est - lambda_target) ** 2
+ loss = ce + current_lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+ else:
+ loss = ce
+
+ if torch.isnan(loss):
+ converged = False
+ break
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+ optimizer.step()
+
+ grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
+ grad_norms.append(grad_norm)
+
+ preds = logits.argmax(dim=1)
+ total_correct += (preds == y).sum().item()
+ total_samples += x.size(0)
+
+ if not converged:
+ break
+
+ train_acc = total_correct / total_samples
+ final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0
+ final_grad = np.mean(grad_norms) if grad_norms else 0.0
+
+ # Track epochs to 90% accuracy
+ if epochs_to_90 < 0 and train_acc >= 0.9:
+ epochs_to_90 = epoch
+
+ # Validation
+ model.eval()
+ val_correct = 0
+ val_total = 0
+ with torch.no_grad():
+ for x, y in val_loader:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False, record_states=False)
+ preds = logits.argmax(dim=1)
+ val_correct += (preds == y).sum().item()
+ val_total += x.size(0)
+ val_acc = val_correct / val_total
+ best_val_acc = max(best_val_acc, val_acc)
+
+ return GridSearchResult(
+ depth=depth,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ final_train_acc=train_acc if converged else 0.0,
+ final_val_acc=best_val_acc if converged else 0.0,
+ final_lyapunov=final_lyap,
+ final_grad_norm=final_grad,
+ converged=converged,
+ epochs_to_90pct=epochs_to_90,
+ )
+
+
+def run_grid_search(
+ depths: List[int],
+ lambda_regs: List[float],
+ lambda_targets: List[float],
+ train_loader: DataLoader,
+ val_loader: DataLoader,
+ input_dim: int,
+ num_classes: int,
+ hidden_dim: int,
+ epochs: int,
+ lr: float,
+ device: torch.device,
+ seed: int = 42,
+ progress: bool = True,
+) -> List[GridSearchResult]:
+ """Run full grid search."""
+ results = []
+
+ # Total configurations
+ configs = list(product(depths, lambda_regs, lambda_targets))
+ total = len(configs)
+
+ iterator = tqdm(configs, desc="Grid Search", disable=not progress)
+
+ for depth, lambda_reg, lambda_target in iterator:
+ if progress:
+ iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target})
+
+ result = train_and_evaluate(
+ depth=depth,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=input_dim,
+ num_classes=num_classes,
+ hidden_dim=hidden_dim,
+ epochs=epochs,
+ lr=lr,
+ device=device,
+ seed=seed,
+ )
+ results.append(result)
+
+ if progress:
+ iterator.set_postfix({
+ "d": depth,
+ "λr": lambda_reg,
+ "λt": lambda_target,
+ "acc": f"{result.final_val_acc:.2f}"
+ })
+
+ return results
+
+
+def analyze_results(results: List[GridSearchResult]) -> Dict:
+ """Analyze grid search results and find optimal hyperparameters per depth."""
+
+ # Group by depth
+ by_depth = {}
+ for r in results:
+ if r.depth not in by_depth:
+ by_depth[r.depth] = []
+ by_depth[r.depth].append(r)
+
+ analysis = {
+ "optimal_per_depth": {},
+ "all_results": [asdict(r) for r in results],
+ }
+
+ print("\n" + "=" * 80)
+ print("GRID SEARCH ANALYSIS")
+ print("=" * 80)
+
+ # Find optimal for each depth
+ print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}")
+ print("-" * 80)
+
+ optimal_lambda_regs = []
+ optimal_lambda_targets = []
+ depths_list = []
+
+ for depth in sorted(by_depth.keys()):
+ depth_results = by_depth[depth]
+ # Find best by validation accuracy
+ best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0)
+
+ analysis["optimal_per_depth"][depth] = {
+ "lambda_reg": best.lambda_reg,
+ "lambda_target": best.lambda_target,
+ "val_acc": best.final_val_acc,
+ "lyapunov": best.final_lyapunov,
+ "epochs_to_90": best.epochs_to_90pct,
+ }
+
+ print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} "
+ f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}")
+
+ if best.final_val_acc > 0.5: # Only use successful runs for curve fitting
+ depths_list.append(depth)
+ optimal_lambda_regs.append(best.lambda_reg)
+ optimal_lambda_targets.append(best.lambda_target)
+
+ # Fit adaptive curves
+ print("\n" + "=" * 80)
+ print("ADAPTIVE HYPERPARAMETER CURVES")
+ print("=" * 80)
+
+ if len(depths_list) >= 3:
+ # Fit polynomial curves
+ depths_arr = np.array(depths_list)
+ lambda_regs_arr = np.array(optimal_lambda_regs)
+ lambda_targets_arr = np.array(optimal_lambda_targets)
+
+ # Fit lambda_reg vs depth (expect increasing with depth)
+ try:
+ reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1))
+ reg_poly = np.poly1d(reg_coeffs)
+ print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}"
+ if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}")
+ analysis["lambda_reg_curve"] = reg_coeffs.tolist()
+ except Exception as e:
+ print(f"Could not fit λ_reg curve: {e}")
+
+ # Fit lambda_target vs depth (expect decreasing / more negative with depth)
+ try:
+ target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1))
+ target_poly = np.poly1d(target_coeffs)
+ print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}"
+ if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}")
+ analysis["lambda_target_curve"] = target_coeffs.tolist()
+ except Exception as e:
+ print(f"Could not fit λ_target curve: {e}")
+
+ # Print recommendations
+ print("\n" + "-" * 80)
+ print("RECOMMENDED HYPERPARAMETERS BY DEPTH:")
+ print("-" * 80)
+ for d in [2, 4, 6, 8, 10, 12, 14, 16]:
+ rec_reg = max(0.01, reg_poly(d))
+ rec_target = min(0.0, target_poly(d))
+ print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}")
+
+ else:
+ print("Not enough successful runs to fit curves")
+
+ return analysis
+
+
+def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict):
+ """Save grid search results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f:
+ json.dump(analysis, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def plot_grid_search(results: List[GridSearchResult], output_dir: str):
+ """Generate visualization of grid search results."""
+ try:
+ import matplotlib.pyplot as plt
+ except ImportError:
+ print("matplotlib not available, skipping plots")
+ return
+
+ # Group by depth
+ by_depth = {}
+ for r in results:
+ if r.depth not in by_depth:
+ by_depth[r.depth] = []
+ by_depth[r.depth].append(r)
+
+ depths = sorted(by_depth.keys())
+
+ # Get unique lambda values
+ lambda_regs = sorted(set(r.lambda_reg for r in results))
+ lambda_targets = sorted(set(r.lambda_target for r in results))
+
+ # Create heatmaps for each depth
+ n_depths = len(depths)
+ fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10))
+ axes = axes.flatten()
+
+ for idx, depth in enumerate(depths):
+ ax = axes[idx]
+ depth_results = by_depth[depth]
+
+ # Create accuracy matrix
+ acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs)))
+ for r in depth_results:
+ i = lambda_targets.index(r.lambda_target)
+ j = lambda_regs.index(r.lambda_reg)
+ acc_matrix[i, j] = r.final_val_acc
+
+ im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto')
+ ax.set_xticks(range(len(lambda_regs)))
+ ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45)
+ ax.set_yticks(range(len(lambda_targets)))
+ ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets])
+ ax.set_xlabel("λ_reg")
+ ax.set_ylabel("λ_target")
+ ax.set_title(f"Depth {depth}")
+
+ # Mark best
+ best = max(depth_results, key=lambda r: r.final_val_acc)
+ bi = lambda_targets.index(best.lambda_target)
+ bj = lambda_regs.index(best.lambda_reg)
+ ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2)
+
+ # Add colorbar
+ plt.colorbar(im, ax=ax, label='Val Acc')
+
+ # Hide unused subplots
+ for idx in range(len(depths), len(axes)):
+ axes[idx].axis('off')
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight')
+ plt.close()
+
+ # Plot optimal hyperparameters vs depth
+ fig, axes = plt.subplots(1, 3, figsize=(15, 4))
+
+ optimal_regs = []
+ optimal_targets = []
+ optimal_accs = []
+ for depth in depths:
+ best = max(by_depth[depth], key=lambda r: r.final_val_acc)
+ optimal_regs.append(best.lambda_reg)
+ optimal_targets.append(best.lambda_target)
+ optimal_accs.append(best.final_val_acc)
+
+ axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8)
+ axes[0].set_xlabel("Network Depth")
+ axes[0].set_ylabel("Optimal λ_reg")
+ axes[0].set_title("Optimal Regularization Strength vs Depth")
+ axes[0].grid(True, alpha=0.3)
+
+ axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange')
+ axes[1].set_xlabel("Network Depth")
+ axes[1].set_ylabel("Optimal λ_target")
+ axes[1].set_title("Optimal Target Lyapunov vs Depth")
+ axes[1].grid(True, alpha=0.3)
+
+ axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green')
+ axes[2].set_xlabel("Network Depth")
+ axes[2].set_ylabel("Best Validation Accuracy")
+ axes[2].set_title("Best Achievable Accuracy vs Depth")
+ axes[2].set_ylim(0, 1.05)
+ axes[2].grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight')
+ plt.close()
+
+ print(f"Plots saved to {output_dir}")
+
+
+def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'):
+ """Get CIFAR-10 dataloaders with rate encoding for SNN."""
+ from torchvision import datasets, transforms
+
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)
+ val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform)
+
+ # Rate encoding: convert images to spike sequences
+ class RateEncodedDataset(torch.utils.data.Dataset):
+ def __init__(self, dataset, T):
+ self.dataset = dataset
+ self.T = T
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ img, label = self.dataset[idx]
+ # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D)
+ flat = img.view(-1) # (3072,)
+ # Rate encoding: probability of spike = pixel intensity
+ spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float()
+ return spikes, label
+
+ train_encoded = RateEncodedDataset(train_ds, T)
+ val_encoded = RateEncodedDataset(val_ds, T)
+
+ train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4)
+
+ return train_loader, val_loader, T, 3072, 10 # T, D, num_classes
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN")
+
+ # Grid search parameters
+ p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10],
+ help="Network depths to test")
+ p.add_argument("--lambda_regs", type=float, nargs="+",
+ default=[0.01, 0.05, 0.1, 0.2, 0.3],
+ help="Lambda_reg values to test")
+ p.add_argument("--lambda_targets", type=float, nargs="+",
+ default=[0.0, -0.05, -0.1, -0.2],
+ help="Lambda_target values to test")
+
+ # Model parameters
+ p.add_argument("--hidden_dim", type=int, default=256)
+ p.add_argument("--epochs", type=int, default=15)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--seed", type=int, default=42)
+
+ # Data
+ p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)")
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding")
+
+ # Output
+ p.add_argument("--out_dir", type=str, default="runs/grid_search")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--no-progress", action="store_true")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("HYPERPARAMETER GRID SEARCH")
+ print("=" * 80)
+ print(f"Depths: {args.depths}")
+ print(f"λ_reg values: {args.lambda_regs}")
+ print(f"λ_target values: {args.lambda_targets}")
+ print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}")
+ print(f"Epochs per config: {args.epochs}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Load data
+ if args.synthetic:
+ print("\nUsing synthetic data")
+ train_loader, val_loader, T, D, C = create_synthetic_data(
+ seed=args.seed, batch_size=args.batch_size
+ )
+ else:
+ print("\nUsing CIFAR-10 with rate encoding")
+ train_loader, val_loader, T, D, C = get_cifar10_loaders(
+ batch_size=args.batch_size,
+ T=args.T,
+ data_dir=args.data_dir
+ )
+
+ print(f"Data: T={T}, D={D}, classes={C}\n")
+
+ # Run grid search
+ results = run_grid_search(
+ depths=args.depths,
+ lambda_regs=args.lambda_regs,
+ lambda_targets=args.lambda_targets,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ input_dim=D,
+ num_classes=C,
+ hidden_dim=args.hidden_dim,
+ epochs=args.epochs,
+ lr=args.lr,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ )
+
+ # Analyze results
+ analysis = analyze_results(results)
+
+ # Save results
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, ts)
+ save_results(results, analysis, output_dir, vars(args))
+
+ # Generate plots
+ plot_grid_search(results, output_dir)
+
+
+if __name__ == "__main__":
+ main()