""" Post-hoc Lyapunov Fine-tuning Experiment Strategy: 1. Train network with vanilla (no Lyapunov) for N epochs 2. Then fine-tune with Lyapunov regularization for M epochs This allows the network to learn task-relevant features first, then stabilize dynamics without starting from chaotic initialization. """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Optional, Tuple _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from tqdm.auto import tqdm from files.experiments.depth_scaling_benchmark import ( SpikingVGG, get_dataset, train_epoch, evaluate, TrainingMetrics, compute_lyap_reg_loss, ) def run_posthoc_experiment( dataset_name: str, depth_config: Tuple[int, int], train_loader: DataLoader, test_loader: DataLoader, num_classes: int, in_channels: int, T: int, pretrain_epochs: int, finetune_epochs: int, lr: float, finetune_lr: float, lambda_reg: float, lambda_target: float, device: torch.device, seed: int, reg_type: str = "extreme", lyap_threshold: float = 2.0, progress: bool = True, ) -> Dict: """Run post-hoc fine-tuning experiment.""" torch.manual_seed(seed) num_stages, blocks_per_stage = depth_config total_depth = num_stages * blocks_per_stage print(f"\n{'='*60}") print(f"POST-HOC FINE-TUNING: Depth = {total_depth}") print(f"Pretrain: {pretrain_epochs} epochs (vanilla)") print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})") print(f"{'='*60}") model = SpikingVGG( in_channels=in_channels, num_classes=num_classes, base_channels=64, num_stages=num_stages, blocks_per_stage=blocks_per_stage, T=T, ).to(device) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Parameters: {num_params:,}") criterion = nn.CrossEntropyLoss() # Phase 1: Vanilla pre-training print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---") optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs) pretrain_history = [] best_pretrain_acc = 0.0 for epoch in range(1, pretrain_epochs + 1): t0 = time.time() train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( model, train_loader, optimizer, criterion, device, use_lyapunov=False, # No Lyapunov during pre-training lambda_reg=0, lambda_target=0, lyap_eps=1e-4, progress=progress, ) test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) scheduler.step() dt = time.time() - t0 best_pretrain_acc = max(best_pretrain_acc, test_acc) metrics = TrainingMetrics( epoch=epoch, train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc, lyapunov=lyap, grad_norm=grad_norm, grad_max_sv=grad_max_sv, grad_min_sv=grad_min_sv, grad_condition=grad_cond, lr=scheduler.get_last_lr()[0], time_sec=dt, ) pretrain_history.append(metrics) if epoch % 10 == 0 or epoch == pretrain_epochs: print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}") print(f" Best pretrain acc: {best_pretrain_acc:.3f}") # Phase 2: Lyapunov fine-tuning print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---") print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}") # Reset optimizer with lower learning rate for fine-tuning optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs) finetune_history = [] best_finetune_acc = 0.0 for epoch in range(1, finetune_epochs + 1): t0 = time.time() # Warmup lambda_reg over first 10 epochs of fine-tuning warmup_epochs = 10 if epoch <= warmup_epochs: current_lambda_reg = lambda_reg * (epoch / warmup_epochs) else: current_lambda_reg = lambda_reg train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( model, train_loader, optimizer, criterion, device, use_lyapunov=True, lambda_reg=lambda_reg, lambda_target=lambda_target, lyap_eps=1e-4, progress=progress, reg_type=reg_type, current_lambda_reg=current_lambda_reg, lyap_threshold=lyap_threshold, ) test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) scheduler.step() dt = time.time() - t0 best_finetune_acc = max(best_finetune_acc, test_acc) metrics = TrainingMetrics( epoch=pretrain_epochs + epoch, # Continue epoch numbering train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc, lyapunov=lyap, grad_norm=grad_norm, grad_max_sv=grad_max_sv, grad_min_sv=grad_min_sv, grad_condition=grad_cond, lr=scheduler.get_last_lr()[0], time_sec=dt, ) finetune_history.append(metrics) if epoch % 10 == 0 or epoch == finetune_epochs: lyap_str = f"λ={lyap:.3f}" if lyap else "" print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}") if np.isnan(train_loss): print(f" DIVERGED at epoch {epoch}") break print(f" Best finetune acc: {best_finetune_acc:.3f}") print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "") return { "depth": total_depth, "pretrain_history": pretrain_history, "finetune_history": finetune_history, "best_pretrain_acc": best_pretrain_acc, "best_finetune_acc": best_finetune_acc, } def main(): parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning") parser.add_argument("--dataset", type=str, default="cifar100", choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16]) parser.add_argument("--T", type=int, default=4) parser.add_argument("--pretrain_epochs", type=int, default=100) parser.add_argument("--finetune_epochs", type=int, default=50) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--finetune_lr", type=float, default=1e-4) parser.add_argument("--lambda_reg", type=float, default=0.1) parser.add_argument("--lambda_target", type=float, default=-0.1) parser.add_argument("--reg_type", type=str, default="extreme") parser.add_argument("--lyap_threshold", type=float, default=2.0) parser.add_argument("--data_dir", type=str, default="./data") parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--no-progress", action="store_true") args = parser.parse_args() device = torch.device(args.device) print("=" * 80) print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT") print("=" * 80) print(f"Dataset: {args.dataset}") print(f"Depths: {args.depths}") print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})") print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})") print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}") print("=" * 80) # Load data train_loader, test_loader, num_classes, input_shape = get_dataset( args.dataset, args.data_dir, args.batch_size ) in_channels = input_shape[0] # Convert depths to configs depth_configs = [] for d in args.depths: if d <= 4: depth_configs.append((d, 1)) else: depth_configs.append((4, d // 4)) # Run experiments all_results = [] for depth_config in depth_configs: result = run_posthoc_experiment( dataset_name=args.dataset, depth_config=depth_config, train_loader=train_loader, test_loader=test_loader, num_classes=num_classes, in_channels=in_channels, T=args.T, pretrain_epochs=args.pretrain_epochs, finetune_epochs=args.finetune_epochs, lr=args.lr, finetune_lr=args.finetune_lr, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, device=device, seed=args.seed, reg_type=args.reg_type, lyap_threshold=args.lyap_threshold, progress=not args.no_progress, ) all_results.append(result) # Summary print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}") print("-" * 80) for r in all_results: pre_acc = r["best_pretrain_acc"] fine_acc = r["best_finetune_acc"] change = fine_acc - pre_acc final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A" change_str = f"{change:+.3f}" print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}") print("=" * 80) # Save results os.makedirs(args.out_dir, exist_ok=True) ts = time.strftime("%Y%m%d-%H%M%S") output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json") serializable_results = [] for r in all_results: sr = { "depth": r["depth"], "best_pretrain_acc": r["best_pretrain_acc"], "best_finetune_acc": r["best_finetune_acc"], "pretrain_history": [asdict(m) for m in r["pretrain_history"]], "finetune_history": [asdict(m) for m in r["finetune_history"]], } serializable_results.append(sr) with open(output_file, "w") as f: json.dump({"config": vars(args), "results": serializable_results}, f, indent=2) print(f"\nResults saved to {output_file}") if __name__ == "__main__": main()