diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
| commit | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch) | |
| tree | 59a233959932ca0e4f12f196275e07fcf443b33f /files/experiments/posthoc_finetune.py | |
init commit
Diffstat (limited to 'files/experiments/posthoc_finetune.py')
| -rw-r--r-- | files/experiments/posthoc_finetune.py | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/files/experiments/posthoc_finetune.py b/files/experiments/posthoc_finetune.py new file mode 100644 index 0000000..3f3bf6c --- /dev/null +++ b/files/experiments/posthoc_finetune.py @@ -0,0 +1,323 @@ +""" +Post-hoc Lyapunov Fine-tuning Experiment + +Strategy: +1. Train network with vanilla (no Lyapunov) for N epochs +2. Then fine-tune with Lyapunov regularization for M epochs + +This allows the network to learn task-relevant features first, +then stabilize dynamics without starting from chaotic initialization. +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from files.experiments.depth_scaling_benchmark import ( + SpikingVGG, + get_dataset, + train_epoch, + evaluate, + TrainingMetrics, + compute_lyap_reg_loss, +) + + +def run_posthoc_experiment( + dataset_name: str, + depth_config: Tuple[int, int], + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + pretrain_epochs: int, + finetune_epochs: int, + lr: float, + finetune_lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + reg_type: str = "extreme", + lyap_threshold: float = 2.0, + progress: bool = True, +) -> Dict: + """Run post-hoc fine-tuning experiment.""" + torch.manual_seed(seed) + + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + print(f"\n{'='*60}") + print(f"POST-HOC FINE-TUNING: Depth = {total_depth}") + print(f"Pretrain: {pretrain_epochs} epochs (vanilla)") + print(f"Finetune: {finetune_epochs} epochs (Lyapunov, reg_type={reg_type})") + print(f"{'='*60}") + + model = SpikingVGG( + in_channels=in_channels, + num_classes=num_classes, + base_channels=64, + num_stages=num_stages, + blocks_per_stage=blocks_per_stage, + T=T, + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Parameters: {num_params:,}") + + criterion = nn.CrossEntropyLoss() + + # Phase 1: Vanilla pre-training + print(f"\n--- Phase 1: Vanilla Pre-training ({pretrain_epochs} epochs) ---") + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pretrain_epochs) + + pretrain_history = [] + best_pretrain_acc = 0.0 + + for epoch in range(1, pretrain_epochs + 1): + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov=False, # No Lyapunov during pre-training + lambda_reg=0, lambda_target=0, lyap_eps=1e-4, + progress=progress, + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_pretrain_acc = max(best_pretrain_acc, test_acc) + + metrics = TrainingMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + pretrain_history.append(metrics) + + if epoch % 10 == 0 or epoch == pretrain_epochs: + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f}") + + print(f" Best pretrain acc: {best_pretrain_acc:.3f}") + + # Phase 2: Lyapunov fine-tuning + print(f"\n--- Phase 2: Lyapunov Fine-tuning ({finetune_epochs} epochs) ---") + print(f" reg_type={reg_type}, lambda_reg={lambda_reg}, threshold={lyap_threshold}") + + # Reset optimizer with lower learning rate for fine-tuning + optimizer = optim.AdamW(model.parameters(), lr=finetune_lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=finetune_epochs) + + finetune_history = [] + best_finetune_acc = 0.0 + + for epoch in range(1, finetune_epochs + 1): + t0 = time.time() + + # Warmup lambda_reg over first 10 epochs of fine-tuning + warmup_epochs = 10 + if epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov=True, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=1e-4, + progress=progress, + reg_type=reg_type, + current_lambda_reg=current_lambda_reg, + lyap_threshold=lyap_threshold, + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_finetune_acc = max(best_finetune_acc, test_acc) + + metrics = TrainingMetrics( + epoch=pretrain_epochs + epoch, # Continue epoch numbering + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + finetune_history.append(metrics) + + if epoch % 10 == 0 or epoch == finetune_epochs: + lyap_str = f"λ={lyap:.3f}" if lyap else "" + print(f" Epoch {pretrain_epochs + epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str}") + + if np.isnan(train_loss): + print(f" DIVERGED at epoch {epoch}") + break + + print(f" Best finetune acc: {best_finetune_acc:.3f}") + print(f" Final λ: {finetune_history[-1].lyapunov:.3f}" if finetune_history[-1].lyapunov else "") + + return { + "depth": total_depth, + "pretrain_history": pretrain_history, + "finetune_history": finetune_history, + "best_pretrain_acc": best_pretrain_acc, + "best_finetune_acc": best_finetune_acc, + } + + +def main(): + parser = argparse.ArgumentParser(description="Post-hoc Lyapunov Fine-tuning") + parser.add_argument("--dataset", type=str, default="cifar100", + choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) + parser.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16]) + parser.add_argument("--T", type=int, default=4) + parser.add_argument("--pretrain_epochs", type=int, default=100) + parser.add_argument("--finetune_epochs", type=int, default=50) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--finetune_lr", type=float, default=1e-4) + parser.add_argument("--lambda_reg", type=float, default=0.1) + parser.add_argument("--lambda_target", type=float, default=-0.1) + parser.add_argument("--reg_type", type=str, default="extreme") + parser.add_argument("--lyap_threshold", type=float, default=2.0) + parser.add_argument("--data_dir", type=str, default="./data") + parser.add_argument("--out_dir", type=str, default="runs/posthoc_finetune") + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--no-progress", action="store_true") + + args = parser.parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("POST-HOC LYAPUNOV FINE-TUNING EXPERIMENT") + print("=" * 80) + print(f"Dataset: {args.dataset}") + print(f"Depths: {args.depths}") + print(f"Pretrain: {args.pretrain_epochs} epochs (vanilla, lr={args.lr})") + print(f"Finetune: {args.finetune_epochs} epochs (Lyapunov, lr={args.finetune_lr})") + print(f"Lyapunov: reg_type={args.reg_type}, λ_reg={args.lambda_reg}, threshold={args.lyap_threshold}") + print("=" * 80) + + # Load data + train_loader, test_loader, num_classes, input_shape = get_dataset( + args.dataset, args.data_dir, args.batch_size + ) + in_channels = input_shape[0] + + # Convert depths to configs + depth_configs = [] + for d in args.depths: + if d <= 4: + depth_configs.append((d, 1)) + else: + depth_configs.append((4, d // 4)) + + # Run experiments + all_results = [] + for depth_config in depth_configs: + result = run_posthoc_experiment( + dataset_name=args.dataset, + depth_config=depth_config, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=args.T, + pretrain_epochs=args.pretrain_epochs, + finetune_epochs=args.finetune_epochs, + lr=args.lr, + finetune_lr=args.finetune_lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + reg_type=args.reg_type, + lyap_threshold=args.lyap_threshold, + progress=not args.no_progress, + ) + all_results.append(result) + + # Summary + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + print(f"{'Depth':<8} {'Pretrain Acc':<15} {'Finetune Acc':<15} {'Change':<10} {'Final λ':<10}") + print("-" * 80) + + for r in all_results: + pre_acc = r["best_pretrain_acc"] + fine_acc = r["best_finetune_acc"] + change = fine_acc - pre_acc + final_lyap = r["finetune_history"][-1].lyapunov if r["finetune_history"] else None + lyap_str = f"{final_lyap:.3f}" if final_lyap else "N/A" + change_str = f"{change:+.3f}" + + print(f"{r['depth']:<8} {pre_acc:<15.3f} {fine_acc:<15.3f} {change_str:<10} {lyap_str:<10}") + + print("=" * 80) + + # Save results + os.makedirs(args.out_dir, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + output_file = os.path.join(args.out_dir, f"{args.dataset}_{ts}.json") + + serializable_results = [] + for r in all_results: + sr = { + "depth": r["depth"], + "best_pretrain_acc": r["best_pretrain_acc"], + "best_finetune_acc": r["best_finetune_acc"], + "pretrain_history": [asdict(m) for m in r["pretrain_history"]], + "finetune_history": [asdict(m) for m in r["finetune_history"]], + } + serializable_results.append(sr) + + with open(output_file, "w") as f: + json.dump({"config": vars(args), "results": serializable_results}, f, indent=2) + + print(f"\nResults saved to {output_file}") + + +if __name__ == "__main__": + main() |
