From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/experiments/plot_depth_comparison.py | 305 +++++++++++++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 files/experiments/plot_depth_comparison.py (limited to 'files/experiments/plot_depth_comparison.py') diff --git a/files/experiments/plot_depth_comparison.py b/files/experiments/plot_depth_comparison.py new file mode 100644 index 0000000..2222b7b --- /dev/null +++ b/files/experiments/plot_depth_comparison.py @@ -0,0 +1,305 @@ +""" +Visualization for depth comparison experiments. + +Usage: + python files/experiments/plot_depth_comparison.py --results_dir runs/depth_comparison/TIMESTAMP +""" + +import os +import sys +import json +import argparse +from typing import Dict, List + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + + +def load_results(results_dir: str) -> Dict: + """Load results from JSON file.""" + with open(os.path.join(results_dir, "results.json"), "r") as f: + return json.load(f) + + +def load_config(results_dir: str) -> Dict: + """Load config from JSON file.""" + config_path = os.path.join(results_dir, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + return json.load(f) + return {} + + +def plot_training_curves(results: Dict, output_path: str): + """ + Plot training curves for each depth. + + Creates a figure with subplots for each depth showing: + - Training loss + - Validation accuracy + - Lyapunov exponent (if available) + - Gradient norm + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + n_depths = len(depths) + + fig, axes = plt.subplots(n_depths, 4, figsize=(16, 3 * n_depths)) + if n_depths == 1: + axes = axes.reshape(1, -1) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + labels = {"vanilla": "Vanilla", "lyapunov": "Lyapunov"} + + for i, depth in enumerate(depths): + for method in ["vanilla", "lyapunov"]: + metrics = results[method][str(depth)] + epochs = [m["epoch"] for m in metrics] + + # Training Loss + train_loss = [m["train_loss"] for m in metrics] + axes[i, 0].plot(epochs, train_loss, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 0].set_ylabel("Train Loss") + axes[i, 0].set_title(f"Depth={depth}: Training Loss") + axes[i, 0].set_yscale("log") + axes[i, 0].grid(True, alpha=0.3) + + # Validation Accuracy + val_acc = [m["val_acc"] for m in metrics] + axes[i, 1].plot(epochs, val_acc, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 1].set_ylabel("Val Accuracy") + axes[i, 1].set_title(f"Depth={depth}: Validation Accuracy") + axes[i, 1].set_ylim(0, 1) + axes[i, 1].grid(True, alpha=0.3) + + # Lyapunov Exponent + lyap = [m["lyapunov"] for m in metrics if m["lyapunov"] is not None] + lyap_epochs = [m["epoch"] for m in metrics if m["lyapunov"] is not None] + if lyap: + axes[i, 2].plot(lyap_epochs, lyap, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 2].axhline(y=0, color='gray', linestyle='--', alpha=0.5) + axes[i, 2].set_ylabel("Lyapunov λ") + axes[i, 2].set_title(f"Depth={depth}: Lyapunov Exponent") + axes[i, 2].grid(True, alpha=0.3) + + # Gradient Norm + grad_norm = [m["grad_norm"] for m in metrics] + axes[i, 3].plot(epochs, grad_norm, color=colors[method], + label=labels[method], linewidth=2) + axes[i, 3].set_ylabel("Gradient Norm") + axes[i, 3].set_title(f"Depth={depth}: Gradient Norm") + axes[i, 3].set_yscale("log") + axes[i, 3].grid(True, alpha=0.3) + + # Add legend to first row + if i == 0: + for ax in axes[i]: + ax.legend(loc="upper right") + + # Set x-labels on bottom row + for ax in axes[-1]: + ax.set_xlabel("Epoch") + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved training curves to {output_path}") + + +def plot_depth_summary(results: Dict, output_path: str): + """ + Plot summary comparing methods across depths. + + Creates a figure showing: + - Final validation accuracy vs depth + - Final gradient norm vs depth + - Final Lyapunov exponent vs depth + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + + fig, axes = plt.subplots(1, 3, figsize=(14, 4)) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + markers = {"vanilla": "o", "lyapunov": "s"} + + # Collect final metrics + van_acc = [] + lyap_acc = [] + van_grad = [] + lyap_grad = [] + lyap_lambda = [] + + for depth in depths: + van_metrics = results["vanilla"][str(depth)][-1] + lyap_metrics = results["lyapunov"][str(depth)][-1] + + van_acc.append(van_metrics["val_acc"] if not np.isnan(van_metrics["val_acc"]) else 0) + lyap_acc.append(lyap_metrics["val_acc"] if not np.isnan(lyap_metrics["val_acc"]) else 0) + + van_grad.append(van_metrics["grad_norm"] if not np.isnan(van_metrics["grad_norm"]) else 0) + lyap_grad.append(lyap_metrics["grad_norm"] if not np.isnan(lyap_metrics["grad_norm"]) else 0) + + if lyap_metrics["lyapunov"] is not None: + lyap_lambda.append(lyap_metrics["lyapunov"]) + else: + lyap_lambda.append(0) + + # Plot 1: Validation Accuracy vs Depth + ax = axes[0] + ax.plot(depths, van_acc, 'o-', color=colors["vanilla"], + label="Vanilla", linewidth=2, markersize=8) + ax.plot(depths, lyap_acc, 's-', color=colors["lyapunov"], + label="Lyapunov", linewidth=2, markersize=8) + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Validation Accuracy") + ax.set_title("Accuracy vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_ylim(0, max(max(van_acc), max(lyap_acc)) * 1.1 + 0.05) + + # Plot 2: Gradient Norm vs Depth + ax = axes[1] + ax.plot(depths, van_grad, 'o-', color=colors["vanilla"], + label="Vanilla", linewidth=2, markersize=8) + ax.plot(depths, lyap_grad, 's-', color=colors["lyapunov"], + label="Lyapunov", linewidth=2, markersize=8) + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Gradient Norm") + ax.set_title("Gradient Stability vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_yscale("log") + + # Plot 3: Lyapunov Exponent vs Depth + ax = axes[2] + ax.plot(depths, lyap_lambda, 's-', color=colors["lyapunov"], + linewidth=2, markersize=8) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5, label="Target (λ=0)") + ax.fill_between(depths, -0.5, 0.5, alpha=0.2, color='green', label="Stable region") + ax.set_xlabel("Network Depth (# layers)") + ax.set_ylabel("Final Lyapunov Exponent") + ax.set_title("Lyapunov Exponent vs Depth") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved depth summary to {output_path}") + + +def plot_stability_comparison(results: Dict, output_path: str): + """ + Plot stability metrics comparison. + """ + depths = sorted([int(d) for d in results["vanilla"].keys()]) + + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + + colors = {"vanilla": "#E74C3C", "lyapunov": "#3498DB"} + + # Collect metrics over training + for depth in depths: + van_metrics = results["vanilla"][str(depth)] + lyap_metrics = results["lyapunov"][str(depth)] + + van_epochs = [m["epoch"] for m in van_metrics] + lyap_epochs = [m["epoch"] for m in lyap_metrics] + + # Firing rate + van_fr = [m["firing_rate"] for m in van_metrics] + lyap_fr = [m["firing_rate"] for m in lyap_metrics] + axes[0, 0].plot(van_epochs, van_fr, color=colors["vanilla"], + alpha=0.3 + 0.1 * depths.index(depth)) + axes[0, 0].plot(lyap_epochs, lyap_fr, color=colors["lyapunov"], + alpha=0.3 + 0.1 * depths.index(depth)) + + # Dead neurons + van_dead = [m["dead_neurons"] for m in van_metrics] + lyap_dead = [m["dead_neurons"] for m in lyap_metrics] + axes[0, 1].plot(van_epochs, van_dead, color=colors["vanilla"], + alpha=0.3 + 0.1 * depths.index(depth)) + axes[0, 1].plot(lyap_epochs, lyap_dead, color=colors["lyapunov"], + alpha=0.3 + 0.1 * depths.index(depth)) + + axes[0, 0].set_xlabel("Epoch") + axes[0, 0].set_ylabel("Firing Rate") + axes[0, 0].set_title("Firing Rate Over Training") + axes[0, 0].grid(True, alpha=0.3) + + axes[0, 1].set_xlabel("Epoch") + axes[0, 1].set_ylabel("Dead Neuron Fraction") + axes[0, 1].set_title("Dead Neurons Over Training") + axes[0, 1].grid(True, alpha=0.3) + + # Final metrics bar chart + van_final_acc = [results["vanilla"][str(d)][-1]["val_acc"] for d in depths] + lyap_final_acc = [results["lyapunov"][str(d)][-1]["val_acc"] for d in depths] + + x = np.arange(len(depths)) + width = 0.35 + + axes[1, 0].bar(x - width/2, van_final_acc, width, label='Vanilla', color=colors["vanilla"]) + axes[1, 0].bar(x + width/2, lyap_final_acc, width, label='Lyapunov', color=colors["lyapunov"]) + axes[1, 0].set_xlabel("Network Depth") + axes[1, 0].set_ylabel("Final Validation Accuracy") + axes[1, 0].set_title("Final Accuracy Comparison") + axes[1, 0].set_xticks(x) + axes[1, 0].set_xticklabels(depths) + axes[1, 0].legend() + axes[1, 0].grid(True, alpha=0.3, axis='y') + + # Improvement percentage + improvements = [(l - v) for v, l in zip(van_final_acc, lyap_final_acc)] + colors_bar = ['#27AE60' if imp > 0 else '#E74C3C' for imp in improvements] + + axes[1, 1].bar(x, improvements, color=colors_bar) + axes[1, 1].axhline(y=0, color='black', linestyle='-', linewidth=0.5) + axes[1, 1].set_xlabel("Network Depth") + axes[1, 1].set_ylabel("Accuracy Improvement") + axes[1, 1].set_title("Lyapunov Improvement over Vanilla") + axes[1, 1].set_xticks(x) + axes[1, 1].set_xticklabels(depths) + axes[1, 1].grid(True, alpha=0.3, axis='y') + + # Add legend for line plots + custom_lines = [Line2D([0], [0], color=colors["vanilla"], lw=2), + Line2D([0], [0], color=colors["lyapunov"], lw=2)] + axes[0, 0].legend(custom_lines, ['Vanilla', 'Lyapunov']) + axes[0, 1].legend(custom_lines, ['Vanilla', 'Lyapunov']) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved stability comparison to {output_path}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--results_dir", type=str, required=True, + help="Directory containing results.json") + parser.add_argument("--output_dir", type=str, default=None, + help="Output directory for plots (default: same as results_dir)") + args = parser.parse_args() + + output_dir = args.output_dir or args.results_dir + + print(f"Loading results from {args.results_dir}") + results = load_results(args.results_dir) + config = load_config(args.results_dir) + + print(f"Config: {config}") + + # Generate plots + plot_training_curves(results, os.path.join(output_dir, "training_curves.png")) + plot_depth_summary(results, os.path.join(output_dir, "depth_summary.png")) + plot_stability_comparison(results, os.path.join(output_dir, "stability_comparison.png")) + + print(f"\nAll plots saved to {output_dir}") + + +if __name__ == "__main__": + main() -- cgit v1.2.3