summaryrefslogtreecommitdiff
path: root/experiments/plot_toy_final.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/plot_toy_final.py')
-rw-r--r--experiments/plot_toy_final.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py
new file mode 100644
index 0000000..2f7c109
--- /dev/null
+++ b/experiments/plot_toy_final.py
@@ -0,0 +1,183 @@
+"""Generate final toy LQ experiment plots from v2 results across 3 seeds."""
+import os
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+output_dir = 'report'
+os.makedirs(output_dir, exist_ok=True)
+
+# Load all v2 results with term_grad_weight=1.0, fm=0.0
+seeds = [42, 123, 456]
+all_data = []
+for seed in seeds:
+ path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json'
+ if os.path.exists(path):
+ with open(path) as f:
+ all_data.append(json.load(f))
+
+if not all_data:
+ print("No results found!")
+ exit()
+
+# Also load v1 baseline (no term_grad) for comparison
+v1_path = 'results/toy_lq/toy_lq_seed42.json'
+v1_data = None
+if os.path.exists(v1_path):
+ with open(v1_path) as f:
+ v1_data = json.load(f)
+
+# Aggregate final per-layer results across seeds
+methods = ['dfa', 'state', 'credit']
+colors = {'dfa': '#2196F3', 'state': '#FF9800', 'credit': '#4CAF50'}
+labels = {'dfa': 'DFA', 'state': 'State Bridge', 'credit': 'Credit Bridge'}
+
+# Per-layer costate cosine
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+for ax, metric, title, ylabel in zip(
+ axes,
+ ['costate_cos', 'rho', 'nudge'],
+ ['Exact Costate Cosine', 'Perturbation Correlation (ρ)', 'Nudging Test'],
+ ['Cosine Similarity', 'Pearson Correlation', 'Loss Change (negative=good)']
+):
+ for method in methods:
+ key = f'{method}_{metric}'
+ values_per_seed = []
+ for data in all_data:
+ pl = data['final_per_layer']
+ if key in pl:
+ values_per_seed.append(pl[key])
+
+ if values_per_seed:
+ arr = np.array(values_per_seed)
+ mean = arr.mean(axis=0)
+ std = arr.std(axis=0)
+ layers = np.arange(len(mean))
+ ax.plot(layers, mean, 'o-', color=colors[method], label=labels[method], markersize=5)
+ ax.fill_between(layers, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+ ax.set_xlabel('Layer', fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+ if metric == 'costate_cos':
+ ax.set_ylim(-0.15, 1.05)
+ elif metric == 'rho':
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ elif metric == 'nudge':
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+fig.suptitle('Toy LQ Sanity Check: Per-Layer Diagnostics (3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'toy_per_layer_diagnostics.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved toy_per_layer_diagnostics.png")
+
+# Training curves
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+metric_keys = [
+ ('costate_cos', 'Avg Costate Cosine', 'Cosine Similarity'),
+ ('rho', 'Avg Perturbation ρ', 'Pearson Correlation'),
+ ('nudge', 'Avg Nudging', 'Loss Change'),
+]
+
+for ax, (metric, title, ylabel) in zip(axes, metric_keys):
+ for method in methods:
+ key = f'{method}_{metric}'
+ all_curves = []
+ for data in all_data:
+ log = data['log']
+ full_key = f'{method}_costate_cos' if metric == 'costate_cos' else f'{method}_{metric}'
+ if full_key in log:
+ all_curves.append(np.array(log[full_key]))
+
+ if all_curves:
+ # All should have same length, use shortest
+ min_len = min(len(c) for c in all_curves)
+ arr = np.array([c[:min_len] for c in all_curves])
+ steps = np.array(all_data[0]['log']['steps'][:min_len])
+ mean = arr.mean(axis=0)
+ std = arr.std(axis=0)
+ ax.plot(steps, mean, '-', color=colors[method], label=labels[method])
+ ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=colors[method])
+
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+
+fig.suptitle('Toy LQ: Training Curves (3 seeds)', fontsize=14, y=1.02)
+fig.tight_layout()
+fig.savefig(os.path.join(output_dir, 'toy_training_curves.png'), dpi=150, bbox_inches='tight')
+plt.close(fig)
+print("Saved toy_training_curves.png")
+
+# Compare v1 (no term grad) vs v2 (with term grad) for credit bridge
+if v1_data:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+
+ # v1 credit bridge (no term grad matching)
+ v1_log = v1_data['log']
+ ax.plot(v1_log['steps'], v1_log['credit_costate_cos'],
+ '--', color='red', label='Credit Bridge (w/o terminal grad)', alpha=0.8)
+
+ # v2 credit bridge (with term grad)
+ v2_log = all_data[0]['log'] # seed 42
+ ax.plot(v2_log['steps'], v2_log['credit_costate_cos'],
+ '-', color='green', label='Credit Bridge (w/ terminal grad)')
+
+ # State bridge for reference
+ ax.plot(v2_log['steps'], v2_log['state_costate_cos'],
+ '-', color='orange', label='State Bridge')
+
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel('Avg Costate Cosine', fontsize=12)
+ ax.set_title('Effect of Terminal Gradient Matching', fontsize=13)
+ ax.legend(fontsize=11)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.1, 1.05)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_term_grad_effect.png'), dpi=150)
+ plt.close(fig)
+ print("Saved toy_term_grad_effect.png")
+
+# Bridge residual (from v1 which has it)
+if v1_data and v1_data['log'].get('bridge_residual'):
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(v1_data['log']['steps'], v1_data['log']['bridge_residual'], '-', color='green')
+ ax.set_xlabel('Training Step', fontsize=12)
+ ax.set_ylabel('Bridge Residual', fontsize=12)
+ ax.set_title('Credit Bridge: Bridge Residual Over Training', fontsize=13)
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150)
+ plt.close(fig)
+ print("Saved toy_bridge_residual.png")
+
+# Print summary table
+print("\n" + "="*80)
+print("TOY LQ FINAL RESULTS (3 seeds, 8000 steps)")
+print("="*80)
+
+for method in methods:
+ cos_vals = []
+ rho_vals = []
+ nudge_vals = []
+ for data in all_data:
+ pl = data['final_per_layer']
+ cos_vals.append(np.mean(pl[f'{method}_costate_cos']))
+ rho_vals.append(np.mean(pl[f'{method}_rho']))
+ nudge_vals.append(np.mean(pl[f'{method}_nudge']))
+
+ cos_mean, cos_std = np.mean(cos_vals), np.std(cos_vals)
+ rho_mean, rho_std = np.mean(rho_vals), np.std(rho_vals)
+ nudge_mean, nudge_std = np.mean(nudge_vals), np.std(nudge_vals)
+
+ print(f"{labels[method]:<20} Cosine: {cos_mean:.4f}±{cos_std:.4f} "
+ f"ρ: {rho_mean:.4f}±{rho_std:.4f} "
+ f"Nudge: {nudge_mean:.4f}±{nudge_std:.4f}")