summaryrefslogtreecommitdiff
path: root/experiments/plot_results.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 18:21:26 -0500
commit6ed4fa50ddfa4c7957aaa909aaf72f0d7d317712 (patch)
treed7c63adcd19c4f5d46c8a937e5047fece55dea62 /experiments/plot_results.py
Initial implementation: all models, methods, toy and CIFAR experiments
Debug phase. Toy LQ experiments (3 seeds) complete with terminal gradient matching. Credit bridge matches state bridge on linear system (~0.94 cosine). CIFAR experiments in progress.
Diffstat (limited to 'experiments/plot_results.py')
-rw-r--r--experiments/plot_results.py327
1 files changed, 327 insertions, 0 deletions
diff --git a/experiments/plot_results.py b/experiments/plot_results.py
new file mode 100644
index 0000000..e3e2754
--- /dev/null
+++ b/experiments/plot_results.py
@@ -0,0 +1,327 @@
+"""Generate plots for toy LQ and CIFAR-10 experiments."""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+
+
+def plot_toy_results(results_dir='results/toy_lq', output_dir='report'):
+ """Plot toy LQ experiment results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Collect results across seeds
+ files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')]
+ if not files:
+ print(f"No toy results found in {results_dir}")
+ return
+
+ all_data = []
+ for f in sorted(files):
+ with open(os.path.join(results_dir, f)) as fp:
+ all_data.append(json.load(fp))
+
+ # Use the last result for per-layer plots (or average if multiple seeds)
+ data = all_data[-1]
+ per_layer = data['final_per_layer']
+ log_data = data['log']
+
+ num_layers = len(per_layer['dfa_costate_cos'])
+ layers = list(range(num_layers))
+
+ # 1. Per-layer costate cosine
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue')
+ ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange')
+ ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Cosine Similarity with Exact Costate')
+ ax.set_title('Exact Costate Cosine (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.2, 1.05)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150)
+ plt.close(fig)
+
+ # 2. Per-layer perturbation correlation
+ num_rho_layers = len(per_layer['dfa_rho'])
+ rho_layers = list(range(num_rho_layers))
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue')
+ ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange')
+ ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation Correlation (rho)')
+ ax.set_title('Local Perturbation Correlation (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150)
+ plt.close(fig)
+
+ # 3. Per-layer nudging test
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue')
+ ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange')
+ ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Nudge Delta (negative = good)')
+ ax.set_title('Nudging Test (Toy LQ)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150)
+ plt.close(fig)
+
+ # 4. Bridge residual over training
+ if log_data['bridge_residual']:
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green')
+ ax.set_xlabel('Training Step')
+ ax.set_ylabel('Bridge Residual')
+ ax.set_title('Bridge Residual Over Training (Toy LQ)')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150)
+ plt.close(fig)
+
+ # 5. Training curves (costate cosine over time)
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+ for ax, key, title in zip(axes,
+ ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'],
+ ['DFA', 'State Bridge', 'Credit Bridge']):
+ ax.plot(log_data['steps'], log_data[key], '-')
+ ax.set_xlabel('Training Step')
+ ax.set_ylabel('Avg Costate Cosine')
+ ax.set_title(f'{title} - Costate Cosine Over Training')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150)
+ plt.close(fig)
+
+ # 6. Per-layer bridge residual
+ if per_layer.get('bridge_residual'):
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ br_layers = list(range(len(per_layer['bridge_residual'])))
+ ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green')
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Bridge Residual')
+ ax.set_title('Per-Layer Bridge Residual (Toy LQ)')
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150)
+ plt.close(fig)
+
+ print(f"Toy LQ plots saved to {output_dir}/")
+
+
+def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'):
+ """Plot CIFAR-10 experiment results."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ if not os.path.exists(results_path):
+ print(f"No CIFAR results found at {results_path}")
+ return
+
+ with open(results_path) as f:
+ data = json.load(f)
+
+ config = data.pop('config', {})
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'}
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ seeds = [k for k in data.keys() if k != 'config']
+
+ # 1. Accuracy curves (mean ± std across seeds)
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+ for method in methods:
+ train_accs = []
+ test_accs = []
+ for seed in seeds:
+ if method in data[seed]:
+ log = data[seed][method]['log']
+ train_accs.append(log['train_acc'])
+ test_accs.append(log['test_acc'])
+
+ if train_accs:
+ train_arr = np.array(train_accs)
+ test_arr = np.array(test_accs)
+ epochs = np.arange(1, train_arr.shape[1] + 1)
+
+ mean_train = train_arr.mean(0)
+ std_train = train_arr.std(0)
+ mean_test = test_arr.mean(0)
+ std_test = test_arr.std(0)
+
+ axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method])
+ axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train,
+ alpha=0.15, color=colors[method])
+ axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method])
+ axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test,
+ alpha=0.15, color=colors[method])
+
+ axes[0].set_xlabel('Epoch')
+ axes[0].set_ylabel('Train Accuracy')
+ axes[0].set_title('Train Accuracy')
+ axes[0].legend()
+ axes[0].grid(True, alpha=0.3)
+ axes[1].set_xlabel('Epoch')
+ axes[1].set_ylabel('Test Accuracy')
+ axes[1].set_title('Test Accuracy')
+ axes[1].legend()
+ axes[1].grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150)
+ plt.close(fig)
+
+ # 2. Per-layer diagnostics (from last seed)
+ last_seed = seeds[-1]
+
+ # BP cosine per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'bp_cosine' in diag:
+ layers = list(range(len(diag['bp_cosine'])))
+ ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Cosine with BP Gradient')
+ ax.set_title('Offline BP Cosine (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150)
+ plt.close(fig)
+
+ # Perturbation rho per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'perturbation_rho' in diag:
+ layers = list(range(len(diag['perturbation_rho'])))
+ ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation Correlation (rho)')
+ ax.set_title('Local Perturbation Correlation (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150)
+ plt.close(fig)
+
+ # Nudging test per layer (eta=0.01)
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'diagnostics' in data[last_seed][method]:
+ diag = data[last_seed][method]['diagnostics']
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ nud = diag['nudging']['0.01']
+ layers = list(range(len(nud)))
+ ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Nudge Delta (negative = good)')
+ ax.set_title('Nudging Test eta=0.01 (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150)
+ plt.close(fig)
+
+ # Feature drift per layer
+ fig, ax = plt.subplots(1, 1, figsize=(10, 6))
+ for method in methods:
+ if method in data[last_seed] and 'drift' in data[last_seed][method]:
+ drift = data[last_seed][method]['drift']
+ # Extract per-block drift (only block weights)
+ block_drifts = []
+ for l in range(12):
+ key = f'blocks.{l}.w1.weight'
+ if key in drift:
+ block_drifts.append(drift[key])
+ if block_drifts:
+ ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method])
+ ax.set_xlabel('Block')
+ ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)')
+ ax.set_title('Feature Drift (CIFAR-10)')
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150)
+ plt.close(fig)
+
+ print(f"CIFAR-10 plots saved to {output_dir}/")
+
+
+def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'):
+ """Print summary table of results."""
+ if not os.path.exists(results_path):
+ print(f"No results at {results_path}")
+ return
+
+ with open(results_path) as f:
+ data = json.load(f)
+
+ config = data.pop('config', {})
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ seeds = [k for k in data.keys() if k != 'config']
+
+ print("\n" + "="*80)
+ print("SUMMARY TABLE")
+ print("="*80)
+ print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}")
+ print("-"*80)
+
+ for method in methods:
+ test_accs = []
+ avg_rhos = []
+ avg_nudges = []
+ avg_bp_cos = []
+
+ for seed in seeds:
+ if method in data[seed]:
+ log = data[seed][method]['log']
+ test_accs.append(log['test_acc'][-1])
+
+ if 'diagnostics' in data[seed][method]:
+ diag = data[seed][method]['diagnostics']
+ if 'perturbation_rho' in diag:
+ avg_rhos.append(np.mean(diag['perturbation_rho']))
+ if 'nudging' in diag and '0.01' in diag['nudging']:
+ avg_nudges.append(np.mean(diag['nudging']['0.01']))
+ if 'bp_cosine' in diag:
+ avg_bp_cos.append(np.mean(diag['bp_cosine']))
+
+ ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A"
+ rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A"
+ nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A"
+ bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A"
+
+ print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}")
+
+ print("="*80)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--toy_dir', type=str, default='results/toy_lq')
+ parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json')
+ parser.add_argument('--output_dir', type=str, default='report')
+ args = parser.parse_args()
+
+ plot_toy_results(args.toy_dir, args.output_dir)
+ plot_cifar_results(args.cifar_path, args.output_dir)
+ print_summary_table(args.cifar_path)