"""Generate plots for toy LQ and CIFAR-10 experiments.""" import os import sys import json import argparse import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt def plot_toy_results(results_dir='results/toy_lq', output_dir='report'): """Plot toy LQ experiment results.""" os.makedirs(output_dir, exist_ok=True) # Collect results across seeds files = [f for f in os.listdir(results_dir) if f.startswith('toy_lq_seed') and f.endswith('.json')] if not files: print(f"No toy results found in {results_dir}") return all_data = [] for f in sorted(files): with open(os.path.join(results_dir, f)) as fp: all_data.append(json.load(fp)) # Use the last result for per-layer plots (or average if multiple seeds) data = all_data[-1] per_layer = data['final_per_layer'] log_data = data['log'] num_layers = len(per_layer['dfa_costate_cos']) layers = list(range(num_layers)) # 1. Per-layer costate cosine fig, ax = plt.subplots(1, 1, figsize=(10, 6)) ax.plot(layers, per_layer['dfa_costate_cos'], 'o-', label='DFA', color='blue') ax.plot(layers, per_layer['state_costate_cos'], 's-', label='State Bridge', color='orange') ax.plot(layers, per_layer['credit_costate_cos'], '^-', label='Credit Bridge', color='green') ax.set_xlabel('Layer') ax.set_ylabel('Cosine Similarity with Exact Costate') ax.set_title('Exact Costate Cosine (Toy LQ)') ax.legend() ax.grid(True, alpha=0.3) ax.set_ylim(-0.2, 1.05) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'toy_costate_cosine.png'), dpi=150) plt.close(fig) # 2. Per-layer perturbation correlation num_rho_layers = len(per_layer['dfa_rho']) rho_layers = list(range(num_rho_layers)) fig, ax = plt.subplots(1, 1, figsize=(10, 6)) ax.plot(rho_layers, per_layer['dfa_rho'], 'o-', label='DFA', color='blue') ax.plot(rho_layers, per_layer['state_rho'], 's-', label='State Bridge', color='orange') ax.plot(rho_layers, per_layer['credit_rho'], '^-', label='Credit Bridge', color='green') ax.set_xlabel('Layer') ax.set_ylabel('Perturbation Correlation (rho)') ax.set_title('Local Perturbation Correlation (Toy LQ)') ax.legend() ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'toy_perturbation_rho.png'), dpi=150) plt.close(fig) # 3. Per-layer nudging test fig, ax = plt.subplots(1, 1, figsize=(10, 6)) ax.plot(rho_layers, per_layer['dfa_nudge'], 'o-', label='DFA', color='blue') ax.plot(rho_layers, per_layer['state_nudge'], 's-', label='State Bridge', color='orange') ax.plot(rho_layers, per_layer['credit_nudge'], '^-', label='Credit Bridge', color='green') ax.set_xlabel('Layer') ax.set_ylabel('Nudge Delta (negative = good)') ax.set_title('Nudging Test (Toy LQ)') ax.legend() ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'toy_nudging.png'), dpi=150) plt.close(fig) # 4. Bridge residual over training if log_data['bridge_residual']: fig, ax = plt.subplots(1, 1, figsize=(10, 6)) ax.plot(log_data['steps'], log_data['bridge_residual'], '-', color='green') ax.set_xlabel('Training Step') ax.set_ylabel('Bridge Residual') ax.set_title('Bridge Residual Over Training (Toy LQ)') ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'toy_bridge_residual.png'), dpi=150) plt.close(fig) # 5. Training curves (costate cosine over time) fig, axes = plt.subplots(1, 3, figsize=(18, 5)) for ax, key, title in zip(axes, ['dfa_costate_cos', 'state_costate_cos', 'credit_costate_cos'], ['DFA', 'State Bridge', 'Credit Bridge']): ax.plot(log_data['steps'], log_data[key], '-') ax.set_xlabel('Training Step') ax.set_ylabel('Avg Costate Cosine') ax.set_title(f'{title} - Costate Cosine Over Training') ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'toy_cosine_training.png'), dpi=150) plt.close(fig) # 6. Per-layer bridge residual if per_layer.get('bridge_residual'): fig, ax = plt.subplots(1, 1, figsize=(10, 6)) br_layers = list(range(len(per_layer['bridge_residual']))) ax.plot(br_layers, per_layer['bridge_residual'], '^-', color='green') ax.set_xlabel('Layer') ax.set_ylabel('Bridge Residual') ax.set_title('Per-Layer Bridge Residual (Toy LQ)') ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'toy_bridge_residual_per_layer.png'), dpi=150) plt.close(fig) print(f"Toy LQ plots saved to {output_dir}/") def plot_cifar_results(results_path='results/cifar10/cifar_results_cifar10.json', output_dir='report'): """Plot CIFAR-10 experiment results.""" os.makedirs(output_dir, exist_ok=True) if not os.path.exists(results_path): print(f"No CIFAR results found at {results_path}") return with open(results_path) as f: data = json.load(f) config = data.pop('config', {}) methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] colors = {'bp': 'red', 'dfa': 'blue', 'state_bridge': 'orange', 'credit_bridge': 'green'} labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} seeds = [k for k in data.keys() if k != 'config'] # 1. Accuracy curves (mean ± std across seeds) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for method in methods: train_accs = [] test_accs = [] for seed in seeds: if method in data[seed]: log = data[seed][method]['log'] train_accs.append(log['train_acc']) test_accs.append(log['test_acc']) if train_accs: train_arr = np.array(train_accs) test_arr = np.array(test_accs) epochs = np.arange(1, train_arr.shape[1] + 1) mean_train = train_arr.mean(0) std_train = train_arr.std(0) mean_test = test_arr.mean(0) std_test = test_arr.std(0) axes[0].plot(epochs, mean_train, '-', color=colors[method], label=labels[method]) axes[0].fill_between(epochs, mean_train - std_train, mean_train + std_train, alpha=0.15, color=colors[method]) axes[1].plot(epochs, mean_test, '-', color=colors[method], label=labels[method]) axes[1].fill_between(epochs, mean_test - std_test, mean_test + std_test, alpha=0.15, color=colors[method]) axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Train Accuracy') axes[0].set_title('Train Accuracy') axes[0].legend() axes[0].grid(True, alpha=0.3) axes[1].set_xlabel('Epoch') axes[1].set_ylabel('Test Accuracy') axes[1].set_title('Test Accuracy') axes[1].legend() axes[1].grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_accuracy.png'), dpi=150) plt.close(fig) # 2. Per-layer diagnostics (from last seed) last_seed = seeds[-1] # BP cosine per layer fig, ax = plt.subplots(1, 1, figsize=(10, 6)) for method in methods: if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: diag = data[last_seed][method]['diagnostics'] if 'bp_cosine' in diag: layers = list(range(len(diag['bp_cosine']))) ax.plot(layers, diag['bp_cosine'], 'o-', color=colors[method], label=labels[method]) ax.set_xlabel('Layer') ax.set_ylabel('Cosine with BP Gradient') ax.set_title('Offline BP Cosine (CIFAR-10)') ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_bp_cosine.png'), dpi=150) plt.close(fig) # Perturbation rho per layer fig, ax = plt.subplots(1, 1, figsize=(10, 6)) for method in methods: if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: diag = data[last_seed][method]['diagnostics'] if 'perturbation_rho' in diag: layers = list(range(len(diag['perturbation_rho']))) ax.plot(layers, diag['perturbation_rho'], 'o-', color=colors[method], label=labels[method]) ax.set_xlabel('Layer') ax.set_ylabel('Perturbation Correlation (rho)') ax.set_title('Local Perturbation Correlation (CIFAR-10)') ax.legend() ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_perturbation_rho.png'), dpi=150) plt.close(fig) # Nudging test per layer (eta=0.01) fig, ax = plt.subplots(1, 1, figsize=(10, 6)) for method in methods: if method in data[last_seed] and 'diagnostics' in data[last_seed][method]: diag = data[last_seed][method]['diagnostics'] if 'nudging' in diag and '0.01' in diag['nudging']: nud = diag['nudging']['0.01'] layers = list(range(len(nud))) ax.plot(layers, nud, 'o-', color=colors[method], label=labels[method]) ax.set_xlabel('Layer') ax.set_ylabel('Nudge Delta (negative = good)') ax.set_title('Nudging Test eta=0.01 (CIFAR-10)') ax.legend() ax.grid(True, alpha=0.3) ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_nudging.png'), dpi=150) plt.close(fig) # Feature drift per layer fig, ax = plt.subplots(1, 1, figsize=(10, 6)) for method in methods: if method in data[last_seed] and 'drift' in data[last_seed][method]: drift = data[last_seed][method]['drift'] # Extract per-block drift (only block weights) block_drifts = [] for l in range(12): key = f'blocks.{l}.w1.weight' if key in drift: block_drifts.append(drift[key]) if block_drifts: ax.plot(range(len(block_drifts)), block_drifts, 'o-', color=colors[method], label=labels[method]) ax.set_xlabel('Block') ax.set_ylabel('Feature Drift (||W_final - W_init||/||W_init||)') ax.set_title('Feature Drift (CIFAR-10)') ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() fig.savefig(os.path.join(output_dir, 'cifar_feature_drift.png'), dpi=150) plt.close(fig) print(f"CIFAR-10 plots saved to {output_dir}/") def print_summary_table(results_path='results/cifar10/cifar_results_cifar10.json'): """Print summary table of results.""" if not os.path.exists(results_path): print(f"No results at {results_path}") return with open(results_path) as f: data = json.load(f) config = data.pop('config', {}) methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} seeds = [k for k in data.keys() if k != 'config'] print("\n" + "="*80) print("SUMMARY TABLE") print("="*80) print(f"{'Method':<20} {'Test Acc':<15} {'Avg rho':<15} {'Avg Nudge(0.01)':<15} {'Avg BP Cos':<15}") print("-"*80) for method in methods: test_accs = [] avg_rhos = [] avg_nudges = [] avg_bp_cos = [] for seed in seeds: if method in data[seed]: log = data[seed][method]['log'] test_accs.append(log['test_acc'][-1]) if 'diagnostics' in data[seed][method]: diag = data[seed][method]['diagnostics'] if 'perturbation_rho' in diag: avg_rhos.append(np.mean(diag['perturbation_rho'])) if 'nudging' in diag and '0.01' in diag['nudging']: avg_nudges.append(np.mean(diag['nudging']['0.01'])) if 'bp_cosine' in diag: avg_bp_cos.append(np.mean(diag['bp_cosine'])) ta = f"{np.mean(test_accs):.4f}±{np.std(test_accs):.4f}" if test_accs else "N/A" rho = f"{np.mean(avg_rhos):.4f}" if avg_rhos else "N/A" nud = f"{np.mean(avg_nudges):.4f}" if avg_nudges else "N/A" bpc = f"{np.mean(avg_bp_cos):.4f}" if avg_bp_cos else "N/A" print(f"{labels[method]:<20} {ta:<15} {rho:<15} {nud:<15} {bpc:<15}") print("="*80) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--toy_dir', type=str, default='results/toy_lq') parser.add_argument('--cifar_path', type=str, default='results/cifar10/cifar_results_cifar10.json') parser.add_argument('--output_dir', type=str, default='report') args = parser.parse_args() plot_toy_results(args.toy_dir, args.output_dir) plot_cifar_results(args.cifar_path, args.output_dir) print_summary_table(args.cifar_path)