diff options
Diffstat (limited to 'experiments/plot_results.py')
| -rw-r--r-- | experiments/plot_results.py | 327 |
1 files changed, 327 insertions, 0 deletions
diff --git a/experiments/plot_results.py b/experiments/plot_results.py new file mode 100644 index 0000000..e3e2754 --- /dev/null +++ b/experiments/plot_results.py @@ -0,0 +1,327 @@ +"""Generate plots for toy LQ and CIFAR-10 experiments.""" +import os +import sys +import json +import argparse +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def plot_toy_results(results_dir='results/toy_lq', output_dir='report'): + """Plot toy LQ experiment results.""" + os.makedirs(output_dir, exist_ok=True) + + # Collect results across seeds + files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')] + if not files: + print(f"No toy results found in {results_dir}") + return + + all_data = [] + for f in sorted(files): + with open(os.path.join(results_dir, f)) as fp: + all_data.append(json.load(fp)) + + # Use the last result for per-layer plots (or average if multiple seeds) + data = all_data[-1] + per_layer = data['final_per_layer'] + log_data = data['log'] + + num_layers = len(per_layer['dfa_costate_cos']) + layers = list(range(num_layers)) + + # 1. Per-layer costate cosine + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue') + ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange') + ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Cosine Similarity with Exact Costate') + ax.set_title('Exact Costate Cosine (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.2, 1.05) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150) + plt.close(fig) + + # 2. Per-layer perturbation correlation + num_rho_layers = len(per_layer['dfa_rho']) + rho_layers = list(range(num_rho_layers)) + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue') + ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange') + ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation Correlation (rho)') + ax.set_title('Local Perturbation Correlation (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150) + plt.close(fig) + + # 3. Per-layer nudging test + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue') + ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange') + ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Nudge Delta (negative = good)') + ax.set_title('Nudging Test (Toy LQ)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150) + plt.close(fig) + + # 4. Bridge residual over training + if log_data['bridge_residual']: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green') + ax.set_xlabel('Training Step') + ax.set_ylabel('Bridge Residual') + ax.set_title('Bridge Residual Over Training (Toy LQ)') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) + plt.close(fig) + + # 5. Training curves (costate cosine over time) + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + for ax, key, title in zip(axes, + ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'], + ['DFA', 'State Bridge', 'Credit Bridge']): + ax.plot(log_data['steps'], log_data[key], '-') + ax.set_xlabel('Training Step') + ax.set_ylabel('Avg Costate Cosine') + ax.set_title(f'{title} - Costate Cosine Over Training') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150) + plt.close(fig) + + # 6. Per-layer bridge residual + if per_layer.get('bridge_residual'): + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + br_layers = list(range(len(per_layer['bridge_residual']))) + ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green') + ax.set_xlabel('Layer') + ax.set_ylabel('Bridge Residual') + ax.set_title('Per-Layer Bridge Residual (Toy LQ)') + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150) + plt.close(fig) + + print(f"Toy LQ plots saved to {output_dir}/") + + +def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'): + """Plot CIFAR-10 experiment results.""" + os.makedirs(output_dir, exist_ok=True) + + if not os.path.exists(results_path): + print(f"No CIFAR results found at {results_path}") + return + + with open(results_path) as f: + data = json.load(f) + + config = data.pop('config', {}) + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'} + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + seeds = [k for k in data.keys() if k != 'config'] + + # 1. Accuracy curves (mean ± std across seeds) + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + for method in methods: + train_accs = [] + test_accs = [] + for seed in seeds: + if method in data[seed]: + log = data[seed][method]['log'] + train_accs.append(log['train_acc']) + test_accs.append(log['test_acc']) + + if train_accs: + train_arr = np.array(train_accs) + test_arr = np.array(test_accs) + epochs = np.arange(1, train_arr.shape[1] + 1) + + mean_train = train_arr.mean(0) + std_train = train_arr.std(0) + mean_test = test_arr.mean(0) + std_test = test_arr.std(0) + + axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method]) + axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train, + alpha=0.15, color=colors[method]) + axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method]) + axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test, + alpha=0.15, color=colors[method]) + + axes[0].set_xlabel('Epoch') + axes[0].set_ylabel('Train Accuracy') + axes[0].set_title('Train Accuracy') + axes[0].legend() + axes[0].grid(True, alpha=0.3) + axes[1].set_xlabel('Epoch') + axes[1].set_ylabel('Test Accuracy') + axes[1].set_title('Test Accuracy') + axes[1].legend() + axes[1].grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150) + plt.close(fig) + + # 2. Per-layer diagnostics (from last seed) + last_seed = seeds[-1] + + # BP cosine per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'bp_cosine' in diag: + layers = list(range(len(diag['bp_cosine']))) + ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Cosine with BP Gradient') + ax.set_title('Offline BP Cosine (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150) + plt.close(fig) + + # Perturbation rho per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'perturbation_rho' in diag: + layers = list(range(len(diag['perturbation_rho']))) + ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation Correlation (rho)') + ax.set_title('Local Perturbation Correlation (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150) + plt.close(fig) + + # Nudging test per layer (eta=0.01) + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: + diag = data[last_seed][method]['diagnostics'] + if 'nudging' in diag and '0.01' in diag['nudging']: + nud = diag['nudging']['0.01'] + layers = list(range(len(nud))) + ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Layer') + ax.set_ylabel('Nudge Delta (negative = good)') + ax.set_title('Nudging Test eta=0.01 (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150) + plt.close(fig) + + # Feature drift per layer + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + for method in methods: + if method in data[last_seed] and 'drift' in data[last_seed][method]: + drift = data[last_seed][method]['drift'] + # Extract per-block drift (only block weights) + block_drifts = [] + for l in range(12): + key = f'blocks.{l}.w1.weight' + if key in drift: + block_drifts.append(drift[key]) + if block_drifts: + ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method]) + ax.set_xlabel('Block') + ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)') + ax.set_title('Feature Drift (CIFAR-10)') + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150) + plt.close(fig) + + print(f"CIFAR-10 plots saved to {output_dir}/") + + +def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'): + """Print summary table of results.""" + if not os.path.exists(results_path): + print(f"No results at {results_path}") + return + + with open(results_path) as f: + data = json.load(f) + + config = data.pop('config', {}) + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + seeds = [k for k in data.keys() if k != 'config'] + + print("\n" + "="*80) + print("SUMMARY TABLE") + print("="*80) + print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}") + print("-"*80) + + for method in methods: + test_accs = [] + avg_rhos = [] + avg_nudges = [] + avg_bp_cos = [] + + for seed in seeds: + if method in data[seed]: + log = data[seed][method]['log'] + test_accs.append(log['test_acc'][-1]) + + if 'diagnostics' in data[seed][method]: + diag = data[seed][method]['diagnostics'] + if 'perturbation_rho' in diag: + avg_rhos.append(np.mean(diag['perturbation_rho'])) + if 'nudging' in diag and '0.01' in diag['nudging']: + avg_nudges.append(np.mean(diag['nudging']['0.01'])) + if 'bp_cosine' in diag: + avg_bp_cos.append(np.mean(diag['bp_cosine'])) + + ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A" + rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A" + nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A" + bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A" + + print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}") + + print("="*80) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--toy_dir', type=str, default='results/toy_lq') + parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json') + parser.add_argument('--output_dir', type=str, default='report') + args = parser.parse_args() + + plot_toy_results(args.toy_dir, args.output_dir) + plot_cifar_results(args.cifar_path, args.output_dir) + print_summary_table(args.cifar_path) |
