""" Hyperparameter Grid Search for Lyapunov-Regularized SNNs. Goal: Find optimal (lambda_reg, lambda_target) for each network depth and derive an adaptive curve for automatic hyperparameter selection. Usage: python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20 """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Tuple from itertools import product _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from tqdm.auto import tqdm from files.models.snn_snntorch import LyapunovSNN @dataclass class GridSearchResult: """Result from a single grid search configuration.""" depth: int lambda_reg: float lambda_target: float final_train_acc: float final_val_acc: float final_lyapunov: float final_grad_norm: float converged: bool # Did training succeed (not NaN)? epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never) def create_synthetic_data( n_train: int = 2000, n_val: int = 500, T: int = 50, D: int = 100, n_classes: int = 10, seed: int = 42, batch_size: int = 128, ) -> Tuple[DataLoader, DataLoader, int, int, int]: """Create synthetic spike data.""" torch.manual_seed(seed) np.random.seed(seed) def generate_data(n_samples): x = torch.zeros(n_samples, T, D) y = torch.randint(0, n_classes, (n_samples,)) for i in range(n_samples): label = y[i].item() base_rate = 0.05 + 0.02 * label class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) for t in range(T): x[i, t] = (torch.rand(D) < base_rate).float() for c in class_channels: if torch.rand(1) < base_rate * 3: x[i, t, c] = 1.0 return x, y x_train, y_train = generate_data(n_train) x_val, y_val = generate_data(n_val) train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False) return train_loader, val_loader, T, D, n_classes def train_and_evaluate( depth: int, lambda_reg: float, lambda_target: float, train_loader: DataLoader, val_loader: DataLoader, input_dim: int, num_classes: int, hidden_dim: int, epochs: int, lr: float, device: torch.device, seed: int = 42, warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early ) -> GridSearchResult: """Train a single configuration and return results.""" torch.manual_seed(seed) # Create model hidden_dims = [hidden_dim] * depth model = LyapunovSNN( input_dim=input_dim, hidden_dims=hidden_dims, num_classes=num_classes, beta=0.9, threshold=1.0, ).to(device) optimizer = optim.Adam(model.parameters(), lr=lr) ce_loss = nn.CrossEntropyLoss() best_val_acc = 0.0 epochs_to_90 = -1 final_lyap = 0.0 final_grad = 0.0 converged = True for epoch in range(1, epochs + 1): # Warmup: gradually increase lambda_reg if epoch <= warmup_epochs: current_lambda_reg = lambda_reg * (epoch / warmup_epochs) else: current_lambda_reg = lambda_reg # Training model.train() total_correct = 0 total_samples = 0 lyap_vals = [] grad_norms = [] for x, y in train_loader: x, y = x.to(device), y.to(device) optimizer.zero_grad() logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False) ce = ce_loss(logits, y) if lyap_est is not None: reg = (lyap_est - lambda_target) ** 2 loss = ce + current_lambda_reg * reg lyap_vals.append(lyap_est.item()) else: loss = ce if torch.isnan(loss): converged = False break loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) optimizer.step() grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 grad_norms.append(grad_norm) preds = logits.argmax(dim=1) total_correct += (preds == y).sum().item() total_samples += x.size(0) if not converged: break train_acc = total_correct / total_samples final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0 final_grad = np.mean(grad_norms) if grad_norms else 0.0 # Track epochs to 90% accuracy if epochs_to_90 < 0 and train_acc >= 0.9: epochs_to_90 = epoch # Validation model.eval() val_correct = 0 val_total = 0 with torch.no_grad(): for x, y in val_loader: x, y = x.to(device), y.to(device) logits, _, _ = model(x, compute_lyapunov=False, record_states=False) preds = logits.argmax(dim=1) val_correct += (preds == y).sum().item() val_total += x.size(0) val_acc = val_correct / val_total best_val_acc = max(best_val_acc, val_acc) return GridSearchResult( depth=depth, lambda_reg=lambda_reg, lambda_target=lambda_target, final_train_acc=train_acc if converged else 0.0, final_val_acc=best_val_acc if converged else 0.0, final_lyapunov=final_lyap, final_grad_norm=final_grad, converged=converged, epochs_to_90pct=epochs_to_90, ) def run_grid_search( depths: List[int], lambda_regs: List[float], lambda_targets: List[float], train_loader: DataLoader, val_loader: DataLoader, input_dim: int, num_classes: int, hidden_dim: int, epochs: int, lr: float, device: torch.device, seed: int = 42, progress: bool = True, ) -> List[GridSearchResult]: """Run full grid search.""" results = [] # Total configurations configs = list(product(depths, lambda_regs, lambda_targets)) total = len(configs) iterator = tqdm(configs, desc="Grid Search", disable=not progress) for depth, lambda_reg, lambda_target in iterator: if progress: iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target}) result = train_and_evaluate( depth=depth, lambda_reg=lambda_reg, lambda_target=lambda_target, train_loader=train_loader, val_loader=val_loader, input_dim=input_dim, num_classes=num_classes, hidden_dim=hidden_dim, epochs=epochs, lr=lr, device=device, seed=seed, ) results.append(result) if progress: iterator.set_postfix({ "d": depth, "λr": lambda_reg, "λt": lambda_target, "acc": f"{result.final_val_acc:.2f}" }) return results def analyze_results(results: List[GridSearchResult]) -> Dict: """Analyze grid search results and find optimal hyperparameters per depth.""" # Group by depth by_depth = {} for r in results: if r.depth not in by_depth: by_depth[r.depth] = [] by_depth[r.depth].append(r) analysis = { "optimal_per_depth": {}, "all_results": [asdict(r) for r in results], } print("\n" + "=" * 80) print("GRID SEARCH ANALYSIS") print("=" * 80) # Find optimal for each depth print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}") print("-" * 80) optimal_lambda_regs = [] optimal_lambda_targets = [] depths_list = [] for depth in sorted(by_depth.keys()): depth_results = by_depth[depth] # Find best by validation accuracy best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0) analysis["optimal_per_depth"][depth] = { "lambda_reg": best.lambda_reg, "lambda_target": best.lambda_target, "val_acc": best.final_val_acc, "lyapunov": best.final_lyapunov, "epochs_to_90": best.epochs_to_90pct, } print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} " f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}") if best.final_val_acc > 0.5: # Only use successful runs for curve fitting depths_list.append(depth) optimal_lambda_regs.append(best.lambda_reg) optimal_lambda_targets.append(best.lambda_target) # Fit adaptive curves print("\n" + "=" * 80) print("ADAPTIVE HYPERPARAMETER CURVES") print("=" * 80) if len(depths_list) >= 3: # Fit polynomial curves depths_arr = np.array(depths_list) lambda_regs_arr = np.array(optimal_lambda_regs) lambda_targets_arr = np.array(optimal_lambda_targets) # Fit lambda_reg vs depth (expect increasing with depth) try: reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1)) reg_poly = np.poly1d(reg_coeffs) print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}" if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}") analysis["lambda_reg_curve"] = reg_coeffs.tolist() except Exception as e: print(f"Could not fit λ_reg curve: {e}") # Fit lambda_target vs depth (expect decreasing / more negative with depth) try: target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1)) target_poly = np.poly1d(target_coeffs) print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}" if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}") analysis["lambda_target_curve"] = target_coeffs.tolist() except Exception as e: print(f"Could not fit λ_target curve: {e}") # Print recommendations print("\n" + "-" * 80) print("RECOMMENDED HYPERPARAMETERS BY DEPTH:") print("-" * 80) for d in [2, 4, 6, 8, 10, 12, 14, 16]: rec_reg = max(0.01, reg_poly(d)) rec_target = min(0.0, target_poly(d)) print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}") else: print("Not enough successful runs to fit curves") return analysis def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict): """Save grid search results.""" os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f: json.dump(analysis, f, indent=2) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f"\nResults saved to {output_dir}") def plot_grid_search(results: List[GridSearchResult], output_dir: str): """Generate visualization of grid search results.""" try: import matplotlib.pyplot as plt except ImportError: print("matplotlib not available, skipping plots") return # Group by depth by_depth = {} for r in results: if r.depth not in by_depth: by_depth[r.depth] = [] by_depth[r.depth].append(r) depths = sorted(by_depth.keys()) # Get unique lambda values lambda_regs = sorted(set(r.lambda_reg for r in results)) lambda_targets = sorted(set(r.lambda_target for r in results)) # Create heatmaps for each depth n_depths = len(depths) fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10)) axes = axes.flatten() for idx, depth in enumerate(depths): ax = axes[idx] depth_results = by_depth[depth] # Create accuracy matrix acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs))) for r in depth_results: i = lambda_targets.index(r.lambda_target) j = lambda_regs.index(r.lambda_reg) acc_matrix[i, j] = r.final_val_acc im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto') ax.set_xticks(range(len(lambda_regs))) ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45) ax.set_yticks(range(len(lambda_targets))) ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets]) ax.set_xlabel("λ_reg") ax.set_ylabel("λ_target") ax.set_title(f"Depth {depth}") # Mark best best = max(depth_results, key=lambda r: r.final_val_acc) bi = lambda_targets.index(best.lambda_target) bj = lambda_regs.index(best.lambda_reg) ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2) # Add colorbar plt.colorbar(im, ax=ax, label='Val Acc') # Hide unused subplots for idx in range(len(depths), len(axes)): axes[idx].axis('off') plt.tight_layout() plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight') plt.close() # Plot optimal hyperparameters vs depth fig, axes = plt.subplots(1, 3, figsize=(15, 4)) optimal_regs = [] optimal_targets = [] optimal_accs = [] for depth in depths: best = max(by_depth[depth], key=lambda r: r.final_val_acc) optimal_regs.append(best.lambda_reg) optimal_targets.append(best.lambda_target) optimal_accs.append(best.final_val_acc) axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8) axes[0].set_xlabel("Network Depth") axes[0].set_ylabel("Optimal λ_reg") axes[0].set_title("Optimal Regularization Strength vs Depth") axes[0].grid(True, alpha=0.3) axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange') axes[1].set_xlabel("Network Depth") axes[1].set_ylabel("Optimal λ_target") axes[1].set_title("Optimal Target Lyapunov vs Depth") axes[1].grid(True, alpha=0.3) axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green') axes[2].set_xlabel("Network Depth") axes[2].set_ylabel("Best Validation Accuracy") axes[2].set_title("Best Achievable Accuracy vs Depth") axes[2].set_ylim(0, 1.05) axes[2].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight') plt.close() print(f"Plots saved to {output_dir}") def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'): """Get CIFAR-10 dataloaders with rate encoding for SNN.""" from torchvision import datasets, transforms transform = transforms.Compose([ transforms.ToTensor(), ]) train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform) val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform) # Rate encoding: convert images to spike sequences class RateEncodedDataset(torch.utils.data.Dataset): def __init__(self, dataset, T): self.dataset = dataset self.T = T def __len__(self): return len(self.dataset) def __getitem__(self, idx): img, label = self.dataset[idx] # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D) flat = img.view(-1) # (3072,) # Rate encoding: probability of spike = pixel intensity spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float() return spikes, label train_encoded = RateEncodedDataset(train_ds, T) val_encoded = RateEncodedDataset(val_ds, T) train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4) return train_loader, val_loader, T, 3072, 10 # T, D, num_classes def parse_args(): p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN") # Grid search parameters p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10], help="Network depths to test") p.add_argument("--lambda_regs", type=float, nargs="+", default=[0.01, 0.05, 0.1, 0.2, 0.3], help="Lambda_reg values to test") p.add_argument("--lambda_targets", type=float, nargs="+", default=[0.0, -0.05, -0.1, -0.2], help="Lambda_target values to test") # Model parameters p.add_argument("--hidden_dim", type=int, default=256) p.add_argument("--epochs", type=int, default=15) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--batch_size", type=int, default=128) p.add_argument("--seed", type=int, default=42) # Data p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)") p.add_argument("--data_dir", type=str, default="./data") p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding") # Output p.add_argument("--out_dir", type=str, default="runs/grid_search") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--no-progress", action="store_true") return p.parse_args() def main(): args = parse_args() device = torch.device(args.device) print("=" * 80) print("HYPERPARAMETER GRID SEARCH") print("=" * 80) print(f"Depths: {args.depths}") print(f"λ_reg values: {args.lambda_regs}") print(f"λ_target values: {args.lambda_targets}") print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}") print(f"Epochs per config: {args.epochs}") print(f"Device: {device}") print("=" * 80) # Load data if args.synthetic: print("\nUsing synthetic data") train_loader, val_loader, T, D, C = create_synthetic_data( seed=args.seed, batch_size=args.batch_size ) else: print("\nUsing CIFAR-10 with rate encoding") train_loader, val_loader, T, D, C = get_cifar10_loaders( batch_size=args.batch_size, T=args.T, data_dir=args.data_dir ) print(f"Data: T={T}, D={D}, classes={C}\n") # Run grid search results = run_grid_search( depths=args.depths, lambda_regs=args.lambda_regs, lambda_targets=args.lambda_targets, train_loader=train_loader, val_loader=val_loader, input_dim=D, num_classes=C, hidden_dim=args.hidden_dim, epochs=args.epochs, lr=args.lr, device=device, seed=args.seed, progress=not args.no_progress, ) # Analyze results analysis = analyze_results(results) # Save results ts = time.strftime("%Y%m%d-%H%M%S") output_dir = os.path.join(args.out_dir, ts) save_results(results, analysis, output_dir, vars(args)) # Generate plots plot_grid_search(results, output_dir) if __name__ == "__main__": main()