diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/cifar10_conv_experiment.py | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/experiments/cifar10_conv_experiment.py')
| -rw-r--r-- | files/experiments/cifar10_conv_experiment.py | 448 |
1 files changed, 448 insertions, 0 deletions
diff --git a/files/experiments/cifar10_conv_experiment.py b/files/experiments/cifar10_conv_experiment.py new file mode 100644 index 0000000..a582f9f --- /dev/null +++ b/files/experiments/cifar10_conv_experiment.py @@ -0,0 +1,448 @@ +""" +CIFAR-10 Conv-SNN Experiment with Lyapunov Regularization. + +Uses proper convolutional architecture that preserves spatial structure. +Tests whether Lyapunov regularization helps train deeper Conv-SNNs. + +Architecture: + Image (3,32,32) → Rate Encoding → Conv-LIF-Pool layers → FC → Output + +Usage: + python files/experiments/cifar10_conv_experiment.py --model simple --T 25 + python files/experiments/cifar10_conv_experiment.py --model vgg --T 50 --lyapunov +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +from files.models.conv_snn import create_conv_snn + + +@dataclass +class EpochMetrics: + epoch: int + train_loss: float + train_acc: float + val_loss: float + val_acc: float + lyapunov: Optional[float] + grad_norm: float + time_sec: float + + +def get_cifar10_loaders( + data_dir: str = './data', + batch_size: int = 128, + num_workers: int = 4, +) -> Tuple[DataLoader, DataLoader]: + """ + Get CIFAR-10 dataloaders with standard normalization. + + Images normalized to [0, 1] for rate encoding. + """ + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + # Note: For rate encoding, we keep values in [0, 1] + # No normalization to negative values + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + ]) + + train_dataset = datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transform_train + ) + test_dataset = datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transform_test + ) + + train_loader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True + ) + test_loader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True + ) + + return train_loader, test_loader + + +def train_epoch( + model: nn.Module, + loader: DataLoader, + optimizer: optim.Optimizer, + ce_loss: nn.Module, + device: torch.device, + use_lyapunov: bool, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + progress: bool = True, +) -> Tuple[float, float, Optional[float], float]: + """Train one epoch.""" + model.train() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + lyap_vals = [] + grad_norms = [] + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) # x: (B, 3, 32, 32) + + optimizer.zero_grad() + + logits, lyap_est, _ = model( + x, + compute_lyapunov=use_lyapunov, + lyap_eps=lyap_eps, + ) + + ce = ce_loss(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = (lyap_est - lambda_target) ** 2 + loss = ce + lambda_reg * reg + lyap_vals.append(lyap_est.item()) + else: + loss = ce + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan') + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + optimizer.step() + + grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 + grad_norms.append(grad_norm) + + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + if progress: + iterator.set_postfix({ + "loss": f"{loss.item():.3f}", + "acc": f"{total_correct/total_samples:.3f}", + }) + + return ( + total_loss / total_samples, + total_correct / total_samples, + np.mean(lyap_vals) if lyap_vals else None, + np.mean(grad_norms), + ) + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader: DataLoader, + ce_loss: nn.Module, + device: torch.device, + progress: bool = True, +) -> Tuple[float, float]: + """Evaluate on test set.""" + model.eval() + total_loss = 0.0 + total_correct = 0 + total_samples = 0 + + iterator = tqdm(loader, desc="eval", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + + loss = ce_loss(logits, y) + total_loss += loss.item() * x.size(0) + preds = logits.argmax(dim=1) + total_correct += (preds == y).sum().item() + total_samples += x.size(0) + + return total_loss / total_samples, total_correct / total_samples + + +def run_experiment( + model_type: str, + channels: List[int], + T: int, + use_lyapunov: bool, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + lyap_eps: float, + device: torch.device, + seed: int, + progress: bool = True, +) -> List[EpochMetrics]: + """Run single experiment.""" + torch.manual_seed(seed) + + model = create_conv_snn( + model_type=model_type, + in_channels=3, + num_classes=10, + channels=channels, + T=T, + encoding='rate', + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f" Model: {model_type}, params: {num_params:,}") + + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + ce_loss = nn.CrossEntropyLoss() + + metrics_history = [] + best_acc = 0.0 + + for epoch in range(1, epochs + 1): + t0 = time.time() + + train_loss, train_acc, lyap, grad_norm = train_epoch( + model, train_loader, optimizer, ce_loss, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress + ) + + test_loss, test_acc = evaluate(model, test_loader, ce_loss, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_acc = max(best_acc, test_acc) + + metrics = EpochMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + val_loss=test_loss, + val_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + time_sec=dt, + ) + metrics_history.append(metrics) + + lyap_str = f"λ={lyap:.3f}" if lyap else "" + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} ({dt:.1f}s)") + + if np.isnan(train_loss): + print(" Training diverged!") + break + + print(f" Best test accuracy: {best_acc:.3f}") + return metrics_history + + +def run_comparison( + model_type: str, + channels_configs: List[List[int]], + T: int, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool, +) -> Dict: + """Compare vanilla vs Lyapunov across different depths.""" + results = {"vanilla": {}, "lyapunov": {}} + + for channels in channels_configs: + depth = len(channels) + print(f"\n{'='*60}") + print(f"Depth = {depth} conv layers, channels = {channels}") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + print(f"\n Training {method.upper()}...") + + metrics = run_experiment( + model_type=model_type, + channels=channels, + T=T, + use_lyapunov=use_lyap, + train_loader=train_loader, + test_loader=test_loader, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + lyap_eps=1e-4, + device=device, + seed=seed, + progress=progress, + ) + + results[method][depth] = metrics + + return results + + +def print_summary(results: Dict): + """Print comparison summary.""" + print("\n" + "=" * 70) + print("SUMMARY: CIFAR-10 Conv-SNN Results") + print("=" * 70) + print(f"{'Depth':<8} {'Vanilla':<15} {'Lyapunov':<15} {'Improvement':<15}") + print("-" * 70) + + depths = sorted(results["vanilla"].keys()) + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.val_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.val_acc if not np.isnan(lyap.train_loss) else 0.0 + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff > 0 else f"{diff:.3f}" + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + + print(f"{depth:<8} {van_str:<15} {lyap_str:<15} {diff_str:<15}") + + print("=" * 70) + + +def save_results(results: Dict, output_dir: str, config: Dict): + """Save results.""" + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, metrics_list in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in metrics_list] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser() + + # Model + p.add_argument("--model", type=str, default="simple", choices=["simple", "vgg"]) + p.add_argument("--channels", type=int, nargs="+", default=None, + help="Channel sizes (default: test multiple depths)") + p.add_argument("--T", type=int, default=25, help="Timesteps") + + # Training + p.add_argument("--epochs", type=int, default=50) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + + # Lyapunov + p.add_argument("--lambda_reg", type=float, default=0.3) + p.add_argument("--lambda_target", type=float, default=-0.1) + + # Other + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--out_dir", type=str, default="runs/cifar10_conv") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--no-progress", action="store_true") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 70) + print("CIFAR-10 Conv-SNN Experiment") + print("=" * 70) + print(f"Model: {args.model}") + print(f"Timesteps: {args.T}") + print(f"Epochs: {args.epochs}") + print(f"Device: {device}") + print("=" * 70) + + # Load data + print("\nLoading CIFAR-10...") + train_loader, test_loader = get_cifar10_loaders( + data_dir=args.data_dir, + batch_size=args.batch_size, + ) + print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") + + # Define depth configurations to test + if args.channels: + channels_configs = [args.channels] + else: + # Test increasing depths + channels_configs = [ + [64, 128], # 2 conv layers (shallow) + [64, 128, 256], # 3 conv layers + [64, 128, 256, 512], # 4 conv layers (deep) + ] + + # Run comparison + results = run_comparison( + model_type=args.model, + channels_configs=channels_configs, + T=args.T, + train_loader=train_loader, + test_loader=test_loader, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + progress=not args.no_progress, + ) + + # Summary + print_summary(results) + + # Save + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, ts) + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() |
