""" Plot comparison of all depth-scaling experiments showing: - Task loss (CE only, estimated by subtracting reg loss from total) - Lyapunov regularization loss - λ(t) (Lyapunov exponent over time) - Gradient norm Each experiment variant is shown as a separate line, with experiment settings in the legend. """ import json import numpy as np import matplotlib.pyplot as plt from pathlib import Path from dataclasses import dataclass from typing import List, Optional, Dict RUNS_DIR = Path(__file__).parent.parent.parent / "runs" @dataclass class ExperimentConfig: """Configuration for an experiment.""" name: str # display name path: Path # path to results dir dataset: str reg_type: str lambda_reg: float lambda_target: float warmup_epochs: int stable_init: bool epochs: int depths: List[int] def compute_reg_loss(lyap_value: float, reg_type: str, lambda_target: float, lyap_threshold: float = 2.0) -> float: """Compute the regularization loss value (before multiplying by lambda_reg).""" if lyap_value is None: return 0.0 if reg_type == "squared": return (lyap_value - lambda_target) ** 2 elif reg_type == "hinge": excess = max(0, lyap_value) return excess ** 2 elif reg_type == "asymmetric": chaos = max(0, lyap_value) ** 2 collapse = 0.1 * max(0, -lyap_value - 1.0) ** 2 return chaos + collapse elif reg_type == "extreme": excess = max(0, lyap_value - lyap_threshold) return excess ** 2 else: return (lyap_value - lambda_target) ** 2 def get_effective_lambda(epoch: int, lambda_reg: float, warmup_epochs: int) -> float: """Get the effective lambda_reg at a given epoch (accounts for warmup).""" if warmup_epochs > 0 and epoch <= warmup_epochs: return lambda_reg * (epoch / warmup_epochs) return lambda_reg def load_experiment(exp_dir: Path) -> Optional[ExperimentConfig]: """Load experiment config from a directory.""" config_path = exp_dir / "config.json" results_path = exp_dir / "results.json" if not config_path.exists() or not results_path.exists(): return None with open(config_path) as f: config = json.load(f) reg_type = config.get("reg_type", "squared") warmup = config.get("warmup_epochs", 0) stable_init = config.get("stable_init", False) # Build display name with settings parts = [f"{reg_type}"] parts.append(f"λ_reg={config['lambda_reg']}") parts.append(f"λ_target={config['lambda_target']}") if warmup > 0: parts.append(f"warmup={warmup}") if stable_init: parts.append("stable_init") name = ", ".join(parts) return ExperimentConfig( name=name, path=exp_dir, dataset=config["dataset"], reg_type=reg_type, lambda_reg=config["lambda_reg"], lambda_target=config["lambda_target"], warmup_epochs=warmup, stable_init=stable_init, epochs=config["epochs"], depths=config["depths"], ) def load_results(exp: ExperimentConfig, depth: int) -> Optional[Dict]: """Load results for a specific depth from an experiment.""" results_path = exp.path / "results.json" with open(results_path) as f: data = json.load(f) depth_key = str(depth) if "lyapunov" not in data or depth_key not in data["lyapunov"]: return None epochs_data = data["lyapunov"][depth_key] epochs = [] train_losses = [] lyap_values = [] grad_norms = [] lyap_reg_losses = [] task_losses = [] for entry in epochs_data: epoch = entry["epoch"] train_loss = entry["train_loss"] lyap = entry["lyapunov"] grad_norm = entry["grad_norm"] if train_loss is None or np.isnan(train_loss): break epochs.append(epoch) train_losses.append(train_loss) lyap_values.append(lyap) grad_norms.append(grad_norm) # Compute effective lambda and reg loss eff_lambda = get_effective_lambda(epoch, exp.lambda_reg, exp.warmup_epochs) reg_value = compute_reg_loss(lyap, exp.reg_type, exp.lambda_target) lyap_loss = eff_lambda * reg_value lyap_reg_losses.append(lyap_loss) # Estimate task loss = total - reg task_loss = train_loss - lyap_loss if lyap is not None else train_loss task_losses.append(task_loss) if not epochs: return None return { "epochs": np.array(epochs), "task_loss": np.array(task_losses), "lyap_reg_loss": np.array(lyap_reg_losses), "lyap_values": np.array([v if v is not None else np.nan for v in lyap_values]), "grad_norms": np.array(grad_norms), } def find_experiments() -> List[ExperimentConfig]: """Find all depth-scaling experiments.""" experiments = [] # depth_scaling (baseline squared reg) for subdir in sorted(RUNS_DIR.glob("depth_scaling/cifar*")): exp = load_experiment(subdir) if exp: exp.name = f"{exp.dataset}: {exp.name}" experiments.append(exp) # depth_scaling variants variant_dirs = [ "depth_scaling_asymm", "depth_scaling_extreme", "depth_scaling_hinge", "depth_scaling_stable_init", "depth_scaling_target1", "depth_scaling_weak_reg", ] for variant in variant_dirs: variant_path = RUNS_DIR / variant if variant_path.exists(): for subdir in sorted(variant_path.glob("cifar*")): exp = load_experiment(subdir) if exp: exp.name = f"{exp.dataset}: {exp.name}" experiments.append(exp) return experiments def plot_comparison(depth: int = 12, save_path: Optional[str] = None): """Create 4-panel comparison plot for all experiments at a given depth.""" experiments = find_experiments() if not experiments: print("No experiments found!") return fig, axes = plt.subplots(2, 2, figsize=(16, 12)) fig.suptitle(f"Experiment Comparison (depth={depth})", fontsize=14, fontweight="bold") colors = plt.cm.tab10(np.linspace(0, 1, len(experiments))) plotted = [] for i, exp in enumerate(experiments): results = load_results(exp, depth) if results is None: print(f" Skipping {exp.name}: no data for depth={depth}") continue color = colors[i] label = exp.name # Task Loss axes[0, 0].plot(results["epochs"], results["task_loss"], color=color, label=label, linewidth=1.5, alpha=0.85) # Lyapunov Reg Loss axes[0, 1].plot(results["epochs"], results["lyap_reg_loss"], color=color, label=label, linewidth=1.5, alpha=0.85) # λ(t) axes[1, 0].plot(results["epochs"], results["lyap_values"], color=color, label=label, linewidth=1.5, alpha=0.85) # Gradient Norm axes[1, 1].plot(results["epochs"], results["grad_norms"], color=color, label=label, linewidth=1.5, alpha=0.85) plotted.append(label) # Configure axes axes[0, 0].set_title("Task Loss (CE)") axes[0, 0].set_xlabel("Epoch") axes[0, 0].set_ylabel("Loss") axes[0, 0].grid(True, alpha=0.3) axes[0, 1].set_title("Lyapunov Regularization Loss") axes[0, 1].set_xlabel("Epoch") axes[0, 1].set_ylabel("Loss") axes[0, 1].grid(True, alpha=0.3) axes[1, 0].set_title(r"$\lambda(t)$ (Lyapunov Exponent)") axes[1, 0].set_xlabel("Epoch") axes[1, 0].set_ylabel(r"$\lambda$") axes[1, 0].axhline(y=0, color="black", linestyle="--", alpha=0.5, linewidth=0.8) axes[1, 0].grid(True, alpha=0.3) axes[1, 1].set_title("Gradient Norm") axes[1, 1].set_xlabel("Epoch") axes[1, 1].set_ylabel(r"$\|\nabla\|_2$") axes[1, 1].grid(True, alpha=0.3) # Add legend below the plots handles, labels = axes[0, 0].get_legend_handles_labels() if handles: fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=8.5, bbox_to_anchor=(0.5, -0.02), frameon=True, fancybox=True) plt.tight_layout(rect=[0, 0.06, 1, 0.96]) if save_path is None: save_path = str(RUNS_DIR / "experiment_comparison.png") plt.savefig(save_path, dpi=150, bbox_inches="tight") print(f"Saved plot to {save_path}") plt.close() if __name__ == "__main__": plot_comparison(depth=12)