""" Visualization for depth comparison experiments. Usage: python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP """ import os import sys import json import argparse from typing import Dict, List import numpy as np import matplotlib.pyplot as plt from matplotlib.lines import Line2D def load_results(results_dir: str) -> Dict: """Load results from JSON file.""" with open(os.path.join(results_dir, "results.json"), "r") as f: return json.load(f) def load_config(results_dir: str) -> Dict: """Load config from JSON file.""" config_path = os.path.join(results_dir, "config.json") if os.path.exists(config_path): with open(config_path, "r") as f: return json.load(f) return {} def plot_training_curves(results: Dict, output_path: str): """ Plot training curves for each depth. Creates a figure with subplots for each depth showing: - Training loss - Validation accuracy - Lyapunov exponent (if available) - Gradient norm """ depths = sorted([int(d) for d in results["vanilla"].keys()]) n_depths = len(depths) fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths)) if n_depths == 1: axes = axes.reshape(1, -1) colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"} for i, depth in enumerate(depths): for method in ["vanilla", "lyapunov"]: metrics = results[method][str(depth)] epochs = [m["epoch"] for m in metrics] # Training Loss train_loss = [m["train_loss"] for m in metrics] axes[i, 0].plot(epochs, train_loss, color=colors[method], label=labels[method], linewidth=2) axes[i, 0].set_ylabel("Train Loss") axes[i, 0].set_title(f"Depth={depth}: Training Loss") axes[i, 0].set_yscale("log") axes[i, 0].grid(True, alpha=0.3) # Validation Accuracy val_acc = [m["val_acc"] for m in metrics] axes[i, 1].plot(epochs, val_acc, color=colors[method], label=labels[method], linewidth=2) axes[i, 1].set_ylabel("Val Accuracy") axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy") axes[i, 1].set_ylim(0, 1) axes[i, 1].grid(True, alpha=0.3) # Lyapunov Exponent lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None] lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None] if lyap: axes[i, 2].plot(lyap_epochs, lyap, color=colors[method], label=labels[method], linewidth=2) axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5) axes[i, 2].set_ylabel("Lyapunov λ") axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent") axes[i, 2].grid(True, alpha=0.3) # Gradient Norm grad_norm = [m["grad_norm"] for m in metrics] axes[i, 3].plot(epochs, grad_norm, color=colors[method], label=labels[method], linewidth=2) axes[i, 3].set_ylabel("Gradient Norm") axes[i, 3].set_title(f"Depth={depth}: Gradient Norm") axes[i, 3].set_yscale("log") axes[i, 3].grid(True, alpha=0.3) # Add legend to first row if i == 0: for ax in axes[i]: ax.legend(loc="upper right") # Set x-labels on bottom row for ax in axes[-1]: ax.set_xlabel("Epoch") plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved training curves to {output_path}") def plot_depth_summary(results: Dict, output_path: str): """ Plot summary comparing methods across depths. Creates a figure showing: - Final validation accuracy vs depth - Final gradient norm vs depth - Final Lyapunov exponent vs depth """ depths = sorted([int(d) for d in results["vanilla"].keys()]) fig, axes = plt.subplots(1, 3, figsize=(14, 4)) colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} markers = {"vanilla": "o", "lyapunov": "s"} # Collect final metrics van_acc = [] lyap_acc = [] van_grad = [] lyap_grad = [] lyap_lambda = [] for depth in depths: van_metrics = results["vanilla"][str(depth)][-1] lyap_metrics = results["lyapunov"][str(depth)][-1] van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0) lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0) van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0) lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0) if lyap_metrics["lyapunov"] is not None: lyap_lambda.append(lyap_metrics["lyapunov"]) else: lyap_lambda.append(0) # Plot 1: Validation Accuracy vs Depth ax = axes[0] ax.plot(depths, van_acc, 'o-', color=colors["vanilla"], label="Vanilla", linewidth=2, markersize=8) ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"], label="Lyapunov", linewidth=2, markersize=8) ax.set_xlabel("Network Depth (# layers)") ax.set_ylabel("Final Validation Accuracy") ax.set_title("Accuracy vs Depth") ax.legend() ax.grid(True, alpha=0.3) ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05) # Plot 2: Gradient Norm vs Depth ax = axes[1] ax.plot(depths, van_grad, 'o-', color=colors["vanilla"], label="Vanilla", linewidth=2, markersize=8) ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"], label="Lyapunov", linewidth=2, markersize=8) ax.set_xlabel("Network Depth (# layers)") ax.set_ylabel("Final Gradient Norm") ax.set_title("Gradient Stability vs Depth") ax.legend() ax.grid(True, alpha=0.3) ax.set_yscale("log") # Plot 3: Lyapunov Exponent vs Depth ax = axes[2] ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"], linewidth=2, markersize=8) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)") ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region") ax.set_xlabel("Network Depth (# layers)") ax.set_ylabel("Final Lyapunov Exponent") ax.set_title("Lyapunov Exponent vs Depth") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved depth summary to {output_path}") def plot_stability_comparison(results: Dict, output_path: str): """ Plot stability metrics comparison. """ depths = sorted([int(d) for d in results["vanilla"].keys()]) fig, axes = plt.subplots(2, 2, figsize=(12, 10)) colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} # Collect metrics over training for depth in depths: van_metrics = results["vanilla"][str(depth)] lyap_metrics = results["lyapunov"][str(depth)] van_epochs = [m["epoch"] for m in van_metrics] lyap_epochs = [m["epoch"] for m in lyap_metrics] # Firing rate van_fr = [m["firing_rate"] for m in van_metrics] lyap_fr = [m["firing_rate"] for m in lyap_metrics] axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"], alpha=0.3 + 0.1 * depths.index(depth)) axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"], alpha=0.3 + 0.1 * depths.index(depth)) # Dead neurons van_dead = [m["dead_neurons"] for m in van_metrics] lyap_dead = [m["dead_neurons"] for m in lyap_metrics] axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"], alpha=0.3 + 0.1 * depths.index(depth)) axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"], alpha=0.3 + 0.1 * depths.index(depth)) axes[0, 0].set_xlabel("Epoch") axes[0, 0].set_ylabel("Firing Rate") axes[0, 0].set_title("Firing Rate Over Training") axes[0, 0].grid(True, alpha=0.3) axes[0, 1].set_xlabel("Epoch") axes[0, 1].set_ylabel("Dead Neuron Fraction") axes[0, 1].set_title("Dead Neurons Over Training") axes[0, 1].grid(True, alpha=0.3) # Final metrics bar chart van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths] lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths] x = np.arange(len(depths)) width = 0.35 axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"]) axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"]) axes[1, 0].set_xlabel("Network Depth") axes[1, 0].set_ylabel("Final Validation Accuracy") axes[1, 0].set_title("Final Accuracy Comparison") axes[1, 0].set_xticks(x) axes[1, 0].set_xticklabels(depths) axes[1, 0].legend() axes[1, 0].grid(True, alpha=0.3, axis='y') # Improvement percentage improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)] colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements] axes[1, 1].bar(x, improvements, color=colors_bar) axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5) axes[1, 1].set_xlabel("Network Depth") axes[1, 1].set_ylabel("Accuracy Improvement") axes[1, 1].set_title("Lyapunov Improvement over Vanilla") axes[1, 1].set_xticks(x) axes[1, 1].set_xticklabels(depths) axes[1, 1].grid(True, alpha=0.3, axis='y') # Add legend for line plots custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2), Line2D([0], [0], color=colors["lyapunov"], lw=2)] axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov']) axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov']) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved stability comparison to {output_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument("--results_dir", type=str, required=True, help="Directory containing results.json") parser.add_argument("--output_dir", type=str, default=None, help="Output directory for plots (default: same as results_dir)") args = parser.parse_args() output_dir = args.output_dir or args.results_dir print(f"Loading results from {args.results_dir}") results = load_results(args.results_dir) config = load_config(args.results_dir) print(f"Config: {config}") # Generate plots plot_training_curves(results, os.path.join(output_dir, "training_curves.png")) plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png")) plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png")) print(f"\nAll plots saved to {output_dir}") if __name__ == "__main__": main()