summaryrefslogtreecommitdiff
path: root/files/experiments/plot_depth_comparison.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/experiments/plot_depth_comparison.py
init commit
Diffstat (limited to 'files/experiments/plot_depth_comparison.py')
-rw-r--r--files/experiments/plot_depth_comparison.py305
1 files changed, 305 insertions, 0 deletions
diff --git a/files/experiments/plot_depth_comparison.py b/files/experiments/plot_depth_comparison.py
new file mode 100644
index 0000000..2222b7b
--- /dev/null
+++ b/files/experiments/plot_depth_comparison.py
@@ -0,0 +1,305 @@
+"""
+Visualization for depth comparison experiments.
+
+Usage:
+ python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP
+"""
+
+import os
+import sys
+import json
+import argparse
+from typing import Dict, List
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.lines import Line2D
+
+
+def load_results(results_dir: str) -> Dict:
+ """Load results from JSON file."""
+ with open(os.path.join(results_dir, "results.json"), "r") as f:
+ return json.load(f)
+
+
+def load_config(results_dir: str) -> Dict:
+ """Load config from JSON file."""
+ config_path = os.path.join(results_dir, "config.json")
+ if os.path.exists(config_path):
+ with open(config_path, "r") as f:
+ return json.load(f)
+ return {}
+
+
+def plot_training_curves(results: Dict, output_path: str):
+ """
+ Plot training curves for each depth.
+
+ Creates a figure with subplots for each depth showing:
+ - Training loss
+ - Validation accuracy
+ - Lyapunov exponent (if available)
+ - Gradient norm
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+ n_depths = len(depths)
+
+ fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths))
+ if n_depths == 1:
+ axes = axes.reshape(1, -1)
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+ labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"}
+
+ for i, depth in enumerate(depths):
+ for method in ["vanilla", "lyapunov"]:
+ metrics = results[method][str(depth)]
+ epochs = [m["epoch"] for m in metrics]
+
+ # Training Loss
+ train_loss = [m["train_loss"] for m in metrics]
+ axes[i, 0].plot(epochs, train_loss, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 0].set_ylabel("Train Loss")
+ axes[i, 0].set_title(f"Depth={depth}: Training Loss")
+ axes[i, 0].set_yscale("log")
+ axes[i, 0].grid(True, alpha=0.3)
+
+ # Validation Accuracy
+ val_acc = [m["val_acc"] for m in metrics]
+ axes[i, 1].plot(epochs, val_acc, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 1].set_ylabel("Val Accuracy")
+ axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy")
+ axes[i, 1].set_ylim(0, 1)
+ axes[i, 1].grid(True, alpha=0.3)
+
+ # Lyapunov Exponent
+ lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None]
+ lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None]
+ if lyap:
+ axes[i, 2].plot(lyap_epochs, lyap, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ axes[i, 2].set_ylabel("Lyapunov λ")
+ axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent")
+ axes[i, 2].grid(True, alpha=0.3)
+
+ # Gradient Norm
+ grad_norm = [m["grad_norm"] for m in metrics]
+ axes[i, 3].plot(epochs, grad_norm, color=colors[method],
+ label=labels[method], linewidth=2)
+ axes[i, 3].set_ylabel("Gradient Norm")
+ axes[i, 3].set_title(f"Depth={depth}: Gradient Norm")
+ axes[i, 3].set_yscale("log")
+ axes[i, 3].grid(True, alpha=0.3)
+
+ # Add legend to first row
+ if i == 0:
+ for ax in axes[i]:
+ ax.legend(loc="upper right")
+
+ # Set x-labels on bottom row
+ for ax in axes[-1]:
+ ax.set_xlabel("Epoch")
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved training curves to {output_path}")
+
+
+def plot_depth_summary(results: Dict, output_path: str):
+ """
+ Plot summary comparing methods across depths.
+
+ Creates a figure showing:
+ - Final validation accuracy vs depth
+ - Final gradient norm vs depth
+ - Final Lyapunov exponent vs depth
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+
+ fig, axes = plt.subplots(1, 3, figsize=(14, 4))
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+ markers = {"vanilla": "o", "lyapunov": "s"}
+
+ # Collect final metrics
+ van_acc = []
+ lyap_acc = []
+ van_grad = []
+ lyap_grad = []
+ lyap_lambda = []
+
+ for depth in depths:
+ van_metrics = results["vanilla"][str(depth)][-1]
+ lyap_metrics = results["lyapunov"][str(depth)][-1]
+
+ van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0)
+ lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0)
+
+ van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0)
+ lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0)
+
+ if lyap_metrics["lyapunov"] is not None:
+ lyap_lambda.append(lyap_metrics["lyapunov"])
+ else:
+ lyap_lambda.append(0)
+
+ # Plot 1: Validation Accuracy vs Depth
+ ax = axes[0]
+ ax.plot(depths, van_acc, 'o-', color=colors["vanilla"],
+ label="Vanilla", linewidth=2, markersize=8)
+ ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"],
+ label="Lyapunov", linewidth=2, markersize=8)
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Validation Accuracy")
+ ax.set_title("Accuracy vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05)
+
+ # Plot 2: Gradient Norm vs Depth
+ ax = axes[1]
+ ax.plot(depths, van_grad, 'o-', color=colors["vanilla"],
+ label="Vanilla", linewidth=2, markersize=8)
+ ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"],
+ label="Lyapunov", linewidth=2, markersize=8)
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Gradient Norm")
+ ax.set_title("Gradient Stability vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_yscale("log")
+
+ # Plot 3: Lyapunov Exponent vs Depth
+ ax = axes[2]
+ ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"],
+ linewidth=2, markersize=8)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)")
+ ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region")
+ ax.set_xlabel("Network Depth (# layers)")
+ ax.set_ylabel("Final Lyapunov Exponent")
+ ax.set_title("Lyapunov Exponent vs Depth")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved depth summary to {output_path}")
+
+
+def plot_stability_comparison(results: Dict, output_path: str):
+ """
+ Plot stability metrics comparison.
+ """
+ depths = sorted([int(d) for d in results["vanilla"].keys()])
+
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
+
+ colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"}
+
+ # Collect metrics over training
+ for depth in depths:
+ van_metrics = results["vanilla"][str(depth)]
+ lyap_metrics = results["lyapunov"][str(depth)]
+
+ van_epochs = [m["epoch"] for m in van_metrics]
+ lyap_epochs = [m["epoch"] for m in lyap_metrics]
+
+ # Firing rate
+ van_fr = [m["firing_rate"] for m in van_metrics]
+ lyap_fr = [m["firing_rate"] for m in lyap_metrics]
+ axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+ axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+
+ # Dead neurons
+ van_dead = [m["dead_neurons"] for m in van_metrics]
+ lyap_dead = [m["dead_neurons"] for m in lyap_metrics]
+ axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+ axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"],
+ alpha=0.3 + 0.1 * depths.index(depth))
+
+ axes[0, 0].set_xlabel("Epoch")
+ axes[0, 0].set_ylabel("Firing Rate")
+ axes[0, 0].set_title("Firing Rate Over Training")
+ axes[0, 0].grid(True, alpha=0.3)
+
+ axes[0, 1].set_xlabel("Epoch")
+ axes[0, 1].set_ylabel("Dead Neuron Fraction")
+ axes[0, 1].set_title("Dead Neurons Over Training")
+ axes[0, 1].grid(True, alpha=0.3)
+
+ # Final metrics bar chart
+ van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths]
+ lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths]
+
+ x = np.arange(len(depths))
+ width = 0.35
+
+ axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"])
+ axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"])
+ axes[1, 0].set_xlabel("Network Depth")
+ axes[1, 0].set_ylabel("Final Validation Accuracy")
+ axes[1, 0].set_title("Final Accuracy Comparison")
+ axes[1, 0].set_xticks(x)
+ axes[1, 0].set_xticklabels(depths)
+ axes[1, 0].legend()
+ axes[1, 0].grid(True, alpha=0.3, axis='y')
+
+ # Improvement percentage
+ improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)]
+ colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements]
+
+ axes[1, 1].bar(x, improvements, color=colors_bar)
+ axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5)
+ axes[1, 1].set_xlabel("Network Depth")
+ axes[1, 1].set_ylabel("Accuracy Improvement")
+ axes[1, 1].set_title("Lyapunov Improvement over Vanilla")
+ axes[1, 1].set_xticks(x)
+ axes[1, 1].set_xticklabels(depths)
+ axes[1, 1].grid(True, alpha=0.3, axis='y')
+
+ # Add legend for line plots
+ custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2),
+ Line2D([0], [0], color=colors["lyapunov"], lw=2)]
+ axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov'])
+ axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov'])
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ plt.close()
+ print(f"Saved stability comparison to {output_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--results_dir", type=str, required=True,
+ help="Directory containing results.json")
+ parser.add_argument("--output_dir", type=str, default=None,
+ help="Output directory for plots (default: same as results_dir)")
+ args = parser.parse_args()
+
+ output_dir = args.output_dir or args.results_dir
+
+ print(f"Loading results from {args.results_dir}")
+ results = load_results(args.results_dir)
+ config = load_config(args.results_dir)
+
+ print(f"Config: {config}")
+
+ # Generate plots
+ plot_training_curves(results, os.path.join(output_dir, "training_curves.png"))
+ plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png"))
+ plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png"))
+
+ print(f"\nAll plots saved to {output_dir}")
+
+
+if __name__ == "__main__":
+ main()