summaryrefslogtreecommitdiff
path: root/files/experiments/posthoc_finetune.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/experiments/posthoc_finetune.py')
-rw-r--r--files/experiments/posthoc_finetune.py323
1 files changed, 323 insertions, 0 deletions
diff --git a/files/experiments/posthoc_finetune.py b/files/experiments/posthoc_finetune.py
new file mode 100644
index 0000000..3f3bf6c
--- /dev/null
+++ b/files/experiments/posthoc_finetune.py
@@ -0,0 +1,323 @@
+"""
+Post-hoc Lyapunov Fine-tuning Experiment
+
+Strategy:
+1. Train network with vanilla (no Lyapunov) for N epochs
+2. Then fine-tune with Lyapunov regularization for M epochs
+
+This allows the network to learn task-relevant features first,
+then stabilize dynamics without starting from chaotic initialization.
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from tqdm.auto import tqdm
+
+from files.experiments.depth_scaling_benchmark import (
+ SpikingVGG,
+ get_dataset,
+ train_epoch,
+ evaluate,
+ TrainingMetrics,
+ compute_lyap_reg_loss,
+)
+
+
+def run_posthoc_experiment(
+ dataset_name: str,
+ depth_config: Tuple[int, int],
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ pretrain_epochs: int,
+ finetune_epochs: int,
+ lr: float,
+ finetune_lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ reg_type: str = "extreme",
+ lyap_threshold: float = 2.0,
+ progress: bool = True,
+) -> Dict:
+ """Run post-hoc fine-tuning experiment."""
+ torch.manual_seed(seed)
+
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ print(f"\n{'='*60}")
+ print(f"POST-HOC FINE-TUNING: Depth = {total_depth}")
+ print(f"Pretrain: {pretrain_epochs} epochs (vanilla)")
+ print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})")
+ print(f"{'='*60}")
+
+ model = SpikingVGG(
+ in_channels=in_channels,
+ num_classes=num_classes,
+ base_channels=64,
+ num_stages=num_stages,
+ blocks_per_stage=blocks_per_stage,
+ T=T,
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"Parameters: {num_params:,}")
+
+ criterion = nn.CrossEntropyLoss()
+
+ # Phase 1: Vanilla pre-training
+ print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---")
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs)
+
+ pretrain_history = []
+ best_pretrain_acc = 0.0
+
+ for epoch in range(1, pretrain_epochs + 1):
+ t0 = time.time()
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov=False, # No Lyapunov during pre-training
+ lambda_reg=0, lambda_target=0, lyap_eps=1e-4,
+ progress=progress,
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_pretrain_acc = max(best_pretrain_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ pretrain_history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == pretrain_epochs:
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}")
+
+ print(f" Best pretrain acc: {best_pretrain_acc:.3f}")
+
+ # Phase 2: Lyapunov fine-tuning
+ print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---")
+ print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}")
+
+ # Reset optimizer with lower learning rate for fine-tuning
+ optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs)
+
+ finetune_history = []
+ best_finetune_acc = 0.0
+
+ for epoch in range(1, finetune_epochs + 1):
+ t0 = time.time()
+
+ # Warmup lambda_reg over first 10 epochs of fine-tuning
+ warmup_epochs = 10
+ if epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov=True,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ lyap_eps=1e-4,
+ progress=progress,
+ reg_type=reg_type,
+ current_lambda_reg=current_lambda_reg,
+ lyap_threshold=lyap_threshold,
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_finetune_acc = max(best_finetune_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=pretrain_epochs + epoch, # Continue epoch numbering
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ finetune_history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == finetune_epochs:
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}")
+
+ if np.isnan(train_loss):
+ print(f" DIVERGED at epoch {epoch}")
+ break
+
+ print(f" Best finetune acc: {best_finetune_acc:.3f}")
+ print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "")
+
+ return {
+ "depth": total_depth,
+ "pretrain_history": pretrain_history,
+ "finetune_history": finetune_history,
+ "best_pretrain_acc": best_pretrain_acc,
+ "best_finetune_acc": best_finetune_acc,
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning")
+ parser.add_argument("--dataset", type=str, default="cifar100",
+ choices=["mnist", "fashion_mnist", "cifar10", "cifar100"])
+ parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16])
+ parser.add_argument("--T", type=int, default=4)
+ parser.add_argument("--pretrain_epochs", type=int, default=100)
+ parser.add_argument("--finetune_epochs", type=int, default=50)
+ parser.add_argument("--batch_size", type=int, default=128)
+ parser.add_argument("--lr", type=float, default=1e-3)
+ parser.add_argument("--finetune_lr", type=float, default=1e-4)
+ parser.add_argument("--lambda_reg", type=float, default=0.1)
+ parser.add_argument("--lambda_target", type=float, default=-0.1)
+ parser.add_argument("--reg_type", type=str, default="extreme")
+ parser.add_argument("--lyap_threshold", type=float, default=2.0)
+ parser.add_argument("--data_dir", type=str, default="./data")
+ parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune")
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--no-progress", action="store_true")
+
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT")
+ print("=" * 80)
+ print(f"Dataset: {args.dataset}")
+ print(f"Depths: {args.depths}")
+ print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})")
+ print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})")
+ print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}")
+ print("=" * 80)
+
+ # Load data
+ train_loader, test_loader, num_classes, input_shape = get_dataset(
+ args.dataset, args.data_dir, args.batch_size
+ )
+ in_channels = input_shape[0]
+
+ # Convert depths to configs
+ depth_configs = []
+ for d in args.depths:
+ if d <= 4:
+ depth_configs.append((d, 1))
+ else:
+ depth_configs.append((4, d // 4))
+
+ # Run experiments
+ all_results = []
+ for depth_config in depth_configs:
+ result = run_posthoc_experiment(
+ dataset_name=args.dataset,
+ depth_config=depth_config,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=args.T,
+ pretrain_epochs=args.pretrain_epochs,
+ finetune_epochs=args.finetune_epochs,
+ lr=args.lr,
+ finetune_lr=args.finetune_lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ reg_type=args.reg_type,
+ lyap_threshold=args.lyap_threshold,
+ progress=not args.no_progress,
+ )
+ all_results.append(result)
+
+ # Summary
+ print("\n" + "=" * 80)
+ print("SUMMARY")
+ print("=" * 80)
+ print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}")
+ print("-" * 80)
+
+ for r in all_results:
+ pre_acc = r["best_pretrain_acc"]
+ fine_acc = r["best_finetune_acc"]
+ change = fine_acc - pre_acc
+ final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None
+ lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A"
+ change_str = f"{change:+.3f}"
+
+ print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}")
+
+ print("=" * 80)
+
+ # Save results
+ os.makedirs(args.out_dir, exist_ok=True)
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json")
+
+ serializable_results = []
+ for r in all_results:
+ sr = {
+ "depth": r["depth"],
+ "best_pretrain_acc": r["best_pretrain_acc"],
+ "best_finetune_acc": r["best_finetune_acc"],
+ "pretrain_history": [asdict(m) for m in r["pretrain_history"]],
+ "finetune_history": [asdict(m) for m in r["finetune_history"]],
+ }
+ serializable_results.append(sr)
+
+ with open(output_file, "w") as f:
+ json.dump({"config": vars(args), "results": serializable_results}, f, indent=2)
+
+ print(f"\nResults saved to {output_file}")
+
+
+if __name__ == "__main__":
+ main()