diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-22 16:28:45 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-22 16:28:45 -0600 |
| commit | bd4d87fdc313e757a90d4c0c4c7a02126b092a38 (patch) | |
| tree | 65bfc5984ca87973abe8175d880354dbb35022fa | |
| parent | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (diff) | |
Plots all depth-scaling experiment variants (squared, hinge, asymmetric,
extreme, stable_init, target1, weak_reg) at depth=12 with settings annotations.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| -rw-r--r-- | files/analysis/plot_experiment_comparison.py | 275 | ||||
| -rw-r--r-- | runs/experiment_comparison.png | bin | 0 -> 405007 bytes |
2 files changed, 275 insertions, 0 deletions
diff --git a/files/analysis/plot_experiment_comparison.py b/files/analysis/plot_experiment_comparison.py new file mode 100644 index 0000000..bb9990b --- /dev/null +++ b/files/analysis/plot_experiment_comparison.py @@ -0,0 +1,275 @@ +""" +Plot comparison of all depth-scaling experiments showing: + - Task loss (CE only, estimated by subtracting reg loss from total) + - Lyapunov regularization loss + - λ(t) (Lyapunov exponent over time) + - Gradient norm + +Each experiment variant is shown as a separate line, with experiment settings in the legend. +""" + +import json +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from dataclasses import dataclass +from typing import List, Optional, Dict + +RUNS_DIR = Path(__file__).parent.parent.parent / "runs" + + +@dataclass +class ExperimentConfig: + """Configuration for an experiment.""" + name: str # display name + path: Path # path to results dir + dataset: str + reg_type: str + lambda_reg: float + lambda_target: float + warmup_epochs: int + stable_init: bool + epochs: int + depths: List[int] + + +def compute_reg_loss(lyap_value: float, reg_type: str, lambda_target: float, + lyap_threshold: float = 2.0) -> float: + """Compute the regularization loss value (before multiplying by lambda_reg).""" + if lyap_value is None: + return 0.0 + + if reg_type == "squared": + return (lyap_value - lambda_target) ** 2 + elif reg_type == "hinge": + excess = max(0, lyap_value) + return excess ** 2 + elif reg_type == "asymmetric": + chaos = max(0, lyap_value) ** 2 + collapse = 0.1 * max(0, -lyap_value - 1.0) ** 2 + return chaos + collapse + elif reg_type == "extreme": + excess = max(0, lyap_value - lyap_threshold) + return excess ** 2 + else: + return (lyap_value - lambda_target) ** 2 + + +def get_effective_lambda(epoch: int, lambda_reg: float, warmup_epochs: int) -> float: + """Get the effective lambda_reg at a given epoch (accounts for warmup).""" + if warmup_epochs > 0 and epoch <= warmup_epochs: + return lambda_reg * (epoch / warmup_epochs) + return lambda_reg + + +def load_experiment(exp_dir: Path) -> Optional[ExperimentConfig]: + """Load experiment config from a directory.""" + config_path = exp_dir / "config.json" + results_path = exp_dir / "results.json" + + if not config_path.exists() or not results_path.exists(): + return None + + with open(config_path) as f: + config = json.load(f) + + reg_type = config.get("reg_type", "squared") + warmup = config.get("warmup_epochs", 0) + stable_init = config.get("stable_init", False) + + # Build display name with settings + parts = [f"{reg_type}"] + parts.append(f"λ_reg={config['lambda_reg']}") + parts.append(f"λ_target={config['lambda_target']}") + if warmup > 0: + parts.append(f"warmup={warmup}") + if stable_init: + parts.append("stable_init") + + name = ", ".join(parts) + + return ExperimentConfig( + name=name, + path=exp_dir, + dataset=config["dataset"], + reg_type=reg_type, + lambda_reg=config["lambda_reg"], + lambda_target=config["lambda_target"], + warmup_epochs=warmup, + stable_init=stable_init, + epochs=config["epochs"], + depths=config["depths"], + ) + + +def load_results(exp: ExperimentConfig, depth: int) -> Optional[Dict]: + """Load results for a specific depth from an experiment.""" + results_path = exp.path / "results.json" + with open(results_path) as f: + data = json.load(f) + + depth_key = str(depth) + if "lyapunov" not in data or depth_key not in data["lyapunov"]: + return None + + epochs_data = data["lyapunov"][depth_key] + + epochs = [] + train_losses = [] + lyap_values = [] + grad_norms = [] + lyap_reg_losses = [] + task_losses = [] + + for entry in epochs_data: + epoch = entry["epoch"] + train_loss = entry["train_loss"] + lyap = entry["lyapunov"] + grad_norm = entry["grad_norm"] + + if train_loss is None or np.isnan(train_loss): + break + + epochs.append(epoch) + train_losses.append(train_loss) + lyap_values.append(lyap) + grad_norms.append(grad_norm) + + # Compute effective lambda and reg loss + eff_lambda = get_effective_lambda(epoch, exp.lambda_reg, exp.warmup_epochs) + reg_value = compute_reg_loss(lyap, exp.reg_type, exp.lambda_target) + lyap_loss = eff_lambda * reg_value + lyap_reg_losses.append(lyap_loss) + + # Estimate task loss = total - reg + task_loss = train_loss - lyap_loss if lyap is not None else train_loss + task_losses.append(task_loss) + + if not epochs: + return None + + return { + "epochs": np.array(epochs), + "task_loss": np.array(task_losses), + "lyap_reg_loss": np.array(lyap_reg_losses), + "lyap_values": np.array([v if v is not None else np.nan for v in lyap_values]), + "grad_norms": np.array(grad_norms), + } + + +def find_experiments() -> List[ExperimentConfig]: + """Find all depth-scaling experiments.""" + experiments = [] + + # depth_scaling (baseline squared reg) + for subdir in sorted(RUNS_DIR.glob("depth_scaling/cifar*")): + exp = load_experiment(subdir) + if exp: + exp.name = f"{exp.dataset}: {exp.name}" + experiments.append(exp) + + # depth_scaling variants + variant_dirs = [ + "depth_scaling_asymm", + "depth_scaling_extreme", + "depth_scaling_hinge", + "depth_scaling_stable_init", + "depth_scaling_target1", + "depth_scaling_weak_reg", + ] + + for variant in variant_dirs: + variant_path = RUNS_DIR / variant + if variant_path.exists(): + for subdir in sorted(variant_path.glob("cifar*")): + exp = load_experiment(subdir) + if exp: + exp.name = f"{exp.dataset}: {exp.name}" + experiments.append(exp) + + return experiments + + +def plot_comparison(depth: int = 12, save_path: Optional[str] = None): + """Create 4-panel comparison plot for all experiments at a given depth.""" + experiments = find_experiments() + + if not experiments: + print("No experiments found!") + return + + fig, axes = plt.subplots(2, 2, figsize=(16, 12)) + fig.suptitle(f"Experiment Comparison (depth={depth})", fontsize=14, fontweight="bold") + + colors = plt.cm.tab10(np.linspace(0, 1, len(experiments))) + + plotted = [] + + for i, exp in enumerate(experiments): + results = load_results(exp, depth) + if results is None: + print(f" Skipping {exp.name}: no data for depth={depth}") + continue + + color = colors[i] + label = exp.name + + # Task Loss + axes[0, 0].plot(results["epochs"], results["task_loss"], + color=color, label=label, linewidth=1.5, alpha=0.85) + + # Lyapunov Reg Loss + axes[0, 1].plot(results["epochs"], results["lyap_reg_loss"], + color=color, label=label, linewidth=1.5, alpha=0.85) + + # λ(t) + axes[1, 0].plot(results["epochs"], results["lyap_values"], + color=color, label=label, linewidth=1.5, alpha=0.85) + + # Gradient Norm + axes[1, 1].plot(results["epochs"], results["grad_norms"], + color=color, label=label, linewidth=1.5, alpha=0.85) + + plotted.append(label) + + # Configure axes + axes[0, 0].set_title("Task Loss (CE)") + axes[0, 0].set_xlabel("Epoch") + axes[0, 0].set_ylabel("Loss") + axes[0, 0].grid(True, alpha=0.3) + + axes[0, 1].set_title("Lyapunov Regularization Loss") + axes[0, 1].set_xlabel("Epoch") + axes[0, 1].set_ylabel("Loss") + axes[0, 1].grid(True, alpha=0.3) + + axes[1, 0].set_title(r"$\lambda(t)$ (Lyapunov Exponent)") + axes[1, 0].set_xlabel("Epoch") + axes[1, 0].set_ylabel(r"$\lambda$") + axes[1, 0].axhline(y=0, color="black", linestyle="--", alpha=0.5, linewidth=0.8) + axes[1, 0].grid(True, alpha=0.3) + + axes[1, 1].set_title("Gradient Norm") + axes[1, 1].set_xlabel("Epoch") + axes[1, 1].set_ylabel(r"$\|\nabla\|_2$") + axes[1, 1].grid(True, alpha=0.3) + + # Add legend below the plots + handles, labels = axes[0, 0].get_legend_handles_labels() + if handles: + fig.legend(handles, labels, loc="lower center", ncol=2, + fontsize=8.5, bbox_to_anchor=(0.5, -0.02), + frameon=True, fancybox=True) + + plt.tight_layout(rect=[0, 0.06, 1, 0.96]) + + if save_path is None: + save_path = str(RUNS_DIR / "experiment_comparison.png") + + plt.savefig(save_path, dpi=150, bbox_inches="tight") + print(f"Saved plot to {save_path}") + plt.close() + + +if __name__ == "__main__": + plot_comparison(depth=12) diff --git a/runs/experiment_comparison.png b/runs/experiment_comparison.png Binary files differnew file mode 100644 index 0000000..973ff57 --- /dev/null +++ b/runs/experiment_comparison.png |
