summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-22 16:28:45 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-22 16:28:45 -0600
commitbd4d87fdc313e757a90d4c0c4c7a02126b092a38 (patch)
tree65bfc5984ca87973abe8175d880354dbb35022fa
parent00cf667cee7ffacb144d5805fc7e0ef443f3583a (diff)
Add experiment comparison plot showing task loss, Lyapunov loss, λ(t), and gradient normHEADmain
Plots all depth-scaling experiment variants (squared, hinge, asymmetric, extreme, stable_init, target1, weak_reg) at depth=12 with settings annotations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
-rw-r--r--files/analysis/plot_experiment_comparison.py275
-rw-r--r--runs/experiment_comparison.pngbin0 -> 405007 bytes
2 files changed, 275 insertions, 0 deletions
diff --git a/files/analysis/plot_experiment_comparison.py b/files/analysis/plot_experiment_comparison.py
new file mode 100644
index 0000000..bb9990b
--- /dev/null
+++ b/files/analysis/plot_experiment_comparison.py
@@ -0,0 +1,275 @@
+"""
+Plot comparison of all depth-scaling experiments showing:
+ - Task loss (CE only, estimated by subtracting reg loss from total)
+ - Lyapunov regularization loss
+ - λ(t) (Lyapunov exponent over time)
+ - Gradient norm
+
+Each experiment variant is shown as a separate line, with experiment settings in the legend.
+"""
+
+import json
+import numpy as np
+import matplotlib.pyplot as plt
+from pathlib import Path
+from dataclasses import dataclass
+from typing import List, Optional, Dict
+
+RUNS_DIR = Path(__file__).parent.parent.parent / "runs"
+
+
+@dataclass
+class ExperimentConfig:
+ """Configuration for an experiment."""
+ name: str # display name
+ path: Path # path to results dir
+ dataset: str
+ reg_type: str
+ lambda_reg: float
+ lambda_target: float
+ warmup_epochs: int
+ stable_init: bool
+ epochs: int
+ depths: List[int]
+
+
+def compute_reg_loss(lyap_value: float, reg_type: str, lambda_target: float,
+ lyap_threshold: float = 2.0) -> float:
+ """Compute the regularization loss value (before multiplying by lambda_reg)."""
+ if lyap_value is None:
+ return 0.0
+
+ if reg_type == "squared":
+ return (lyap_value - lambda_target) ** 2
+ elif reg_type == "hinge":
+ excess = max(0, lyap_value)
+ return excess ** 2
+ elif reg_type == "asymmetric":
+ chaos = max(0, lyap_value) ** 2
+ collapse = 0.1 * max(0, -lyap_value - 1.0) ** 2
+ return chaos + collapse
+ elif reg_type == "extreme":
+ excess = max(0, lyap_value - lyap_threshold)
+ return excess ** 2
+ else:
+ return (lyap_value - lambda_target) ** 2
+
+
+def get_effective_lambda(epoch: int, lambda_reg: float, warmup_epochs: int) -> float:
+ """Get the effective lambda_reg at a given epoch (accounts for warmup)."""
+ if warmup_epochs > 0 and epoch <= warmup_epochs:
+ return lambda_reg * (epoch / warmup_epochs)
+ return lambda_reg
+
+
+def load_experiment(exp_dir: Path) -> Optional[ExperimentConfig]:
+ """Load experiment config from a directory."""
+ config_path = exp_dir / "config.json"
+ results_path = exp_dir / "results.json"
+
+ if not config_path.exists() or not results_path.exists():
+ return None
+
+ with open(config_path) as f:
+ config = json.load(f)
+
+ reg_type = config.get("reg_type", "squared")
+ warmup = config.get("warmup_epochs", 0)
+ stable_init = config.get("stable_init", False)
+
+ # Build display name with settings
+ parts = [f"{reg_type}"]
+ parts.append(f"λ_reg={config['lambda_reg']}")
+ parts.append(f"λ_target={config['lambda_target']}")
+ if warmup > 0:
+ parts.append(f"warmup={warmup}")
+ if stable_init:
+ parts.append("stable_init")
+
+ name = ", ".join(parts)
+
+ return ExperimentConfig(
+ name=name,
+ path=exp_dir,
+ dataset=config["dataset"],
+ reg_type=reg_type,
+ lambda_reg=config["lambda_reg"],
+ lambda_target=config["lambda_target"],
+ warmup_epochs=warmup,
+ stable_init=stable_init,
+ epochs=config["epochs"],
+ depths=config["depths"],
+ )
+
+
+def load_results(exp: ExperimentConfig, depth: int) -> Optional[Dict]:
+ """Load results for a specific depth from an experiment."""
+ results_path = exp.path / "results.json"
+ with open(results_path) as f:
+ data = json.load(f)
+
+ depth_key = str(depth)
+ if "lyapunov" not in data or depth_key not in data["lyapunov"]:
+ return None
+
+ epochs_data = data["lyapunov"][depth_key]
+
+ epochs = []
+ train_losses = []
+ lyap_values = []
+ grad_norms = []
+ lyap_reg_losses = []
+ task_losses = []
+
+ for entry in epochs_data:
+ epoch = entry["epoch"]
+ train_loss = entry["train_loss"]
+ lyap = entry["lyapunov"]
+ grad_norm = entry["grad_norm"]
+
+ if train_loss is None or np.isnan(train_loss):
+ break
+
+ epochs.append(epoch)
+ train_losses.append(train_loss)
+ lyap_values.append(lyap)
+ grad_norms.append(grad_norm)
+
+ # Compute effective lambda and reg loss
+ eff_lambda = get_effective_lambda(epoch, exp.lambda_reg, exp.warmup_epochs)
+ reg_value = compute_reg_loss(lyap, exp.reg_type, exp.lambda_target)
+ lyap_loss = eff_lambda * reg_value
+ lyap_reg_losses.append(lyap_loss)
+
+ # Estimate task loss = total - reg
+ task_loss = train_loss - lyap_loss if lyap is not None else train_loss
+ task_losses.append(task_loss)
+
+ if not epochs:
+ return None
+
+ return {
+ "epochs": np.array(epochs),
+ "task_loss": np.array(task_losses),
+ "lyap_reg_loss": np.array(lyap_reg_losses),
+ "lyap_values": np.array([v if v is not None else np.nan for v in lyap_values]),
+ "grad_norms": np.array(grad_norms),
+ }
+
+
+def find_experiments() -> List[ExperimentConfig]:
+ """Find all depth-scaling experiments."""
+ experiments = []
+
+ # depth_scaling (baseline squared reg)
+ for subdir in sorted(RUNS_DIR.glob("depth_scaling/cifar*")):
+ exp = load_experiment(subdir)
+ if exp:
+ exp.name = f"{exp.dataset}: {exp.name}"
+ experiments.append(exp)
+
+ # depth_scaling variants
+ variant_dirs = [
+ "depth_scaling_asymm",
+ "depth_scaling_extreme",
+ "depth_scaling_hinge",
+ "depth_scaling_stable_init",
+ "depth_scaling_target1",
+ "depth_scaling_weak_reg",
+ ]
+
+ for variant in variant_dirs:
+ variant_path = RUNS_DIR / variant
+ if variant_path.exists():
+ for subdir in sorted(variant_path.glob("cifar*")):
+ exp = load_experiment(subdir)
+ if exp:
+ exp.name = f"{exp.dataset}: {exp.name}"
+ experiments.append(exp)
+
+ return experiments
+
+
+def plot_comparison(depth: int = 12, save_path: Optional[str] = None):
+ """Create 4-panel comparison plot for all experiments at a given depth."""
+ experiments = find_experiments()
+
+ if not experiments:
+ print("No experiments found!")
+ return
+
+ fig, axes = plt.subplots(2, 2, figsize=(16, 12))
+ fig.suptitle(f"Experiment Comparison (depth={depth})", fontsize=14, fontweight="bold")
+
+ colors = plt.cm.tab10(np.linspace(0, 1, len(experiments)))
+
+ plotted = []
+
+ for i, exp in enumerate(experiments):
+ results = load_results(exp, depth)
+ if results is None:
+ print(f" Skipping {exp.name}: no data for depth={depth}")
+ continue
+
+ color = colors[i]
+ label = exp.name
+
+ # Task Loss
+ axes[0, 0].plot(results["epochs"], results["task_loss"],
+ color=color, label=label, linewidth=1.5, alpha=0.85)
+
+ # Lyapunov Reg Loss
+ axes[0, 1].plot(results["epochs"], results["lyap_reg_loss"],
+ color=color, label=label, linewidth=1.5, alpha=0.85)
+
+ # λ(t)
+ axes[1, 0].plot(results["epochs"], results["lyap_values"],
+ color=color, label=label, linewidth=1.5, alpha=0.85)
+
+ # Gradient Norm
+ axes[1, 1].plot(results["epochs"], results["grad_norms"],
+ color=color, label=label, linewidth=1.5, alpha=0.85)
+
+ plotted.append(label)
+
+ # Configure axes
+ axes[0, 0].set_title("Task Loss (CE)")
+ axes[0, 0].set_xlabel("Epoch")
+ axes[0, 0].set_ylabel("Loss")
+ axes[0, 0].grid(True, alpha=0.3)
+
+ axes[0, 1].set_title("Lyapunov Regularization Loss")
+ axes[0, 1].set_xlabel("Epoch")
+ axes[0, 1].set_ylabel("Loss")
+ axes[0, 1].grid(True, alpha=0.3)
+
+ axes[1, 0].set_title(r"$\lambda(t)$ (Lyapunov Exponent)")
+ axes[1, 0].set_xlabel("Epoch")
+ axes[1, 0].set_ylabel(r"$\lambda$")
+ axes[1, 0].axhline(y=0, color="black", linestyle="--", alpha=0.5, linewidth=0.8)
+ axes[1, 0].grid(True, alpha=0.3)
+
+ axes[1, 1].set_title("Gradient Norm")
+ axes[1, 1].set_xlabel("Epoch")
+ axes[1, 1].set_ylabel(r"$\|\nabla\|_2$")
+ axes[1, 1].grid(True, alpha=0.3)
+
+ # Add legend below the plots
+ handles, labels = axes[0, 0].get_legend_handles_labels()
+ if handles:
+ fig.legend(handles, labels, loc="lower center", ncol=2,
+ fontsize=8.5, bbox_to_anchor=(0.5, -0.02),
+ frameon=True, fancybox=True)
+
+ plt.tight_layout(rect=[0, 0.06, 1, 0.96])
+
+ if save_path is None:
+ save_path = str(RUNS_DIR / "experiment_comparison.png")
+
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
+ print(f"Saved plot to {save_path}")
+ plt.close()
+
+
+if __name__ == "__main__":
+ plot_comparison(depth=12)
diff --git a/runs/experiment_comparison.png b/runs/experiment_comparison.png
new file mode 100644
index 0000000..973ff57
--- /dev/null
+++ b/runs/experiment_comparison.png
Binary files differ