From 6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 23 Mar 2026 18:21:26 -0500 Subject: Initial implementation: all models, methods, toy and CIFAR experiments Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching. Credit bridge matches state bridge on linear system (~0.94 cosine). CIFAR experiments in progress. --- experiments/plot_toy_final.py | 183 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 experiments/plot_toy_final.py (limited to 'experiments/plot_toy_final.py') diff --git a/experiments/plot_toy_final.py b/experiments/plot_toy_final.py new file mode 100644 index 0000000..2f7c109 --- /dev/null +++ b/experiments/plot_toy_final.py @@ -0,0 +1,183 @@ +"""Generate final toy LQ experiment plots from v2 results across 3 seeds.""" +import os +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +output_dir = 'report' +os.makedirs(output_dir, exist_ok=True) + +# Load all v2 results with term_grad_weight=1.0, fm=0.0 +seeds = [42, 123, 456] +all_data = [] +for seed in seeds: + path = f'results/toy_lq/toy_lq_v2_seed{seed}_lam0.1_sig0.1_tgw1.0_fm0.0.json' + if os.path.exists(path): + with open(path) as f: + all_data.append(json.load(f)) + +if not all_data: + print("No results found!") + exit() + +# Also load v1 baseline (no term_grad) for comparison +v1_path = 'results/toy_lq/toy_lq_seed42.json' +v1_data = None +if os.path.exists(v1_path): + with open(v1_path) as f: + v1_data = json.load(f) + +# Aggregate final per-layer results across seeds +methods = ['dfa', 'state', 'credit'] +colors = {'dfa': '#2196F3', 'state': '#FF9800', 'credit': '#4CAF50'} +labels = {'dfa': 'DFA', 'state': 'State Bridge', 'credit': 'Credit Bridge'} + +# Per-layer costate cosine +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +for ax, metric, title, ylabel in zip( + axes, + ['costate_cos', 'rho', 'nudge'], + ['Exact Costate Cosine', 'Perturbation Correlation (ρ)', 'Nudging Test'], + ['Cosine Similarity', 'Pearson Correlation', 'Loss Change (negative=good)'] +): + for method in methods: + key = f'{method}_{metric}' + values_per_seed = [] + for data in all_data: + pl = data['final_per_layer'] + if key in pl: + values_per_seed.append(pl[key]) + + if values_per_seed: + arr = np.array(values_per_seed) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + layers = np.arange(len(mean)) + ax.plot(layers, mean, 'o-', color=colors[method], label=labels[method], markersize=5) + ax.fill_between(layers, mean - std, mean + std, alpha=0.15, color=colors[method]) + + ax.set_xlabel('Layer', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + if metric == 'costate_cos': + ax.set_ylim(-0.15, 1.05) + elif metric == 'rho': + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + elif metric == 'nudge': + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + +fig.suptitle('Toy LQ Sanity Check: Per-Layer Diagnostics (3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'toy_per_layer_diagnostics.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved toy_per_layer_diagnostics.png") + +# Training curves +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) +metric_keys = [ + ('costate_cos', 'Avg Costate Cosine', 'Cosine Similarity'), + ('rho', 'Avg Perturbation ρ', 'Pearson Correlation'), + ('nudge', 'Avg Nudging', 'Loss Change'), +] + +for ax, (metric, title, ylabel) in zip(axes, metric_keys): + for method in methods: + key = f'{method}_{metric}' + all_curves = [] + for data in all_data: + log = data['log'] + full_key = f'{method}_costate_cos' if metric == 'costate_cos' else f'{method}_{metric}' + if full_key in log: + all_curves.append(np.array(log[full_key])) + + if all_curves: + # All should have same length, use shortest + min_len = min(len(c) for c in all_curves) + arr = np.array([c[:min_len] for c in all_curves]) + steps = np.array(all_data[0]['log']['steps'][:min_len]) + mean = arr.mean(axis=0) + std = arr.std(axis=0) + ax.plot(steps, mean, '-', color=colors[method], label=labels[method]) + ax.fill_between(steps, mean - std, mean + std, alpha=0.15, color=colors[method]) + + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel(ylabel, fontsize=12) + ax.set_title(title, fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + +fig.suptitle('Toy LQ: Training Curves (3 seeds)', fontsize=14, y=1.02) +fig.tight_layout() +fig.savefig(os.path.join(output_dir, 'toy_training_curves.png'), dpi=150, bbox_inches='tight') +plt.close(fig) +print("Saved toy_training_curves.png") + +# Compare v1 (no term grad) vs v2 (with term grad) for credit bridge +if v1_data: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + + # v1 credit bridge (no term grad matching) + v1_log = v1_data['log'] + ax.plot(v1_log['steps'], v1_log['credit_costate_cos'], + '--', color='red', label='Credit Bridge (w/o terminal grad)', alpha=0.8) + + # v2 credit bridge (with term grad) + v2_log = all_data[0]['log'] # seed 42 + ax.plot(v2_log['steps'], v2_log['credit_costate_cos'], + '-', color='green', label='Credit Bridge (w/ terminal grad)') + + # State bridge for reference + ax.plot(v2_log['steps'], v2_log['state_costate_cos'], + '-', color='orange', label='State Bridge') + + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel('Avg Costate Cosine', fontsize=12) + ax.set_title('Effect of Terminal Gradient Matching', fontsize=13) + ax.legend(fontsize=11) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.05) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_term_grad_effect.png'), dpi=150) + plt.close(fig) + print("Saved toy_term_grad_effect.png") + +# Bridge residual (from v1 which has it) +if v1_data and v1_data['log'].get('bridge_residual'): + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(v1_data['log']['steps'], v1_data['log']['bridge_residual'], '-', color='green') + ax.set_xlabel('Training Step', fontsize=12) + ax.set_ylabel('Bridge Residual', fontsize=12) + ax.set_title('Credit Bridge: Bridge Residual Over Training', fontsize=13) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) + plt.close(fig) + print("Saved toy_bridge_residual.png") + +# Print summary table +print("\n" + "="*80) +print("TOY LQ FINAL RESULTS (3 seeds, 8000 steps)") +print("="*80) + +for method in methods: + cos_vals = [] + rho_vals = [] + nudge_vals = [] + for data in all_data: + pl = data['final_per_layer'] + cos_vals.append(np.mean(pl[f'{method}_costate_cos'])) + rho_vals.append(np.mean(pl[f'{method}_rho'])) + nudge_vals.append(np.mean(pl[f'{method}_nudge'])) + + cos_mean, cos_std = np.mean(cos_vals), np.std(cos_vals) + rho_mean, rho_std = np.mean(rho_vals), np.std(rho_vals) + nudge_mean, nudge_std = np.mean(nudge_vals), np.std(nudge_vals) + + print(f"{labels[method]:<20} Cosine: {cos_mean:.4f}±{cos_std:.4f} " + f"ρ: {rho_mean:.4f}±{rho_std:.4f} " + f"Nudge: {nudge_mean:.4f}±{nudge_std:.4f}") -- cgit v1.2.3