diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/hyperparameter_grid_search.py | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/experiments/hyperparameter_grid_search.py')
| -rw-r--r-- | files/experiments/hyperparameter_grid_search.py | 597 |
1 files changed, 597 insertions, 0 deletions
diff --git a/files/experiments/hyperparameter_grid_search.py b/files/experiments/hyperparameter_grid_search.py new file mode 100644 index 0000000..011387f --- /dev/null +++ b/files/experiments/hyperparameter_grid_search.py @@ -0,0 +1,597 @@ +""" +Hyperparameter Grid Search for Lyapunov-Regularized SNNs. + +Goal: Find optimal (lambda_reg, lambda_target) for each network depth + and derive an adaptive curve for automatic hyperparameter selection. + +Usage: + python files/experiments/hyperparameter_grid_search.py --synthetic --epochs 20 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Tuple +from itertools import product + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm + +from files.models.snn_snntorch import LyapunovSNN + + +@dataclass +class GridSearchResult: + """Result from a single grid search configuration.""" + depth: int + lambda_reg: float + lambda_target: float + final_train_acc: float + final_val_acc: float + final_lyapunov: float + final_grad_norm: float + converged: bool # Did training succeed (not NaN)? + epochs_to_90pct: int # Epochs to reach 90% train accuracy (-1 if never) + + +def create_synthetic_data( + n_train: int = 2000, + n_val: int = 500, + T: int = 50, + D: int = 100, + n_classes: int = 10, + seed: int = 42, + batch_size: int = 128, +) -> Tuple[DataLoader, DataLoader, int, int, int]: + """Create synthetic spike data.""" + torch.manual_seed(seed) + np.random.seed(seed) + + def generate_data(n_samples): + x = torch.zeros(n_samples, T, D) + y = torch.randint(0, n_classes, (n_samples,)) + for i in range(n_samples): + label = y[i].item() + base_rate = 0.05 + 0.02 * label + class_channels = range(label * (D // n_classes), (label + 1) * (D // n_classes)) + for t in range(T): + x[i, t] = (torch.rand(D) < base_rate).float() + for c in class_channels: + if torch.rand(1) < base_rate * 3: + x[i, t, c] = 1.0 + return x, y + + x_train, y_train = generate_data(n_train) + x_val, y_val = generate_data(n_val) + + train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True) + val_loader = DataLoader(TensorDataset(x_val, y_val), batch_size=batch_size, shuffle=False) + + return train_loader, val_loader, T, D, n_classes + + +def train_and_evaluate( + depth: int, + lambda_reg: float, + lambda_target: float, + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + device: torch.device, + seed: int = 42, + warmup_epochs: int = 5, # Warmup λ_reg to avoid killing learning early +) -> GridSearchResult: + """Train a single configuration and return results.""" + torch.manual_seed(seed) + + # Create model + hidden_dims = [hidden_dim] * depth + model = LyapunovSNN( + input_dim=input_dim, + hidden_dims=hidden_dims, + num_classes=num_classes, + beta=0.9, + threshold=1.0, + ).to(device) + + optimizer = optim.Adam(model.parameters(), lr=lr) + ce_loss = nn.CrossEntropyLoss() + + best_val_acc = 0.0 + epochs_to_90 = -1 + final_lyap = 0.0 + final_grad = 0.0 + converged = True + + for epoch in range(1, epochs + 1): + # Warmup: gradually increase lambda_reg + if epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + # Training + model.train() + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + + for x, y in train_loader: + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=True, lyap_eps=1e-4, record_states=False) + + ce = ce_loss(logits, y) + if lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + current_lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + converged = False + break + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + optimizer.step() + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + if not converged: + break + + train_acc = total_correct / total_samples + final_lyap = np.mean(lyap_vals) if lyap_vals else 0.0 + final_grad = np.mean(grad_norms) if grad_norms else 0.0 + + # Track epochs to 90% accuracy + if epochs_to_90 < 0 and train_acc >= 0.9: + epochs_to_90 = epoch + + # Validation + model.eval() + val_correct = 0 + val_total = 0 + with torch.no_grad(): + for x, y in val_loader: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False, record_states=False) + preds = logits.argmax(dim=1) + val_correct += (preds == y).sum().item() + val_total += x.size(0) + val_acc = val_correct / val_total + best_val_acc = max(best_val_acc, val_acc) + + return GridSearchResult( + depth=depth, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + final_train_acc=train_acc if converged else 0.0, + final_val_acc=best_val_acc if converged else 0.0, + final_lyapunov=final_lyap, + final_grad_norm=final_grad, + converged=converged, + epochs_to_90pct=epochs_to_90, + ) + + +def run_grid_search( + depths: List[int], + lambda_regs: List[float], + lambda_targets: List[float], + train_loader: DataLoader, + val_loader: DataLoader, + input_dim: int, + num_classes: int, + hidden_dim: int, + epochs: int, + lr: float, + device: torch.device, + seed: int = 42, + progress: bool = True, +) -> List[GridSearchResult]: + """Run full grid search.""" + results = [] + + # Total configurations + configs = list(product(depths, lambda_regs, lambda_targets)) + total = len(configs) + + iterator = tqdm(configs, desc="Grid Search", disable=not progress) + + for depth, lambda_reg, lambda_target in iterator: + if progress: + iterator.set_postfix({"d": depth, "λr": lambda_reg, "λt": lambda_target}) + + result = train_and_evaluate( + depth=depth, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + train_loader=train_loader, + val_loader=val_loader, + input_dim=input_dim, + num_classes=num_classes, + hidden_dim=hidden_dim, + epochs=epochs, + lr=lr, + device=device, + seed=seed, + ) + results.append(result) + + if progress: + iterator.set_postfix({ + "d": depth, + "λr": lambda_reg, + "λt": lambda_target, + "acc": f"{result.final_val_acc:.2f}" + }) + + return results + + +def analyze_results(results: List[GridSearchResult]) -> Dict: + """Analyze grid search results and find optimal hyperparameters per depth.""" + + # Group by depth + by_depth = {} + for r in results: + if r.depth not in by_depth: + by_depth[r.depth] = [] + by_depth[r.depth].append(r) + + analysis = { + "optimal_per_depth": {}, + "all_results": [asdict(r) for r in results], + } + + print("\n" + "=" * 80) + print("GRID SEARCH ANALYSIS") + print("=" * 80) + + # Find optimal for each depth + print(f"\n{'Depth':<8} {'Best λ_reg':<12} {'Best λ_target':<14} {'Val Acc':<10} {'Lyapunov':<10}") + print("-" * 80) + + optimal_lambda_regs = [] + optimal_lambda_targets = [] + depths_list = [] + + for depth in sorted(by_depth.keys()): + depth_results = by_depth[depth] + # Find best by validation accuracy + best = max(depth_results, key=lambda r: r.final_val_acc if r.converged else 0) + + analysis["optimal_per_depth"][depth] = { + "lambda_reg": best.lambda_reg, + "lambda_target": best.lambda_target, + "val_acc": best.final_val_acc, + "lyapunov": best.final_lyapunov, + "epochs_to_90": best.epochs_to_90pct, + } + + print(f"{depth:<8} {best.lambda_reg:<12.3f} {best.lambda_target:<14.3f} " + f"{best.final_val_acc:<10.3f} {best.final_lyapunov:<10.3f}") + + if best.final_val_acc > 0.5: # Only use successful runs for curve fitting + depths_list.append(depth) + optimal_lambda_regs.append(best.lambda_reg) + optimal_lambda_targets.append(best.lambda_target) + + # Fit adaptive curves + print("\n" + "=" * 80) + print("ADAPTIVE HYPERPARAMETER CURVES") + print("=" * 80) + + if len(depths_list) >= 3: + # Fit polynomial curves + depths_arr = np.array(depths_list) + lambda_regs_arr = np.array(optimal_lambda_regs) + lambda_targets_arr = np.array(optimal_lambda_targets) + + # Fit lambda_reg vs depth (expect increasing with depth) + try: + reg_coeffs = np.polyfit(depths_arr, lambda_regs_arr, deg=min(2, len(depths_arr) - 1)) + reg_poly = np.poly1d(reg_coeffs) + print(f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d² + {reg_coeffs[1]:.4f}·d + {reg_coeffs[2]:.4f}" + if len(reg_coeffs) == 3 else f"\nλ_reg(depth) ≈ {reg_coeffs[0]:.4f}·d + {reg_coeffs[1]:.4f}") + analysis["lambda_reg_curve"] = reg_coeffs.tolist() + except Exception as e: + print(f"Could not fit λ_reg curve: {e}") + + # Fit lambda_target vs depth (expect decreasing / more negative with depth) + try: + target_coeffs = np.polyfit(depths_arr, lambda_targets_arr, deg=min(2, len(depths_arr) - 1)) + target_poly = np.poly1d(target_coeffs) + print(f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d² + {target_coeffs[1]:.4f}·d + {target_coeffs[2]:.4f}" + if len(target_coeffs) == 3 else f"λ_target(depth) ≈ {target_coeffs[0]:.4f}·d + {target_coeffs[1]:.4f}") + analysis["lambda_target_curve"] = target_coeffs.tolist() + except Exception as e: + print(f"Could not fit λ_target curve: {e}") + + # Print recommendations + print("\n" + "-" * 80) + print("RECOMMENDED HYPERPARAMETERS BY DEPTH:") + print("-" * 80) + for d in [2, 4, 6, 8, 10, 12, 14, 16]: + rec_reg = max(0.01, reg_poly(d)) + rec_target = min(0.0, target_poly(d)) + print(f" Depth {d:2d}: λ_reg = {rec_reg:.3f}, λ_target = {rec_target:.3f}") + + else: + print("Not enough successful runs to fit curves") + + return analysis + + +def save_results(results: List[GridSearchResult], analysis: Dict, output_dir: str, config: Dict): + """Save grid search results.""" + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, "grid_search_results.json"), "w") as f: + json.dump(analysis, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def plot_grid_search(results: List[GridSearchResult], output_dir: str): + """Generate visualization of grid search results.""" + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not available, skipping plots") + return + + # Group by depth + by_depth = {} + for r in results: + if r.depth not in by_depth: + by_depth[r.depth] = [] + by_depth[r.depth].append(r) + + depths = sorted(by_depth.keys()) + + # Get unique lambda values + lambda_regs = sorted(set(r.lambda_reg for r in results)) + lambda_targets = sorted(set(r.lambda_target for r in results)) + + # Create heatmaps for each depth + n_depths = len(depths) + fig, axes = plt.subplots(2, (n_depths + 1) // 2, figsize=(5 * ((n_depths + 1) // 2), 10)) + axes = axes.flatten() + + for idx, depth in enumerate(depths): + ax = axes[idx] + depth_results = by_depth[depth] + + # Create accuracy matrix + acc_matrix = np.zeros((len(lambda_targets), len(lambda_regs))) + for r in depth_results: + i = lambda_targets.index(r.lambda_target) + j = lambda_regs.index(r.lambda_reg) + acc_matrix[i, j] = r.final_val_acc + + im = ax.imshow(acc_matrix, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto') + ax.set_xticks(range(len(lambda_regs))) + ax.set_xticklabels([f"{lr:.2f}" for lr in lambda_regs], rotation=45) + ax.set_yticks(range(len(lambda_targets))) + ax.set_yticklabels([f"{lt:.2f}" for lt in lambda_targets]) + ax.set_xlabel("λ_reg") + ax.set_ylabel("λ_target") + ax.set_title(f"Depth {depth}") + + # Mark best + best = max(depth_results, key=lambda r: r.final_val_acc) + bi = lambda_targets.index(best.lambda_target) + bj = lambda_regs.index(best.lambda_reg) + ax.scatter([bj], [bi], marker='*', s=200, c='blue', edgecolors='white', linewidths=2) + + # Add colorbar + plt.colorbar(im, ax=ax, label='Val Acc') + + # Hide unused subplots + for idx in range(len(depths), len(axes)): + axes[idx].axis('off') + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "grid_search_heatmaps.png"), dpi=150, bbox_inches='tight') + plt.close() + + # Plot optimal hyperparameters vs depth + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + optimal_regs = [] + optimal_targets = [] + optimal_accs = [] + for depth in depths: + best = max(by_depth[depth], key=lambda r: r.final_val_acc) + optimal_regs.append(best.lambda_reg) + optimal_targets.append(best.lambda_target) + optimal_accs.append(best.final_val_acc) + + axes[0].plot(depths, optimal_regs, 'o-', linewidth=2, markersize=8) + axes[0].set_xlabel("Network Depth") + axes[0].set_ylabel("Optimal λ_reg") + axes[0].set_title("Optimal Regularization Strength vs Depth") + axes[0].grid(True, alpha=0.3) + + axes[1].plot(depths, optimal_targets, 's-', linewidth=2, markersize=8, color='orange') + axes[1].set_xlabel("Network Depth") + axes[1].set_ylabel("Optimal λ_target") + axes[1].set_title("Optimal Target Lyapunov vs Depth") + axes[1].grid(True, alpha=0.3) + + axes[2].plot(depths, optimal_accs, '^-', linewidth=2, markersize=8, color='green') + axes[2].set_xlabel("Network Depth") + axes[2].set_ylabel("Best Validation Accuracy") + axes[2].set_title("Best Achievable Accuracy vs Depth") + axes[2].set_ylim(0, 1.05) + axes[2].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "optimal_hyperparameters.png"), dpi=150, bbox_inches='tight') + plt.close() + + print(f"Plots saved to {output_dir}") + + +def get_cifar10_loaders(batch_size=64, T=8, data_dir='./data'): + """Get CIFAR-10 dataloaders with rate encoding for SNN.""" + from torchvision import datasets, transforms + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform) + val_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform) + + # Rate encoding: convert images to spike sequences + class RateEncodedDataset(torch.utils.data.Dataset): + def __init__(self, dataset, T): + self.dataset = dataset + self.T = T + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + img, label = self.dataset[idx] + # img: (C, H, W) -> flatten to (C*H*W,) then expand to (T, D) + flat = img.view(-1) # (3072,) + # Rate encoding: probability of spike = pixel intensity + spikes = (torch.rand(self.T, flat.size(0)) < flat.unsqueeze(0)).float() + return spikes, label + + train_encoded = RateEncodedDataset(train_ds, T) + val_encoded = RateEncodedDataset(val_ds, T) + + train_loader = DataLoader(train_encoded, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_encoded, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_loader, val_loader, T, 3072, 10 # T, D, num_classes + + +def parse_args(): + p = argparse.ArgumentParser(description="Hyperparameter grid search for Lyapunov SNN") + + # Grid search parameters + p.add_argument("--depths", type=int, nargs="+", default=[4, 6, 8, 10], + help="Network depths to test") + p.add_argument("--lambda_regs", type=float, nargs="+", + default=[0.01, 0.05, 0.1, 0.2, 0.3], + help="Lambda_reg values to test") + p.add_argument("--lambda_targets", type=float, nargs="+", + default=[0.0, -0.05, -0.1, -0.2], + help="Lambda_target values to test") + + # Model parameters + p.add_argument("--hidden_dim", type=int, default=256) + p.add_argument("--epochs", type=int, default=15) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--seed", type=int, default=42) + + # Data + p.add_argument("--synthetic", action="store_true", help="Use synthetic data (default: CIFAR-10)") + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--T", type=int, default=8, help="Number of timesteps for rate encoding") + + # Output + p.add_argument("--out_dir", type=str, default="runs/grid_search") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("HYPERPARAMETER GRID SEARCH") + print("=" * 80) + print(f"Depths: {args.depths}") + print(f"λ_reg values: {args.lambda_regs}") + print(f"λ_target values: {args.lambda_targets}") + print(f"Total configurations: {len(args.depths) * len(args.lambda_regs) * len(args.lambda_targets)}") + print(f"Epochs per config: {args.epochs}") + print(f"Device: {device}") + print("=" * 80) + + # Load data + if args.synthetic: + print("\nUsing synthetic data") + train_loader, val_loader, T, D, C = create_synthetic_data( + seed=args.seed, batch_size=args.batch_size + ) + else: + print("\nUsing CIFAR-10 with rate encoding") + train_loader, val_loader, T, D, C = get_cifar10_loaders( + batch_size=args.batch_size, + T=args.T, + data_dir=args.data_dir + ) + + print(f"Data: T={T}, D={D}, classes={C}\n") + + # Run grid search + results = run_grid_search( + depths=args.depths, + lambda_regs=args.lambda_regs, + lambda_targets=args.lambda_targets, + train_loader=train_loader, + val_loader=val_loader, + input_dim=D, + num_classes=C, + hidden_dim=args.hidden_dim, + epochs=args.epochs, + lr=args.lr, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Analyze results + analysis = analyze_results(results) + + # Save results + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, analysis, output_dir, vars(args)) + + # Generate plots + plot_grid_search(results, output_dir) + + +if __name__ == "__main__": + main() |
