diff options
| -rw-r--r-- | experiments/resmlp_frozen_blocks_baseline.py | 14 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_multiseed.log | 363 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed0/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed1/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed2/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed3/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed4/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed5/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed6/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed7/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed8/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/fa_dfa_d512_L2_seed9/results_cifar10.json | 749 | ||||
| -rw-r--r-- | results/frozen_d512_baselines.log | 111 |
13 files changed, 7972 insertions, 6 deletions
diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py index c330be2..6040bd4 100644 --- a/experiments/resmlp_frozen_blocks_baseline.py +++ b/experiments/resmlp_frozen_blocks_baseline.py @@ -137,6 +137,7 @@ def main(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--num_blocks', type=int, default=4) args = parser.parse_args() dev = torch.device('cuda:0') @@ -156,10 +157,11 @@ def main(): results['bp_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) - # Condition 2: BP frozen-blocks (num_blocks=4 frozen) - print(f"\n=== BP frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + # Condition 2: BP frozen-blocks (blocks frozen at random init) + L = args.num_blocks + print(f"\n=== BP frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen') @@ -175,10 +177,10 @@ def main(): results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) - # Condition 4: DFA frozen-blocks (num_blocks=4 frozen) - print(f"\n=== DFA frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + # Condition 4: DFA frozen-blocks (blocks frozen at random init) + print(f"\n=== DFA frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen') diff --git a/results/fa_dfa_d512_L2_multiseed.log b/results/fa_dfa_d512_L2_multiseed.log new file mode 100644 index 0000000..896ebce --- /dev/null +++ b/results/fa_dfa_d512_L2_multiseed.log @@ -0,0 +1,363 @@ +=== FA d=512 L=2 multi-seed scan (seeds 0-9) === +Start: Sun Apr 26 06:37:27 AM CDT 2026 + seed=0 (Sun Apr 26 06:37:27 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 0 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0629, train=0.2387, test=0.2555 + [DFA] Epoch 10: loss=2.0211, train=0.2588, test=0.2879 + [DFA] Epoch 20: loss=2.0143, train=0.2640, test=0.2912 + [DFA] Epoch 30: loss=2.0069, train=0.2703, test=0.2887 + [DFA] Epoch 40: loss=2.0030, train=0.2735, test=0.2812 + [DFA] Epoch 50: loss=2.0021, train=0.2733, test=0.2967 + [DFA] Epoch 60: loss=1.9979, train=0.2782, test=0.2824 + [DFA] Epoch 70: loss=1.9977, train=0.2765, test=0.3007 + [DFA] Epoch 80: loss=1.9978, train=0.2766, test=0.2988 + [DFA] Epoch 90: loss=1.9943, train=0.2764, test=0.3031 + [DFA] Epoch 100: loss=1.9955, train=0.2783, test=0.3001 + Final test acc: 0.3001 + +--- FA --- + [FA] Epoch 1: loss=2.0742, train=0.2441, test=0.2891 + [FA] Epoch 10: loss=1.8376, train=0.3416, test=0.3680 + [FA] Epoch 20: loss=1.8270, train=0.3454, test=0.3785 + [FA] Epoch 30: loss=1.8043, train=0.3507, test=0.3758 + [FA] Epoch 40: loss=1.7738, train=0.3664, test=0.3886 + [FA] Epoch 50: loss=1.7620, train=0.3699, test=0.3825 + [FA] Epoch 60: loss=1.7467, train=0.3785, test=0.3881 + [FA] Epoch 70: loss=1.7352, train=0.3830, test=0.3928 + [FA] Epoch 80: loss=1.7339, train=0.3832, test=0.3959 + [FA] Epoch 90: loss=1.7269, train=0.3870, test=0.3930 + [FA] Epoch 100: loss=1.7298, train=0.3865, test=0.3938 + Final test acc: 0.3938 + +All results saved to results/fa_dfa_d512_L2_seed0/results_cifar10.json + seed=1 (Sun Apr 26 06:49:50 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 1 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0591, train=0.2454, test=0.2655 + [DFA] Epoch 10: loss=2.0095, train=0.2684, test=0.2903 + [DFA] Epoch 20: loss=2.0058, train=0.2681, test=0.2783 + [DFA] Epoch 30: loss=2.0053, train=0.2721, test=0.2762 + [DFA] Epoch 40: loss=2.0044, train=0.2712, test=0.2959 + [DFA] Epoch 50: loss=2.0005, train=0.2729, test=0.2901 + [DFA] Epoch 60: loss=1.9964, train=0.2742, test=0.2919 + [DFA] Epoch 70: loss=1.9982, train=0.2780, test=0.2913 + [DFA] Epoch 80: loss=1.9966, train=0.2753, test=0.2950 + [DFA] Epoch 90: loss=1.9938, train=0.2775, test=0.2936 + [DFA] Epoch 100: loss=1.9950, train=0.2769, test=0.2978 + Final test acc: 0.2978 + +--- FA --- + [FA] Epoch 1: loss=2.0552, train=0.2546, test=0.2978 + [FA] Epoch 10: loss=1.8246, train=0.3458, test=0.3722 + [FA] Epoch 20: loss=1.7888, train=0.3580, test=0.3600 + [FA] Epoch 30: loss=1.7674, train=0.3671, test=0.3397 + [FA] Epoch 40: loss=1.7640, train=0.3680, test=0.3496 + [FA] Epoch 50: loss=1.7560, train=0.3707, test=0.3419 + [FA] Epoch 60: loss=1.7543, train=0.3738, test=0.3465 + [FA] Epoch 70: loss=1.7564, train=0.3744, test=0.3425 + [FA] Epoch 80: loss=1.7502, train=0.3758, test=0.3419 + [FA] Epoch 90: loss=1.7486, train=0.3771, test=0.3436 + [FA] Epoch 100: loss=1.7471, train=0.3765, test=0.3471 + Final test acc: 0.3471 + +All results saved to results/fa_dfa_d512_L2_seed1/results_cifar10.json + seed=2 (Sun Apr 26 07:02:14 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 2 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0584, train=0.2404, test=0.2780 + [DFA] Epoch 10: loss=2.0235, train=0.2615, test=0.2647 + [DFA] Epoch 20: loss=2.0273, train=0.2582, test=0.2770 + [DFA] Epoch 30: loss=2.0230, train=0.2623, test=0.2923 + [DFA] Epoch 40: loss=2.0245, train=0.2655, test=0.2919 + [DFA] Epoch 50: loss=2.0236, train=0.2650, test=0.2856 + [DFA] Epoch 60: loss=2.0224, train=0.2710, test=0.2876 + [DFA] Epoch 70: loss=2.0200, train=0.2690, test=0.2933 + [DFA] Epoch 80: loss=2.0180, train=0.2696, test=0.2968 + [DFA] Epoch 90: loss=2.0171, train=0.2729, test=0.2971 + [DFA] Epoch 100: loss=2.0162, train=0.2734, test=0.2968 + Final test acc: 0.2968 + +--- FA --- + [FA] Epoch 1: loss=2.0631, train=0.2423, test=0.2954 + [FA] Epoch 10: loss=1.8296, train=0.3454, test=0.3479 + [FA] Epoch 20: loss=1.8376, train=0.3439, test=0.3729 + [FA] Epoch 30: loss=1.8384, train=0.3410, test=0.3410 + [FA] Epoch 40: loss=1.8241, train=0.3480, test=0.3644 + [FA] Epoch 50: loss=1.8116, train=0.3500, test=0.3488 + [FA] Epoch 60: loss=1.8008, train=0.3585, test=0.3434 + [FA] Epoch 70: loss=1.7941, train=0.3608, test=0.3357 + [FA] Epoch 80: loss=1.7894, train=0.3592, test=0.3427 + [FA] Epoch 90: loss=1.7872, train=0.3613, test=0.3458 + [FA] Epoch 100: loss=1.7820, train=0.3655, test=0.3464 + Final test acc: 0.3464 + +All results saved to results/fa_dfa_d512_L2_seed2/results_cifar10.json + seed=3 (Sun Apr 26 07:14:48 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 3 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0491, train=0.2519, test=0.2855 + [DFA] Epoch 10: loss=2.0729, train=0.2429, test=0.2372 + [DFA] Epoch 20: loss=2.0572, train=0.2480, test=0.2718 + [DFA] Epoch 30: loss=2.0397, train=0.2584, test=0.2648 + [DFA] Epoch 40: loss=2.0308, train=0.2633, test=0.2861 + [DFA] Epoch 50: loss=2.0252, train=0.2686, test=0.2873 + [DFA] Epoch 60: loss=2.0212, train=0.2711, test=0.2754 + [DFA] Epoch 70: loss=2.0171, train=0.2705, test=0.2885 + [DFA] Epoch 80: loss=2.0116, train=0.2746, test=0.2942 + [DFA] Epoch 90: loss=2.0133, train=0.2721, test=0.2908 + [DFA] Epoch 100: loss=2.0157, train=0.2732, test=0.2922 + Final test acc: 0.2922 + +--- FA --- + [FA] Epoch 1: loss=2.0629, train=0.2512, test=0.2764 + [FA] Epoch 10: loss=1.8419, train=0.3392, test=0.3642 + [FA] Epoch 20: loss=1.7965, train=0.3568, test=0.3832 + [FA] Epoch 30: loss=1.7868, train=0.3598, test=0.3835 + [FA] Epoch 40: loss=1.7775, train=0.3646, test=0.3865 + [FA] Epoch 50: loss=1.7635, train=0.3726, test=0.3924 + [FA] Epoch 60: loss=1.7572, train=0.3736, test=0.3852 + [FA] Epoch 70: loss=1.7471, train=0.3794, test=0.3979 + [FA] Epoch 80: loss=1.7353, train=0.3832, test=0.3941 + [FA] Epoch 90: loss=1.7358, train=0.3846, test=0.3993 + [FA] Epoch 100: loss=1.7309, train=0.3885, test=0.4012 + Final test acc: 0.4012 + +All results saved to results/fa_dfa_d512_L2_seed3/results_cifar10.json + seed=4 (Sun Apr 26 07:27:09 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 4 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0539, train=0.2433, test=0.2691 + [DFA] Epoch 10: loss=2.0219, train=0.2587, test=0.2707 + [DFA] Epoch 20: loss=2.0191, train=0.2646, test=0.2732 + [DFA] Epoch 30: loss=2.0129, train=0.2656, test=0.2762 + [DFA] Epoch 40: loss=2.0132, train=0.2675, test=0.2768 + [DFA] Epoch 50: loss=2.0101, train=0.2685, test=0.2703 + [DFA] Epoch 60: loss=2.0089, train=0.2689, test=0.2861 + [DFA] Epoch 70: loss=2.0098, train=0.2676, test=0.2876 + [DFA] Epoch 80: loss=2.0067, train=0.2690, test=0.2911 + [DFA] Epoch 90: loss=2.0036, train=0.2691, test=0.2899 + [DFA] Epoch 100: loss=2.0048, train=0.2729, test=0.2861 + Final test acc: 0.2861 + +--- FA --- + [FA] Epoch 1: loss=2.0835, train=0.2421, test=0.2873 + [FA] Epoch 10: loss=1.8501, train=0.3365, test=0.3577 + [FA] Epoch 20: loss=1.8507, train=0.3369, test=0.3386 + [FA] Epoch 30: loss=1.8146, train=0.3477, test=0.3402 + [FA] Epoch 40: loss=1.8072, train=0.3527, test=0.3409 + [FA] Epoch 50: loss=1.7822, train=0.3630, test=0.3474 + [FA] Epoch 60: loss=1.7633, train=0.3659, test=0.3393 + [FA] Epoch 70: loss=1.7476, train=0.3725, test=0.3393 + [FA] Epoch 80: loss=1.7440, train=0.3747, test=0.3537 + [FA] Epoch 90: loss=1.7396, train=0.3730, test=0.3495 + [FA] Epoch 100: loss=1.7415, train=0.3765, test=0.3501 + Final test acc: 0.3501 + +All results saved to results/fa_dfa_d512_L2_seed4/results_cifar10.json + seed=5 (Sun Apr 26 07:39:42 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 5 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0501, train=0.2507, test=0.2723 + [DFA] Epoch 10: loss=2.0202, train=0.2642, test=0.2896 + [DFA] Epoch 20: loss=2.0155, train=0.2692, test=0.2842 + [DFA] Epoch 30: loss=2.0135, train=0.2699, test=0.2989 + [DFA] Epoch 40: loss=2.0102, train=0.2722, test=0.2843 + [DFA] Epoch 50: loss=2.0075, train=0.2735, test=0.2972 + [DFA] Epoch 60: loss=2.0089, train=0.2730, test=0.2916 + [DFA] Epoch 70: loss=2.0058, train=0.2753, test=0.2993 + [DFA] Epoch 80: loss=2.0035, train=0.2769, test=0.2966 + [DFA] Epoch 90: loss=2.0024, train=0.2759, test=0.2998 + [DFA] Epoch 100: loss=2.0019, train=0.2791, test=0.2963 + Final test acc: 0.2963 + +--- FA --- + [FA] Epoch 1: loss=2.0579, train=0.2506, test=0.2861 + [FA] Epoch 10: loss=1.8242, train=0.3435, test=0.3683 + [FA] Epoch 20: loss=1.8064, train=0.3523, test=0.3588 + [FA] Epoch 30: loss=1.7973, train=0.3574, test=0.3697 + [FA] Epoch 40: loss=1.7889, train=0.3577, test=0.3599 + [FA] Epoch 50: loss=1.7732, train=0.3651, test=0.3452 + [FA] Epoch 60: loss=1.7614, train=0.3703, test=0.3428 + [FA] Epoch 70: loss=1.7496, train=0.3739, test=0.3374 + [FA] Epoch 80: loss=1.7420, train=0.3770, test=0.3406 + [FA] Epoch 90: loss=1.7398, train=0.3786, test=0.3446 + [FA] Epoch 100: loss=1.7391, train=0.3797, test=0.3410 + Final test acc: 0.3410 + +All results saved to results/fa_dfa_d512_L2_seed5/results_cifar10.json + seed=6 (Sun Apr 26 07:52:11 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 6 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0559, train=0.2468, test=0.2713 + [DFA] Epoch 10: loss=2.0195, train=0.2639, test=0.2738 + [DFA] Epoch 20: loss=2.0121, train=0.2678, test=0.2808 + [DFA] Epoch 30: loss=2.0112, train=0.2643, test=0.2850 + [DFA] Epoch 40: loss=2.0103, train=0.2670, test=0.2961 + [DFA] Epoch 50: loss=2.0083, train=0.2673, test=0.2983 + [DFA] Epoch 60: loss=2.0041, train=0.2734, test=0.2863 + [DFA] Epoch 70: loss=2.0043, train=0.2745, test=0.2956 + [DFA] Epoch 80: loss=1.9999, train=0.2742, test=0.2878 + [DFA] Epoch 90: loss=2.0039, train=0.2733, test=0.2938 + [DFA] Epoch 100: loss=2.0008, train=0.2720, test=0.2933 + Final test acc: 0.2933 + +--- FA --- + [FA] Epoch 1: loss=2.0650, train=0.2447, test=0.2947 + [FA] Epoch 10: loss=1.8200, train=0.3500, test=0.3699 + [FA] Epoch 20: loss=1.7577, train=0.3697, test=0.3955 + [FA] Epoch 30: loss=1.7312, train=0.3816, test=0.3770 + [FA] Epoch 40: loss=1.7424, train=0.3782, test=0.3734 + [FA] Epoch 50: loss=1.7344, train=0.3837, test=0.3712 + [FA] Epoch 60: loss=1.7249, train=0.3860, test=0.3726 + [FA] Epoch 70: loss=1.7168, train=0.3903, test=0.3812 + [FA] Epoch 80: loss=1.7108, train=0.3942, test=0.3826 + [FA] Epoch 90: loss=1.7121, train=0.3952, test=0.3857 + [FA] Epoch 100: loss=1.7054, train=0.3982, test=0.3865 + Final test acc: 0.3865 + +All results saved to results/fa_dfa_d512_L2_seed6/results_cifar10.json + seed=7 (Sun Apr 26 08:04:41 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 7 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0552, train=0.2426, test=0.2524 + [DFA] Epoch 10: loss=2.0217, train=0.2618, test=0.2860 + [DFA] Epoch 20: loss=2.0057, train=0.2688, test=0.2705 + [DFA] Epoch 30: loss=1.9929, train=0.2755, test=0.2971 + [DFA] Epoch 40: loss=1.9826, train=0.2822, test=0.3056 + [DFA] Epoch 50: loss=1.9748, train=0.2864, test=0.2965 + [DFA] Epoch 60: loss=1.9717, train=0.2885, test=0.3153 + [DFA] Epoch 70: loss=1.9681, train=0.2900, test=0.3103 + [DFA] Epoch 80: loss=1.9665, train=0.2913, test=0.3141 + [DFA] Epoch 90: loss=1.9648, train=0.2903, test=0.3161 + [DFA] Epoch 100: loss=1.9661, train=0.2882, test=0.3157 + Final test acc: 0.3157 + +--- FA --- + [FA] Epoch 1: loss=2.0654, train=0.2482, test=0.2781 + [FA] Epoch 10: loss=1.8355, train=0.3425, test=0.3543 + [FA] Epoch 20: loss=1.8289, train=0.3434, test=0.3387 + [FA] Epoch 30: loss=1.8337, train=0.3407, test=0.3371 + [FA] Epoch 40: loss=1.8044, train=0.3554, test=0.3577 + [FA] Epoch 50: loss=1.7871, train=0.3601, test=0.3339 + [FA] Epoch 60: loss=1.7844, train=0.3637, test=0.3612 + [FA] Epoch 70: loss=1.7775, train=0.3683, test=0.3508 + [FA] Epoch 80: loss=1.7746, train=0.3716, test=0.3519 + [FA] Epoch 90: loss=1.7633, train=0.3720, test=0.3570 + [FA] Epoch 100: loss=1.7649, train=0.3746, test=0.3537 + Final test acc: 0.3537 + +All results saved to results/fa_dfa_d512_L2_seed7/results_cifar10.json + seed=8 (Sun Apr 26 08:17:07 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 8 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0461, train=0.2457, test=0.2641 + [DFA] Epoch 10: loss=2.0045, train=0.2637, test=0.2704 + [DFA] Epoch 20: loss=1.9953, train=0.2706, test=0.2883 + [DFA] Epoch 30: loss=1.9922, train=0.2731, test=0.2966 + [DFA] Epoch 40: loss=1.9912, train=0.2745, test=0.2928 + [DFA] Epoch 50: loss=1.9866, train=0.2774, test=0.2894 + [DFA] Epoch 60: loss=1.9862, train=0.2775, test=0.2875 + [DFA] Epoch 70: loss=1.9859, train=0.2799, test=0.2968 + [DFA] Epoch 80: loss=1.9836, train=0.2803, test=0.2973 + [DFA] Epoch 90: loss=1.9812, train=0.2794, test=0.2930 + [DFA] Epoch 100: loss=1.9818, train=0.2811, test=0.2967 + Final test acc: 0.2967 + +--- FA --- + [FA] Epoch 1: loss=2.0653, train=0.2482, test=0.2970 + [FA] Epoch 10: loss=1.8267, train=0.3462, test=0.3510 + [FA] Epoch 20: loss=1.7845, train=0.3614, test=0.3742 + [FA] Epoch 30: loss=1.7457, train=0.3738, test=0.3967 + [FA] Epoch 40: loss=1.7335, train=0.3797, test=0.3871 + [FA] Epoch 50: loss=1.7340, train=0.3824, test=0.4010 + [FA] Epoch 60: loss=1.7299, train=0.3851, test=0.3996 + [FA] Epoch 70: loss=1.7155, train=0.3908, test=0.4084 + [FA] Epoch 80: loss=1.7086, train=0.3952, test=0.4138 + [FA] Epoch 90: loss=1.7057, train=0.3961, test=0.4136 + [FA] Epoch 100: loss=1.7031, train=0.3992, test=0.4146 + Final test acc: 0.4146 + +All results saved to results/fa_dfa_d512_L2_seed8/results_cifar10.json + seed=9 (Sun Apr 26 08:29:32 AM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 9 +============================================================ + +--- DFA --- + [DFA] Epoch 1: loss=2.0386, train=0.2506, test=0.2683 + [DFA] Epoch 10: loss=1.9998, train=0.2732, test=0.2891 + [DFA] Epoch 20: loss=1.9913, train=0.2765, test=0.3009 + [DFA] Epoch 30: loss=1.9860, train=0.2787, test=0.2937 + [DFA] Epoch 40: loss=1.9853, train=0.2809, test=0.3125 + [DFA] Epoch 50: loss=1.9797, train=0.2828, test=0.2959 + [DFA] Epoch 60: loss=1.9782, train=0.2845, test=0.3108 + [DFA] Epoch 70: loss=1.9766, train=0.2833, test=0.3020 + [DFA] Epoch 80: loss=1.9734, train=0.2889, test=0.3189 + [DFA] Epoch 90: loss=1.9717, train=0.2894, test=0.3147 + [DFA] Epoch 100: loss=1.9705, train=0.2891, test=0.3156 + Final test acc: 0.3156 + +--- FA --- + [FA] Epoch 1: loss=2.0596, train=0.2467, test=0.2753 + [FA] Epoch 10: loss=1.8433, train=0.3403, test=0.3507 + [FA] Epoch 20: loss=1.8299, train=0.3471, test=0.3675 + [FA] Epoch 30: loss=1.8167, train=0.3509, test=0.3671 + [FA] Epoch 40: loss=1.8146, train=0.3520, test=0.3633 + [FA] Epoch 50: loss=1.8032, train=0.3550, test=0.3489 + [FA] Epoch 60: loss=1.7958, train=0.3599, test=0.3571 + [FA] Epoch 70: loss=1.7952, train=0.3629, test=0.3583 + [FA] Epoch 80: loss=1.7971, train=0.3645, test=0.3653 + [FA] Epoch 90: loss=1.7904, train=0.3673, test=0.3676 + [FA] Epoch 100: loss=1.7879, train=0.3666, test=0.3661 + Final test acc: 0.3661 + +All results saved to results/fa_dfa_d512_L2_seed9/results_cifar10.json +=== MULTI-SEED SCAN DONE (Sun Apr 26 08:41:52 AM CDT 2026) === diff --git a/results/fa_dfa_d512_L2_seed0/results_cifar10.json b/results/fa_dfa_d512_L2_seed0/results_cifar10.json new file mode 100644 index 0000000..7d68e48 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed0/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "0": { + "dfa": { + "log": { + "train_loss": [ + 2.062894982070923, + 2.042903160247803, + 2.0362783416748047, + 2.0289835427093506, + 2.0305548177719115, + 2.027372160644531, + 2.0264537405395506, + 2.024415637550354, + 2.0239928167343137, + 2.0210681298828126, + 2.0223728131103518, + 2.0176786605072023, + 2.0183338179016115, + 2.015661160888672, + 2.017801855773926, + 2.0157006243896483, + 2.013551385803223, + 2.013131450958252, + 2.012682000961304, + 2.0142799925231936, + 2.010995005722046, + 2.01094291557312, + 2.010355429763794, + 2.00825265914917, + 2.0103073081207277, + 2.0115878075790405, + 2.00740104183197, + 2.008656568450928, + 2.007300491104126, + 2.0068678036499024, + 2.0060133220672607, + 2.005689826889038, + 2.0071374508666993, + 2.0041723249053955, + 2.007439960021973, + 2.004801420516968, + 2.007543896255493, + 2.008051019935608, + 2.0040309619140624, + 2.002984625930786, + 2.0036395249176024, + 2.0038933919906614, + 2.000672343521118, + 2.002527554702759, + 2.0040715547180175, + 2.002382646026611, + 2.002094281616211, + 2.0012509018707276, + 2.0035453020477294, + 2.0021117819213865, + 2.0006714540863038, + 2.0024239098739622, + 2.000547568016052, + 2.0004410552978515, + 2.000826226043701, + 1.9995253698730469, + 2.0001726056671143, + 2.0008310513687135, + 1.998274679107666, + 1.9978595404052735, + 1.9991289748382568, + 1.9976552139663697, + 1.9974908544540406, + 1.9972500681304932, + 1.9982564139556884, + 1.9990739729309082, + 1.9982671169281006, + 1.9968176255035401, + 1.9985010131072998, + 1.997684263381958, + 1.9971537530517578, + 1.9961954132843018, + 1.9972638996124268, + 1.9965271125793458, + 1.9960845895385741, + 1.9970510581970216, + 1.9964770040512085, + 1.9966202107238769, + 1.99665422416687, + 1.9977766342163086, + 1.9974106648254395, + 1.9950730167007447, + 1.9971654697418213, + 1.995940417022705, + 1.9962149201202393, + 1.9963797412109374, + 1.9944206412506102, + 1.9952468407440185, + 1.9964326373291015, + 1.99426823387146, + 1.9950427423477173, + 1.9937485892105102, + 1.995298335914612, + 1.9937747159194947, + 1.99329125541687, + 1.9960021032333375, + 1.9957950158309936, + 1.9929876937103272, + 1.994628695678711, + 1.9954896281433105 + ], + "train_acc": [ + 0.2387, + 0.24782, + 0.24834, + 0.255, + 0.2542, + 0.25346, + 0.2542, + 0.25464, + 0.25776, + 0.25884, + 0.25932, + 0.26184, + 0.26154, + 0.2615, + 0.26154, + 0.26214, + 0.26598, + 0.26548, + 0.26678, + 0.26396, + 0.2672, + 0.2662, + 0.26618, + 0.26934, + 0.2649, + 0.26488, + 0.26854, + 0.2669, + 0.26922, + 0.27032, + 0.2688, + 0.26942, + 0.26852, + 0.2698, + 0.2708, + 0.27092, + 0.27064, + 0.2683, + 0.27222, + 0.2735, + 0.27254, + 0.26794, + 0.27416, + 0.27154, + 0.2709, + 0.2735, + 0.27234, + 0.2735, + 0.2704, + 0.27326, + 0.27246, + 0.2728, + 0.27374, + 0.2747, + 0.27344, + 0.2739, + 0.2744, + 0.27318, + 0.27582, + 0.27816, + 0.27576, + 0.27492, + 0.27696, + 0.27608, + 0.27436, + 0.27578, + 0.27502, + 0.27806, + 0.274, + 0.27654, + 0.27598, + 0.27728, + 0.2783, + 0.27636, + 0.27394, + 0.27612, + 0.2788, + 0.27772, + 0.27614, + 0.27662, + 0.27458, + 0.27506, + 0.27614, + 0.2755, + 0.27668, + 0.27622, + 0.27808, + 0.27856, + 0.27796, + 0.27642, + 0.27808, + 0.28, + 0.27782, + 0.27638, + 0.27682, + 0.2774, + 0.27772, + 0.27642, + 0.27804, + 0.27828 + ], + "test_acc": [ + 0.2555, + 0.2749, + 0.2689, + 0.2689, + 0.2829, + 0.2478, + 0.2839, + 0.2723, + 0.2772, + 0.2879, + 0.2743, + 0.2663, + 0.2681, + 0.2905, + 0.2902, + 0.284, + 0.2788, + 0.2919, + 0.2851, + 0.2912, + 0.2889, + 0.2929, + 0.2868, + 0.2687, + 0.2855, + 0.304, + 0.292, + 0.2903, + 0.2882, + 0.2887, + 0.2923, + 0.2995, + 0.2973, + 0.2915, + 0.2914, + 0.2869, + 0.2737, + 0.2922, + 0.2898, + 0.2812, + 0.2924, + 0.3001, + 0.305, + 0.2996, + 0.2954, + 0.3002, + 0.2926, + 0.2858, + 0.3015, + 0.2967, + 0.2964, + 0.2897, + 0.3042, + 0.286, + 0.294, + 0.2986, + 0.2865, + 0.3013, + 0.2935, + 0.2824, + 0.2985, + 0.2908, + 0.3008, + 0.2949, + 0.3059, + 0.289, + 0.2914, + 0.305, + 0.2975, + 0.3007, + 0.3019, + 0.3009, + 0.2996, + 0.2976, + 0.2945, + 0.2986, + 0.3022, + 0.2991, + 0.2976, + 0.2988, + 0.2988, + 0.2979, + 0.2974, + 0.298, + 0.3014, + 0.2972, + 0.3003, + 0.2981, + 0.3005, + 0.3031, + 0.3026, + 0.3007, + 0.3004, + 0.2997, + 0.3016, + 0.3003, + 0.2993, + 0.2995, + 0.3002, + 0.3001 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.38934820890426636, + -0.0007739475695416331 + ], + "perturbation_rho": [ + -0.011460170149803162, + 0.0 + ], + "nudging": { + "0.001": [ + -4.163011908531189e-07, + 0.0 + ], + "0.003": [ + -1.0598450899124146e-06, + 0.0 + ], + "0.01": [ + -3.500375896692276e-06, + 0.0 + ] + }, + "hidden_norms_per_layer": [ + 52784.04296875, + 1514456960.0, + 4954951168.0 + ], + "bp_grad_norms_per_layer": [ + 2.3542017402178317e-07, + 3.12540243685433e-10, + 3.13137349383652e-10 + ] + }, + "drift": { + "embed.weight": 333.9178900630863, + "embed.bias": 277.5308018858926, + "blocks.0.ln.weight": 9.301640363801488, + "blocks.0.w1.weight": 311.12199748150937, + "blocks.0.w1.bias": 280.3822361773791, + "blocks.0.w2.weight": 500.10543177407226, + "blocks.1.ln.weight": 9.416271027168449, + "blocks.1.w1.weight": 397.745772813683, + "blocks.1.w1.bias": 381.1405085841875, + "blocks.1.w2.weight": 396.08258246162814, + "out_ln.weight": 0.5525577953421494, + "out_head.weight": 8.092380783169268, + "out_head.bias": 3.6707232448212266 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.074207492980957, + 1.9848071645736693, + 1.943194331741333, + 1.915086039428711, + 1.8987552841186524, + 1.8773444555282592, + 1.8620493222808838, + 1.8454281293106078, + 1.843613411216736, + 1.8376456963348389, + 1.8340519763946532, + 1.8251803881072999, + 1.8269455047607421, + 1.825891435623169, + 1.8311459658432008, + 1.8294941538238525, + 1.8284651235580445, + 1.8239954425811769, + 1.824907121887207, + 1.826982011680603, + 1.8236070413970946, + 1.8217376638412475, + 1.82049068069458, + 1.8151681272125244, + 1.8155477234649657, + 1.8134433920288087, + 1.806967512512207, + 1.8068813376235962, + 1.801505942955017, + 1.8042971012115478, + 1.800074999961853, + 1.7939591689300538, + 1.792923740272522, + 1.7907480084991456, + 1.7882287372589112, + 1.7897650888824463, + 1.7839135689926147, + 1.7830314935302733, + 1.7808781423187257, + 1.7737779104614257, + 1.7762451565933228, + 1.7733127856826782, + 1.7701874697875977, + 1.7670377206802368, + 1.7671197887802124, + 1.7665505773162842, + 1.7623907526016236, + 1.7633579796600343, + 1.7624771506500243, + 1.7620069681167603, + 1.7585726612091064, + 1.7578929253387452, + 1.7546437027740478, + 1.7530586669921875, + 1.7568175424575805, + 1.7472540600585937, + 1.7497945638656616, + 1.7486419037246703, + 1.7454212425994873, + 1.7467387356948854, + 1.742750499343872, + 1.7459785363006592, + 1.7462484993743896, + 1.7419252165985109, + 1.7439173141479491, + 1.7400725153350831, + 1.7417476744842528, + 1.7432426296234131, + 1.7410478344726563, + 1.735186776046753, + 1.7373164199447633, + 1.7379439109039307, + 1.7370698404312135, + 1.7362006621932984, + 1.7366024084091187, + 1.734513318862915, + 1.7316589307403565, + 1.734174518661499, + 1.7326682236480713, + 1.7339326565170288, + 1.734504469909668, + 1.728739596824646, + 1.7321637844848632, + 1.7324034854125976, + 1.7288938723754883, + 1.7308405535888671, + 1.7306951354980469, + 1.7300820154190064, + 1.7285541592788696, + 1.7268984225845336, + 1.727956399230957, + 1.724655009765625, + 1.7292626781845093, + 1.7271040807723999, + 1.7264983687973023, + 1.7272509225845336, + 1.7300393019104003, + 1.7276934866714477, + 1.7269879544067384, + 1.72983638671875 + ], + "train_acc": [ + 0.24412, + 0.28062, + 0.29426, + 0.30988, + 0.31292, + 0.32326, + 0.32948, + 0.33612, + 0.33972, + 0.34158, + 0.34404, + 0.34674, + 0.34434, + 0.34662, + 0.3424, + 0.3445, + 0.34358, + 0.34696, + 0.34462, + 0.3454, + 0.34674, + 0.3471, + 0.34642, + 0.34934, + 0.34898, + 0.35068, + 0.35258, + 0.3542, + 0.35448, + 0.35074, + 0.354, + 0.35526, + 0.35522, + 0.35664, + 0.36026, + 0.36056, + 0.36136, + 0.35988, + 0.3653, + 0.3664, + 0.36336, + 0.36712, + 0.36668, + 0.36956, + 0.36908, + 0.36934, + 0.3731, + 0.36976, + 0.36908, + 0.3699, + 0.37166, + 0.37488, + 0.3728, + 0.37648, + 0.37234, + 0.3769, + 0.3778, + 0.37956, + 0.37994, + 0.37854, + 0.3793, + 0.38058, + 0.3788, + 0.3813, + 0.3811, + 0.38082, + 0.3786, + 0.38008, + 0.37906, + 0.383, + 0.3824, + 0.38106, + 0.38238, + 0.38434, + 0.37922, + 0.3843, + 0.38626, + 0.38438, + 0.38332, + 0.38316, + 0.3814, + 0.3837, + 0.38628, + 0.3838, + 0.38572, + 0.38528, + 0.3855, + 0.38382, + 0.38534, + 0.38696, + 0.3866, + 0.38526, + 0.38498, + 0.38372, + 0.38534, + 0.38782, + 0.38812, + 0.38732, + 0.38564, + 0.38654 + ], + "test_acc": [ + 0.2891, + 0.3185, + 0.3279, + 0.3405, + 0.3535, + 0.3324, + 0.3578, + 0.3568, + 0.3611, + 0.368, + 0.3716, + 0.3657, + 0.3624, + 0.37, + 0.3642, + 0.3776, + 0.3742, + 0.3629, + 0.3623, + 0.3785, + 0.3689, + 0.3538, + 0.3593, + 0.3592, + 0.3713, + 0.3788, + 0.3722, + 0.3675, + 0.3747, + 0.3758, + 0.3671, + 0.3866, + 0.3809, + 0.3827, + 0.3781, + 0.3711, + 0.375, + 0.3801, + 0.3829, + 0.3886, + 0.3844, + 0.3859, + 0.3842, + 0.3854, + 0.3855, + 0.3871, + 0.3835, + 0.3874, + 0.3931, + 0.3825, + 0.3856, + 0.3901, + 0.3877, + 0.3854, + 0.3894, + 0.3873, + 0.3886, + 0.3982, + 0.3923, + 0.3881, + 0.3883, + 0.3901, + 0.3907, + 0.3926, + 0.3942, + 0.3933, + 0.3851, + 0.3912, + 0.393, + 0.3928, + 0.3961, + 0.3907, + 0.3876, + 0.3881, + 0.3875, + 0.3885, + 0.3909, + 0.3893, + 0.394, + 0.3959, + 0.3886, + 0.3945, + 0.3932, + 0.3919, + 0.3906, + 0.3944, + 0.3939, + 0.3946, + 0.3923, + 0.393, + 0.3929, + 0.3923, + 0.3915, + 0.3948, + 0.3934, + 0.3929, + 0.3932, + 0.3936, + 0.394, + 0.3938 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.017680617049336433, + 0.8765696287155151 + ], + "perturbation_rho": [ + -0.00557959359139204, + 0.06313759833574295 + ], + "nudging": { + "0.001": [ + -2.048211172223091e-06, + -3.955443389713764e-06 + ], + "0.003": [ + -6.195507012307644e-06, + -1.213036011904478e-05 + ], + "0.01": [ + -2.056185621768236e-05, + -4.089018329977989e-05 + ] + }, + "hidden_norms_per_layer": [ + 5668.7490234375, + 427742.1875, + 626308.1875 + ], + "bp_grad_norms_per_layer": [ + 2.2485837689600885e-05, + 1.0125075959876995e-06, + 7.618949666721164e-07 + ] + }, + "drift": { + "embed.weight": 43.36126831717888, + "embed.bias": 11.647761057592673, + "blocks.0.ln.weight": 1.7734721493198804, + "blocks.0.w1.weight": 28.788589686086883, + "blocks.0.w1.bias": 16.981347027161885, + "blocks.0.w2.weight": 61.91338151396999, + "blocks.1.ln.weight": 1.2975357664484612, + "blocks.1.w1.weight": 23.055782710684934, + "blocks.1.w1.bias": 17.973225870539384, + "blocks.1.w2.weight": 38.08861834310624, + "out_ln.weight": 0.38876131505924627, + "out_head.weight": 5.716507111968003, + "out_head.bias": 4.553337714995325 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 0 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed0", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed1/results_cifar10.json b/results/fa_dfa_d512_L2_seed1/results_cifar10.json new file mode 100644 index 0000000..0dd9e9b --- /dev/null +++ b/results/fa_dfa_d512_L2_seed1/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "1": { + "dfa": { + "log": { + "train_loss": [ + 2.059094401855469, + 2.044447731781006, + 2.0331237268066404, + 2.029790968055725, + 2.0203696198272705, + 2.019836251373291, + 2.0172710887908933, + 2.0157542709350587, + 2.01279447265625, + 2.0095205695343017, + 2.0108920065689087, + 2.0079552867889405, + 2.0044708168792726, + 2.0045814642333983, + 2.008353828582764, + 2.0057097484970092, + 2.0044997080993654, + 2.008672186203003, + 2.007922875289917, + 2.005835166091919, + 2.0033903245544433, + 2.0044442585754396, + 2.002520341949463, + 2.00441702545166, + 2.0058324599456787, + 2.004086421432495, + 2.0051950196838377, + 2.0038982134246828, + 2.004509604110718, + 2.0052878368377685, + 2.002380726890564, + 2.004787947921753, + 2.0059314285278322, + 2.00248601272583, + 2.0011240525817873, + 2.0024265776062014, + 2.000450315246582, + 2.003341311035156, + 2.0023523516082764, + 2.0043818824768067, + 2.0009307250213624, + 1.9997900833129882, + 1.99952888671875, + 1.9998678870010376, + 2.000314803161621, + 2.0005521519470215, + 1.9993469972991944, + 2.0002477783966066, + 1.9992239441680908, + 2.000459006500244, + 1.999985199661255, + 1.9989831493377686, + 1.9985046770477295, + 1.999920511932373, + 2.001435203552246, + 1.999222763671875, + 1.9992314191436769, + 1.9978362256240845, + 1.9966827032470704, + 1.9963577353286743, + 1.9981027877807618, + 1.998446220703125, + 1.999375821762085, + 1.9966239959716796, + 1.9978697190856933, + 1.9981718766021728, + 1.9964544176864625, + 1.996339178161621, + 1.995297778892517, + 1.9981674702453613, + 1.9945542547988893, + 1.9964004328918457, + 1.9957070028686523, + 1.9959317153930665, + 1.996950538787842, + 1.9957062911987304, + 1.9937946883392335, + 1.9961785815429687, + 1.996840117225647, + 1.9966120117950439, + 1.9950081676483153, + 1.9959315377807618, + 1.9947433185195922, + 1.9953646588897704, + 1.9961627129745483, + 1.9956305498504638, + 1.9933074866104126, + 1.9939569332122802, + 1.9948224069213867, + 1.9938344969940185, + 1.9951542666625977, + 1.994460453224182, + 1.9951913018798828, + 1.992417700653076, + 1.9943852715682984, + 1.9950172023010253, + 1.9913645124816894, + 1.9934076538467407, + 1.993856616821289, + 1.9950498567962647 + ], + "train_acc": [ + 0.24538, + 0.2525, + 0.24984, + 0.25596, + 0.25824, + 0.26188, + 0.26018, + 0.26256, + 0.26788, + 0.26836, + 0.26572, + 0.26712, + 0.26946, + 0.2707, + 0.26872, + 0.26986, + 0.26942, + 0.26706, + 0.2688, + 0.26814, + 0.2714, + 0.26896, + 0.26966, + 0.27136, + 0.26994, + 0.27172, + 0.27024, + 0.2708, + 0.27018, + 0.27212, + 0.27326, + 0.2728, + 0.2694, + 0.27118, + 0.27332, + 0.2738, + 0.27492, + 0.27382, + 0.27348, + 0.27124, + 0.27338, + 0.27632, + 0.27474, + 0.27438, + 0.27608, + 0.27216, + 0.27472, + 0.27174, + 0.2764, + 0.27286, + 0.27426, + 0.27344, + 0.27646, + 0.27692, + 0.27348, + 0.27398, + 0.2745, + 0.27462, + 0.27702, + 0.2742, + 0.27524, + 0.27522, + 0.2742, + 0.27684, + 0.27522, + 0.2755, + 0.2759, + 0.27692, + 0.27698, + 0.27796, + 0.27768, + 0.27376, + 0.27822, + 0.2749, + 0.27752, + 0.2773, + 0.27838, + 0.27666, + 0.27632, + 0.27526, + 0.27792, + 0.27644, + 0.2755, + 0.27696, + 0.27616, + 0.27868, + 0.27898, + 0.27856, + 0.27766, + 0.27746, + 0.27564, + 0.27836, + 0.27742, + 0.27794, + 0.27654, + 0.277, + 0.28076, + 0.27934, + 0.27814, + 0.27694 + ], + "test_acc": [ + 0.2655, + 0.2561, + 0.2732, + 0.2647, + 0.2937, + 0.2959, + 0.2856, + 0.2952, + 0.2909, + 0.2903, + 0.2948, + 0.2937, + 0.2853, + 0.2983, + 0.2836, + 0.2904, + 0.2751, + 0.294, + 0.2826, + 0.2783, + 0.3006, + 0.2986, + 0.2981, + 0.2913, + 0.27, + 0.2959, + 0.2941, + 0.2893, + 0.2816, + 0.2762, + 0.2842, + 0.2888, + 0.294, + 0.3006, + 0.2761, + 0.2995, + 0.2824, + 0.2895, + 0.2918, + 0.2959, + 0.2896, + 0.2977, + 0.2867, + 0.2906, + 0.2955, + 0.2965, + 0.2819, + 0.2848, + 0.293, + 0.2901, + 0.2903, + 0.2893, + 0.2946, + 0.286, + 0.2811, + 0.2948, + 0.2884, + 0.2963, + 0.2914, + 0.2919, + 0.2882, + 0.2923, + 0.2964, + 0.3036, + 0.3028, + 0.2942, + 0.2974, + 0.286, + 0.2997, + 0.2913, + 0.2942, + 0.2914, + 0.295, + 0.2929, + 0.2945, + 0.2905, + 0.2979, + 0.3003, + 0.2997, + 0.295, + 0.2937, + 0.2884, + 0.2946, + 0.2959, + 0.2977, + 0.2996, + 0.2999, + 0.2952, + 0.297, + 0.2936, + 0.2981, + 0.2977, + 0.2982, + 0.2978, + 0.2979, + 0.2976, + 0.2983, + 0.2977, + 0.2978, + 0.2978 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.41370105743408203, + -0.001240116311237216 + ], + "perturbation_rho": [ + 0.015020761638879776, + 0.0 + ], + "nudging": { + "0.001": [ + -4.163011908531189e-07, + 0.0 + ], + "0.003": [ + -1.1897645890712738e-06, + 0.0 + ], + "0.01": [ + -3.778841346502304e-06, + 9.313225746154785e-10 + ] + }, + "hidden_norms_per_layer": [ + 56646.390625, + 1782275456.0, + 3970110464.0 + ], + "bp_grad_norms_per_layer": [ + 2.493850104201556e-07, + 3.667711412358443e-10, + 3.6680619652784685e-10 + ] + }, + "drift": { + "embed.weight": 340.80661711332743, + "embed.bias": 276.58178672612297, + "blocks.0.ln.weight": 9.889068312709053, + "blocks.0.w1.weight": 318.9410688527977, + "blocks.0.w1.bias": 284.37318853440576, + "blocks.0.w2.weight": 489.2728313372279, + "blocks.1.ln.weight": 9.298819559476074, + "blocks.1.w1.weight": 366.2278399546544, + "blocks.1.w1.bias": 340.3854434800174, + "blocks.1.w2.weight": 341.5363762004093, + "out_ln.weight": 0.5301057051408113, + "out_head.weight": 8.039116173269838, + "out_head.bias": 2.5371126762269562 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.05515945022583, + 1.9500655529785156, + 1.9110925933456422, + 1.8819315839004516, + 1.8608471837615967, + 1.8478346523666382, + 1.8411366244125367, + 1.8377560669326782, + 1.8311604034805298, + 1.8246206870269774, + 1.8166842291641236, + 1.8113972366714477, + 1.8071093799209594, + 1.7970440604400635, + 1.7988952663803102, + 1.7952887838363647, + 1.7926777132415772, + 1.790428493423462, + 1.7899820502090453, + 1.7887613116836547, + 1.7811174698257446, + 1.781726747894287, + 1.782828377418518, + 1.7825099970245362, + 1.7798745111846923, + 1.7748807119369507, + 1.769409514427185, + 1.7675959386444091, + 1.7701947622680665, + 1.7674276565551759, + 1.7609349981307982, + 1.7605391665267944, + 1.762060026626587, + 1.7580202764129638, + 1.7636039975357056, + 1.7673666812133788, + 1.7655566649627685, + 1.7669985010147096, + 1.7656920331192016, + 1.7639919567108153, + 1.7634542723846436, + 1.7601457201766968, + 1.761037275123596, + 1.7602271603775024, + 1.7624055517578125, + 1.7582754214477538, + 1.7596718951416015, + 1.7575918433380127, + 1.757328048439026, + 1.7560129935455322, + 1.7593553677749634, + 1.75662100856781, + 1.7546059392547608, + 1.7581290004730226, + 1.7570533187103272, + 1.7555389776992798, + 1.755943459587097, + 1.752056611404419, + 1.752125501060486, + 1.7542531232452392, + 1.7525837285614014, + 1.7537434762954711, + 1.7566586553573609, + 1.754645963783264, + 1.7561360279083251, + 1.7577947814178467, + 1.7517323531723021, + 1.7506846601104735, + 1.7523996893310547, + 1.7564484720611573, + 1.7542920526885986, + 1.7536059958648682, + 1.753201146583557, + 1.7531446646118165, + 1.7524780932617188, + 1.7523272793197633, + 1.752608459777832, + 1.7483984720611572, + 1.7505473443984985, + 1.750156759109497, + 1.7495093181991577, + 1.7508859223175048, + 1.7515791293716432, + 1.7466206330108642, + 1.751723793373108, + 1.750410927658081, + 1.7431821329116821, + 1.7460082646942139, + 1.7443738849639892, + 1.7485774746704101, + 1.747268639907837, + 1.7440908068847656, + 1.7475049467468262, + 1.7451262094497682, + 1.747353721961975, + 1.7474739736557008, + 1.7444730795669556, + 1.748691442642212, + 1.743288662033081, + 1.7470912045669555 + ], + "train_acc": [ + 0.25462, + 0.29498, + 0.31256, + 0.32432, + 0.33206, + 0.33724, + 0.33776, + 0.34132, + 0.34392, + 0.34576, + 0.35074, + 0.34868, + 0.35426, + 0.35624, + 0.35478, + 0.35702, + 0.35892, + 0.3572, + 0.35788, + 0.35796, + 0.3596, + 0.36274, + 0.3592, + 0.36092, + 0.36282, + 0.36558, + 0.36466, + 0.36602, + 0.36356, + 0.36714, + 0.36908, + 0.36888, + 0.36582, + 0.368, + 0.365, + 0.3677, + 0.3674, + 0.3658, + 0.36706, + 0.368, + 0.3688, + 0.3678, + 0.36976, + 0.36634, + 0.36654, + 0.3687, + 0.36986, + 0.36986, + 0.3703, + 0.37068, + 0.36806, + 0.37002, + 0.37306, + 0.3722, + 0.36826, + 0.37146, + 0.3704, + 0.37428, + 0.37186, + 0.37382, + 0.37398, + 0.3711, + 0.3723, + 0.37532, + 0.37158, + 0.37096, + 0.37172, + 0.37316, + 0.37314, + 0.37438, + 0.37196, + 0.3737, + 0.37458, + 0.37306, + 0.37588, + 0.37444, + 0.37408, + 0.37832, + 0.37416, + 0.37576, + 0.3746, + 0.375, + 0.37462, + 0.37456, + 0.37364, + 0.37744, + 0.37852, + 0.37832, + 0.3778, + 0.37706, + 0.37604, + 0.37654, + 0.37828, + 0.37762, + 0.37686, + 0.377, + 0.37886, + 0.37662, + 0.37522, + 0.37652 + ], + "test_acc": [ + 0.2978, + 0.3334, + 0.3399, + 0.3342, + 0.3634, + 0.3618, + 0.3505, + 0.3521, + 0.3696, + 0.3722, + 0.3679, + 0.3787, + 0.3663, + 0.3811, + 0.3611, + 0.365, + 0.3543, + 0.3666, + 0.3501, + 0.36, + 0.3572, + 0.3504, + 0.3467, + 0.3598, + 0.3353, + 0.3474, + 0.3382, + 0.3432, + 0.3617, + 0.3397, + 0.353, + 0.3288, + 0.3385, + 0.3414, + 0.3467, + 0.3389, + 0.3341, + 0.3434, + 0.3537, + 0.3496, + 0.3526, + 0.346, + 0.3603, + 0.3526, + 0.3342, + 0.346, + 0.3392, + 0.3277, + 0.3535, + 0.3419, + 0.3396, + 0.3404, + 0.3433, + 0.3468, + 0.339, + 0.3361, + 0.3421, + 0.3363, + 0.3544, + 0.3465, + 0.3454, + 0.3472, + 0.3401, + 0.3454, + 0.3511, + 0.343, + 0.3445, + 0.3368, + 0.3424, + 0.3425, + 0.3381, + 0.3435, + 0.3496, + 0.346, + 0.3411, + 0.3408, + 0.3512, + 0.3469, + 0.3473, + 0.3419, + 0.3554, + 0.3449, + 0.345, + 0.3431, + 0.3442, + 0.3501, + 0.3534, + 0.3469, + 0.3496, + 0.3436, + 0.3478, + 0.3501, + 0.3492, + 0.3481, + 0.3474, + 0.3466, + 0.3476, + 0.3468, + 0.3473, + 0.3471 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.007958251982927322, + 0.9600946307182312 + ], + "perturbation_rho": [ + 0.029625363647937775, + 0.18801619112491608 + ], + "nudging": { + "0.001": [ + 4.2957253754138947e-07, + -8.616363629698753e-06 + ], + "0.003": [ + 1.2052478268742561e-06, + -2.6152702048420906e-05 + ], + "0.01": [ + 4.0422892197966576e-06, + -8.745735976845026e-05 + ] + }, + "hidden_norms_per_layer": [ + 6253.97314453125, + 339687.71875, + 217173.15625 + ], + "bp_grad_norms_per_layer": [ + 1.8587075828691013e-05, + 8.973949547907978e-07, + 7.725411705905572e-07 + ] + }, + "drift": { + "embed.weight": 35.41278218973057, + "embed.bias": 29.83841191850244, + "blocks.0.ln.weight": 1.4875314025636917, + "blocks.0.w1.weight": 25.996758962578763, + "blocks.0.w1.bias": 15.439034204651744, + "blocks.0.w2.weight": 58.48910743468254, + "blocks.1.ln.weight": 1.2261067193625912, + "blocks.1.w1.weight": 24.34549540914942, + "blocks.1.w1.bias": 12.006957572444392, + "blocks.1.w2.weight": 29.62795958064609, + "out_ln.weight": 0.4822583681094011, + "out_head.weight": 3.604309031698496, + "out_head.bias": 10.930487632919764 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 1 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed1", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed2/results_cifar10.json b/results/fa_dfa_d512_L2_seed2/results_cifar10.json new file mode 100644 index 0000000..35785f5 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed2/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "2": { + "dfa": { + "log": { + "train_loss": [ + 2.058412809791565, + 2.0388552247619627, + 2.0358241034317017, + 2.027553281555176, + 2.030367368774414, + 2.027717582397461, + 2.023588889923096, + 2.022263825721741, + 2.0241731433105468, + 2.0235239835357666, + 2.023893345298767, + 2.0222612105560303, + 2.028809658164978, + 2.023522481994629, + 2.023535331993103, + 2.022634300880432, + 2.0261124728393556, + 2.027209746055603, + 2.024631557235718, + 2.0272511472320556, + 2.0248798042297365, + 2.0252146803283693, + 2.021958327026367, + 2.0255450409317017, + 2.0234208207702635, + 2.0240782318496704, + 2.0233148418807985, + 2.022000389137268, + 2.024022118911743, + 2.023005671234131, + 2.0231490325927735, + 2.024285191497803, + 2.0222519898986815, + 2.023439755783081, + 2.023540186843872, + 2.0246322383117676, + 2.0237508728027342, + 2.024476735076904, + 2.0262484502410887, + 2.0244761968612672, + 2.023149300994873, + 2.0230725455093386, + 2.0237448512268066, + 2.0239054921722412, + 2.0238359395980834, + 2.0231575259399412, + 2.022854585723877, + 2.022712833328247, + 2.022943268432617, + 2.023613681335449, + 2.022585965042114, + 2.0229795806503295, + 2.0224396421051027, + 2.024438477630615, + 2.0226136851501466, + 2.0218412887191772, + 2.020350592918396, + 2.021553417892456, + 2.0207389210510254, + 2.022400696411133, + 2.0226713175201416, + 2.0205454586029052, + 2.022753714904785, + 2.0204339378356932, + 2.0214138391113283, + 2.019458531036377, + 2.018622496871948, + 2.0211455561828613, + 2.020830050392151, + 2.0200049672698976, + 2.0226647985839845, + 2.020451495361328, + 2.0213021927642822, + 2.018636092834473, + 2.0201395918273928, + 2.0197924830627443, + 2.0189465530395507, + 2.0179246043395995, + 2.0190702261352538, + 2.018002661895752, + 2.0188750094604493, + 2.0186908027648927, + 2.018660167617798, + 2.016986283912659, + 2.016896242828369, + 2.0198894203948976, + 2.0182291090393067, + 2.0168725957870484, + 2.0189685047912596, + 2.0170926274108885, + 2.0171643753814696, + 2.0170515134429934, + 2.017839207763672, + 2.0162586155700684, + 2.016942921524048, + 2.0172505876159668, + 2.016627454185486, + 2.01472389213562, + 2.016936295089722, + 2.0162249319458008 + ], + "train_acc": [ + 0.24042, + 0.24948, + 0.25212, + 0.25448, + 0.25656, + 0.25446, + 0.26134, + 0.2625, + 0.25946, + 0.26152, + 0.26086, + 0.26258, + 0.2591, + 0.25878, + 0.25972, + 0.2621, + 0.26092, + 0.25914, + 0.25866, + 0.25816, + 0.26052, + 0.25714, + 0.26258, + 0.25996, + 0.26152, + 0.2614, + 0.2641, + 0.26338, + 0.26254, + 0.26234, + 0.26476, + 0.26218, + 0.2652, + 0.26612, + 0.2645, + 0.26434, + 0.26412, + 0.26428, + 0.26476, + 0.26552, + 0.26592, + 0.2645, + 0.26798, + 0.26322, + 0.26514, + 0.2651, + 0.26806, + 0.26828, + 0.26474, + 0.26502, + 0.26398, + 0.26496, + 0.26396, + 0.265, + 0.26866, + 0.26958, + 0.2685, + 0.26846, + 0.26926, + 0.27102, + 0.26792, + 0.26856, + 0.26608, + 0.26824, + 0.27114, + 0.26984, + 0.2685, + 0.26856, + 0.26812, + 0.26898, + 0.26946, + 0.26888, + 0.26992, + 0.2687, + 0.2679, + 0.2708, + 0.26838, + 0.27112, + 0.27018, + 0.26964, + 0.27128, + 0.27114, + 0.27024, + 0.27024, + 0.27036, + 0.271, + 0.27062, + 0.27166, + 0.27026, + 0.27294, + 0.2727, + 0.27276, + 0.27036, + 0.27272, + 0.26992, + 0.26996, + 0.2709, + 0.27316, + 0.2732, + 0.27336 + ], + "test_acc": [ + 0.278, + 0.2738, + 0.2583, + 0.2773, + 0.2658, + 0.2783, + 0.2681, + 0.2957, + 0.2671, + 0.2647, + 0.287, + 0.2846, + 0.284, + 0.287, + 0.2724, + 0.2925, + 0.2763, + 0.2711, + 0.2614, + 0.277, + 0.2817, + 0.2791, + 0.2834, + 0.2821, + 0.2641, + 0.2764, + 0.2963, + 0.2882, + 0.3005, + 0.2923, + 0.2579, + 0.2889, + 0.2999, + 0.2958, + 0.2739, + 0.2984, + 0.286, + 0.2643, + 0.2829, + 0.2919, + 0.292, + 0.2875, + 0.2932, + 0.2778, + 0.2753, + 0.291, + 0.2955, + 0.3014, + 0.2887, + 0.2856, + 0.2937, + 0.2946, + 0.298, + 0.2932, + 0.3005, + 0.2989, + 0.2924, + 0.2977, + 0.2922, + 0.2876, + 0.2943, + 0.2844, + 0.3003, + 0.2952, + 0.2881, + 0.2821, + 0.2934, + 0.28, + 0.2958, + 0.2933, + 0.2923, + 0.2941, + 0.2872, + 0.2946, + 0.2966, + 0.2962, + 0.2915, + 0.2969, + 0.2988, + 0.2968, + 0.2991, + 0.2972, + 0.297, + 0.2955, + 0.2973, + 0.2994, + 0.2959, + 0.2962, + 0.3, + 0.2971, + 0.2977, + 0.2962, + 0.2963, + 0.2955, + 0.2959, + 0.2964, + 0.2962, + 0.2963, + 0.2968, + 0.2968 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.3571905195713043, + -2.4819419195409864e-05 + ], + "perturbation_rho": [ + -0.020394116640090942, + 0.0 + ], + "nudging": { + "0.001": [ + -3.2084062695503235e-07, + 0.0 + ], + "0.003": [ + -8.596107363700867e-07, + 0.0 + ], + "0.01": [ + -2.905726432800293e-06, + 0.0 + ] + }, + "hidden_norms_per_layer": [ + 53198.71875, + 2227773184.0, + 5250976256.0 + ], + "bp_grad_norms_per_layer": [ + 2.0440073456029495e-07, + 3.7151820508896094e-10, + 3.7148767395578375e-10 + ] + }, + "drift": { + "embed.weight": 342.0501862953785, + "embed.bias": 318.0357815319326, + "blocks.0.ln.weight": 9.910735164962347, + "blocks.0.w1.weight": 324.42259888737226, + "blocks.0.w1.bias": 352.6779805020607, + "blocks.0.w2.weight": 492.2710513712036, + "blocks.1.ln.weight": 9.723967303148715, + "blocks.1.w1.weight": 403.31369239097415, + "blocks.1.w1.bias": 386.0984329662907, + "blocks.1.w2.weight": 397.9343379416526, + "out_ln.weight": 0.5056100406607513, + "out_head.weight": 8.24905668320801, + "out_head.bias": 2.0933992602399067 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.063060122451782, + 1.9739086376571655, + 1.9282753707504272, + 1.8921499670410156, + 1.8799907299804688, + 1.8701481190490723, + 1.8541997326660156, + 1.8432505453109742, + 1.8392359512710572, + 1.8295656701278686, + 1.8286965405273437, + 1.8275740008163452, + 1.82835955909729, + 1.8276023639297485, + 1.8295053287506104, + 1.8279134408569335, + 1.8347502783966065, + 1.8358594959259034, + 1.832546812095642, + 1.8375971967315674, + 1.8337256093978882, + 1.8388268453598022, + 1.8368015246200562, + 1.8367653689956664, + 1.837699118347168, + 1.8374284832763672, + 1.8332337552642821, + 1.8305227688217163, + 1.8352981665802002, + 1.8384290933227538, + 1.8319619774627685, + 1.8331132510375976, + 1.8284437879180908, + 1.8323823754882813, + 1.8337605920791626, + 1.8296470736312866, + 1.824372737045288, + 1.8253158220672607, + 1.8254631029891968, + 1.8240845959472656, + 1.819489220275879, + 1.8196685235214234, + 1.8173167004013062, + 1.8210680514526367, + 1.812099649734497, + 1.8139883127212524, + 1.810353461265564, + 1.8109969634246825, + 1.808874561843872, + 1.8115904225158692, + 1.8089021768188476, + 1.8070925519561767, + 1.807039188232422, + 1.8048222934341431, + 1.8031300540924071, + 1.8039923257446289, + 1.8042795065307617, + 1.8037835544586183, + 1.8002930462646485, + 1.8007701531982423, + 1.8004488678359984, + 1.8029598141098022, + 1.8020868547821045, + 1.7990351900482178, + 1.8022524103546143, + 1.7987030548477172, + 1.7969628913116455, + 1.7972235089874267, + 1.7918564282226563, + 1.7940529309082032, + 1.7972945267105103, + 1.7940796615219117, + 1.7942549993133545, + 1.7947112902450562, + 1.7927878363037109, + 1.7915658866119384, + 1.7944716805267333, + 1.790445680809021, + 1.7891779144668578, + 1.789405579185486, + 1.7867662616729736, + 1.7887766037368775, + 1.7875445496368407, + 1.7851270193862916, + 1.7886884448623657, + 1.786267046775818, + 1.785640559425354, + 1.7853610308074952, + 1.7856942657470702, + 1.7871524771499634, + 1.7883876535797119, + 1.782026519203186, + 1.786798097305298, + 1.7790444551849365, + 1.7840377599334716, + 1.787256519126892, + 1.7822379702758788, + 1.7822947565460205, + 1.7829025159072875, + 1.7819519442749023 + ], + "train_acc": [ + 0.24228, + 0.28314, + 0.30214, + 0.31828, + 0.32302, + 0.3277, + 0.33338, + 0.33888, + 0.34066, + 0.34536, + 0.34814, + 0.34648, + 0.34716, + 0.34526, + 0.3462, + 0.3422, + 0.343, + 0.34162, + 0.3438, + 0.34388, + 0.34354, + 0.33854, + 0.34312, + 0.34176, + 0.34514, + 0.34304, + 0.34458, + 0.34162, + 0.34348, + 0.34098, + 0.34374, + 0.34488, + 0.34328, + 0.34504, + 0.34414, + 0.34548, + 0.34552, + 0.34848, + 0.34836, + 0.34802, + 0.34918, + 0.346, + 0.34986, + 0.3473, + 0.3515, + 0.3527, + 0.3532, + 0.35212, + 0.35068, + 0.34998, + 0.3505, + 0.3565, + 0.35392, + 0.35556, + 0.35556, + 0.35462, + 0.3548, + 0.35674, + 0.35552, + 0.35848, + 0.35704, + 0.35594, + 0.35598, + 0.35716, + 0.35818, + 0.35842, + 0.35966, + 0.35968, + 0.36022, + 0.36082, + 0.3571, + 0.3612, + 0.3604, + 0.3607, + 0.3602, + 0.36134, + 0.35734, + 0.36262, + 0.36352, + 0.35916, + 0.36462, + 0.3617, + 0.3608, + 0.3617, + 0.36106, + 0.36262, + 0.36296, + 0.36354, + 0.36138, + 0.36132, + 0.36446, + 0.36452, + 0.36388, + 0.3645, + 0.36378, + 0.36342, + 0.36496, + 0.36556, + 0.36274, + 0.36546 + ], + "test_acc": [ + 0.2954, + 0.3137, + 0.3295, + 0.3474, + 0.3551, + 0.3544, + 0.3515, + 0.3694, + 0.3633, + 0.3479, + 0.3744, + 0.3726, + 0.371, + 0.3618, + 0.3649, + 0.3685, + 0.3658, + 0.3704, + 0.3504, + 0.3729, + 0.3578, + 0.3685, + 0.3676, + 0.3643, + 0.3562, + 0.3486, + 0.3732, + 0.3624, + 0.3693, + 0.341, + 0.3467, + 0.3544, + 0.3679, + 0.3526, + 0.3593, + 0.3647, + 0.3604, + 0.3662, + 0.3632, + 0.3644, + 0.3628, + 0.3598, + 0.3564, + 0.3633, + 0.3531, + 0.3607, + 0.3616, + 0.3636, + 0.3514, + 0.3488, + 0.3408, + 0.3522, + 0.3651, + 0.3408, + 0.3548, + 0.3504, + 0.3469, + 0.3507, + 0.3515, + 0.3434, + 0.3548, + 0.351, + 0.351, + 0.3545, + 0.3487, + 0.3467, + 0.3492, + 0.3439, + 0.346, + 0.3357, + 0.332, + 0.3387, + 0.3546, + 0.3415, + 0.3458, + 0.3451, + 0.3477, + 0.342, + 0.3446, + 0.3427, + 0.3426, + 0.3423, + 0.3481, + 0.3467, + 0.3443, + 0.3437, + 0.3466, + 0.3459, + 0.3448, + 0.3458, + 0.3433, + 0.3461, + 0.3479, + 0.3447, + 0.3452, + 0.3467, + 0.3461, + 0.3467, + 0.3464, + 0.3464 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.01565447635948658, + 0.9284564256668091 + ], + "perturbation_rho": [ + 0.08889118582010269, + 0.1751367598772049 + ], + "nudging": { + "0.001": [ + -8.567003533244133e-07, + -7.14592169970274e-06 + ], + "0.003": [ + -2.5028130039572716e-06, + -2.1490384824573994e-05 + ], + "0.01": [ + -8.566654287278652e-06, + -7.176969666033983e-05 + ] + }, + "hidden_norms_per_layer": [ + 5015.7705078125, + 205831.875, + 281989.59375 + ], + "bp_grad_norms_per_layer": [ + 1.882538890640717e-05, + 8.274267884189612e-07, + 8.043497246035258e-07 + ] + }, + "drift": { + "embed.weight": 31.381319271594045, + "embed.bias": 16.70881646410694, + "blocks.0.ln.weight": 1.5550725713952875, + "blocks.0.w1.weight": 29.555012389084762, + "blocks.0.w1.bias": 17.153011809958922, + "blocks.0.w2.weight": 62.82184507015774, + "blocks.1.ln.weight": 1.2734634249220895, + "blocks.1.w1.weight": 21.575349693350912, + "blocks.1.w1.bias": 10.895115401942027, + "blocks.1.w2.weight": 35.67159460439069, + "out_ln.weight": 0.4183849023916904, + "out_head.weight": 4.1274562698411295, + "out_head.bias": 12.873937810986282 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 2 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed2", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed3/results_cifar10.json b/results/fa_dfa_d512_L2_seed3/results_cifar10.json new file mode 100644 index 0000000..8cf5599 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed3/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "3": { + "dfa": { + "log": { + "train_loss": [ + 2.049092872314453, + 2.044887622756958, + 2.062239527206421, + 2.0690095069122316, + 2.0704432483673094, + 2.076322864227295, + 2.0749057723999025, + 2.0722323740005493, + 2.0724080812835695, + 2.072884361419678, + 2.0764317562866212, + 2.070321410064697, + 2.0702889765930177, + 2.067789299621582, + 2.0646444761657716, + 2.0636547829818728, + 2.062572795333862, + 2.0582885260772703, + 2.059502869415283, + 2.0571998697280884, + 2.0527090773010253, + 2.056031491851807, + 2.0513366117095946, + 2.0520858401870727, + 2.0493150717926025, + 2.0478450300598143, + 2.0498407819366453, + 2.043566691818237, + 2.042290360069275, + 2.039692493896484, + 2.041093152770996, + 2.0399529009246824, + 2.039473579330444, + 2.0396062004089357, + 2.0359111277770996, + 2.0347996918487548, + 2.034452969818115, + 2.033417662200928, + 2.0335788008880615, + 2.0308428141784667, + 2.027684537124634, + 2.0306087942504885, + 2.0288818074798582, + 2.028032119369507, + 2.0284718214416504, + 2.025789548034668, + 2.0262292552948, + 2.023198474197388, + 2.024526604385376, + 2.025226473312378, + 2.0236859022521974, + 2.021138525657654, + 2.022560397567749, + 2.0220771756744385, + 2.0260347727203367, + 2.022141723327637, + 2.021563760910034, + 2.018922806472778, + 2.021806682510376, + 2.0212029346466065, + 2.01982150100708, + 2.019622989349365, + 2.0191136405944823, + 2.0185832523345946, + 2.0185110153961183, + 2.0175517141342163, + 2.0180020709991453, + 2.0152184085845946, + 2.0153978774261474, + 2.017112913208008, + 2.0173879592895507, + 2.0179903555297853, + 2.013354239501953, + 2.016451063537598, + 2.013973571510315, + 2.018112847671509, + 2.015232135925293, + 2.014464662742615, + 2.0156791037368773, + 2.0115500025177, + 2.014401368560791, + 2.0142960264587404, + 2.012123734397888, + 2.0130788822937014, + 2.0141370764923097, + 2.014659231796265, + 2.0124325815582274, + 2.014389346160889, + 2.0128879175567627, + 2.0132610288238526, + 2.014250590438843, + 2.012836150970459, + 2.0106287144470216, + 2.014688469619751, + 2.0122887326812746, + 2.01329588142395, + 2.0123015225601195, + 2.010611874694824, + 2.015006597442627, + 2.0157135874176024 + ], + "train_acc": [ + 0.25188, + 0.25386, + 0.24578, + 0.2458, + 0.24326, + 0.2434, + 0.23896, + 0.24338, + 0.24348, + 0.24286, + 0.24246, + 0.24364, + 0.24108, + 0.24176, + 0.24598, + 0.24638, + 0.24538, + 0.24996, + 0.24744, + 0.24802, + 0.25254, + 0.2511, + 0.2547, + 0.2523, + 0.25268, + 0.25586, + 0.2522, + 0.25328, + 0.25624, + 0.25838, + 0.25822, + 0.25672, + 0.25938, + 0.25924, + 0.25908, + 0.259, + 0.2602, + 0.26116, + 0.26152, + 0.2633, + 0.26614, + 0.2616, + 0.26606, + 0.26386, + 0.26708, + 0.26482, + 0.2682, + 0.26386, + 0.26484, + 0.26862, + 0.26774, + 0.26798, + 0.2692, + 0.2687, + 0.26776, + 0.27036, + 0.26802, + 0.2715, + 0.26714, + 0.27108, + 0.27032, + 0.27288, + 0.271, + 0.27048, + 0.26948, + 0.27382, + 0.27056, + 0.27404, + 0.27162, + 0.27054, + 0.27232, + 0.27328, + 0.27242, + 0.26978, + 0.2725, + 0.27236, + 0.27228, + 0.27282, + 0.27318, + 0.2746, + 0.27434, + 0.27072, + 0.27108, + 0.2737, + 0.273, + 0.2735, + 0.27328, + 0.2745, + 0.27388, + 0.27212, + 0.2744, + 0.27352, + 0.2723, + 0.27264, + 0.27482, + 0.2727, + 0.2734, + 0.27328, + 0.27136, + 0.27322 + ], + "test_acc": [ + 0.2855, + 0.2551, + 0.2775, + 0.2743, + 0.2562, + 0.2696, + 0.2672, + 0.2538, + 0.2753, + 0.2372, + 0.2459, + 0.2719, + 0.2666, + 0.2665, + 0.2673, + 0.2632, + 0.2686, + 0.2624, + 0.2753, + 0.2718, + 0.2735, + 0.2621, + 0.2539, + 0.247, + 0.2764, + 0.2782, + 0.2758, + 0.2809, + 0.2735, + 0.2648, + 0.2652, + 0.2743, + 0.2866, + 0.2829, + 0.2732, + 0.2647, + 0.282, + 0.2671, + 0.2715, + 0.2861, + 0.2807, + 0.2831, + 0.2899, + 0.2743, + 0.2842, + 0.2796, + 0.2721, + 0.2816, + 0.2874, + 0.2873, + 0.2878, + 0.2746, + 0.2911, + 0.2938, + 0.2945, + 0.2906, + 0.2921, + 0.2921, + 0.2867, + 0.2754, + 0.2951, + 0.2874, + 0.2816, + 0.2894, + 0.2686, + 0.277, + 0.2915, + 0.295, + 0.2958, + 0.2885, + 0.2861, + 0.2907, + 0.2913, + 0.2959, + 0.2897, + 0.2872, + 0.2932, + 0.2882, + 0.2933, + 0.2942, + 0.2939, + 0.2946, + 0.2902, + 0.2888, + 0.2957, + 0.2944, + 0.2917, + 0.294, + 0.2895, + 0.2908, + 0.2929, + 0.2916, + 0.2912, + 0.2927, + 0.2923, + 0.2931, + 0.2924, + 0.2923, + 0.2925, + 0.2922 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.3577328324317932, + -0.002256808802485466 + ], + "perturbation_rho": [ + 0.006564013194292784, + 0.0245413389056921 + ], + "nudging": { + "0.001": [ + -3.3527612686157227e-07, + 0.0 + ], + "0.003": [ + -9.671784937381744e-07, + 9.313225746154785e-10 + ], + "0.01": [ + -3.1115487217903137e-06, + 3.725290298461914e-09 + ] + }, + "hidden_norms_per_layer": [ + 61062.07421875, + 1830187136.0, + 2154652416.0 + ], + "bp_grad_norms_per_layer": [ + 2.2415125044972228e-07, + 2.5406721171350455e-10, + 2.541881982676131e-10 + ] + }, + "drift": { + "embed.weight": 362.6998630270425, + "embed.bias": 226.82991769893152, + "blocks.0.ln.weight": 11.27242096475318, + "blocks.0.w1.weight": 342.7047730991574, + "blocks.0.w1.bias": 255.335818330469, + "blocks.0.w2.weight": 571.0736810874553, + "blocks.1.ln.weight": 6.709290843015378, + "blocks.1.w1.weight": 234.88441196904554, + "blocks.1.w1.bias": 208.6820413337303, + "blocks.1.w2.weight": 259.6972453610996, + "out_ln.weight": 0.47275142976863993, + "out_head.weight": 5.14258998677984, + "out_head.bias": 3.632529706269454 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.062916795501709, + 1.9544608277130127, + 1.9338681225204468, + 1.9168871939468384, + 1.8961386334609986, + 1.8842207279205323, + 1.8664908053970337, + 1.8596847631454467, + 1.8500450038909912, + 1.8418613220596314, + 1.8393701876449584, + 1.828501918258667, + 1.8237099225234985, + 1.8169598587799072, + 1.8136948993682862, + 1.8099042462539672, + 1.8040387533569335, + 1.7993673914337158, + 1.801761696434021, + 1.7964957051849366, + 1.7922449420166016, + 1.7918584114837646, + 1.7928801208496095, + 1.7959136298370362, + 1.7932421494293214, + 1.7945690915298462, + 1.798325119857788, + 1.7924271218490602, + 1.795842142982483, + 1.786767998275757, + 1.7907304748153687, + 1.7930365305328368, + 1.7880332135009767, + 1.7910264099884032, + 1.786202894821167, + 1.7834789249038696, + 1.7853226938247682, + 1.7844206839752197, + 1.779783320388794, + 1.777475042037964, + 1.7763150534820558, + 1.7789424829101563, + 1.7754264488983154, + 1.7715616805648804, + 1.7733828824615478, + 1.7673900942993164, + 1.7679843439102172, + 1.7700712796020508, + 1.7681101821517944, + 1.763485518951416, + 1.7642558453369142, + 1.7612942028427123, + 1.7610017293548583, + 1.7577664344024657, + 1.7626592459869386, + 1.7603577539443969, + 1.754944856262207, + 1.7521989233398438, + 1.7558432748031616, + 1.757179810218811, + 1.7534725143051146, + 1.7556944265365602, + 1.7504646623992919, + 1.7463257875823974, + 1.7487027013397216, + 1.7462505680084228, + 1.7486760431289672, + 1.7459170078277588, + 1.740626848487854, + 1.7471276846694945, + 1.7445780047225952, + 1.743365991783142, + 1.7377847635650634, + 1.7428057806396484, + 1.7423382642364502, + 1.739464909362793, + 1.7394498482513427, + 1.7388129167556763, + 1.743233028869629, + 1.7352628125, + 1.740815905380249, + 1.7376227493667602, + 1.7356027941894532, + 1.7353747326278686, + 1.7357349237060546, + 1.7323332436370849, + 1.7330560940170288, + 1.736462532081604, + 1.7307056594085692, + 1.7358304736709596, + 1.7327582055282593, + 1.7328367833709717, + 1.730865809020996, + 1.7343174131011962, + 1.7336546128082275, + 1.729335763282776, + 1.731069825515747, + 1.7309973764801025, + 1.7332309117889404, + 1.7309427404022217 + ], + "train_acc": [ + 0.25124, + 0.29078, + 0.30174, + 0.30854, + 0.31558, + 0.32094, + 0.32596, + 0.32986, + 0.3375, + 0.33924, + 0.34048, + 0.34462, + 0.34806, + 0.34818, + 0.34936, + 0.35036, + 0.35452, + 0.35762, + 0.35606, + 0.35678, + 0.35724, + 0.3556, + 0.35836, + 0.35558, + 0.35872, + 0.35734, + 0.35428, + 0.35594, + 0.35682, + 0.35978, + 0.36006, + 0.35966, + 0.36212, + 0.361, + 0.3593, + 0.36056, + 0.36182, + 0.36036, + 0.3643, + 0.36462, + 0.36522, + 0.3643, + 0.3701, + 0.37198, + 0.36678, + 0.36798, + 0.37018, + 0.3697, + 0.3685, + 0.37262, + 0.36694, + 0.3722, + 0.37228, + 0.37302, + 0.37096, + 0.37208, + 0.3721, + 0.37598, + 0.37528, + 0.37364, + 0.37788, + 0.37682, + 0.37554, + 0.37848, + 0.37854, + 0.3787, + 0.37996, + 0.37764, + 0.3765, + 0.37938, + 0.37982, + 0.38054, + 0.383, + 0.38022, + 0.382, + 0.3806, + 0.38078, + 0.38504, + 0.38192, + 0.38324, + 0.38318, + 0.38192, + 0.38326, + 0.38492, + 0.38542, + 0.38452, + 0.3843, + 0.38424, + 0.38694, + 0.38464, + 0.38664, + 0.3882, + 0.38476, + 0.3842, + 0.38732, + 0.38674, + 0.38546, + 0.38734, + 0.38404, + 0.38848 + ], + "test_acc": [ + 0.2764, + 0.3182, + 0.3354, + 0.3476, + 0.3539, + 0.339, + 0.3478, + 0.3737, + 0.3595, + 0.3642, + 0.3624, + 0.3739, + 0.3717, + 0.3809, + 0.3776, + 0.3763, + 0.3824, + 0.3807, + 0.3754, + 0.3832, + 0.3864, + 0.391, + 0.3762, + 0.3797, + 0.3875, + 0.3857, + 0.3897, + 0.3769, + 0.3752, + 0.3835, + 0.3795, + 0.3635, + 0.3828, + 0.3894, + 0.3827, + 0.3777, + 0.3865, + 0.3855, + 0.3863, + 0.3865, + 0.3888, + 0.3807, + 0.3986, + 0.3886, + 0.389, + 0.3891, + 0.3782, + 0.384, + 0.3912, + 0.3924, + 0.3923, + 0.3923, + 0.3917, + 0.3969, + 0.3968, + 0.3932, + 0.3964, + 0.3876, + 0.3991, + 0.3852, + 0.4036, + 0.3989, + 0.3858, + 0.3974, + 0.3968, + 0.4013, + 0.403, + 0.4009, + 0.3967, + 0.3979, + 0.397, + 0.3962, + 0.3943, + 0.4013, + 0.4027, + 0.3986, + 0.4013, + 0.3944, + 0.4021, + 0.3941, + 0.3972, + 0.3984, + 0.4003, + 0.3986, + 0.3997, + 0.3989, + 0.3996, + 0.4009, + 0.4004, + 0.3993, + 0.4001, + 0.3999, + 0.4013, + 0.4007, + 0.4007, + 0.4011, + 0.401, + 0.4007, + 0.4011, + 0.4012 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.018671220168471336, + 0.9215916395187378 + ], + "perturbation_rho": [ + 0.05869613587856293, + -0.006530101411044598 + ], + "nudging": { + "0.001": [ + -1.8319697119295597e-06, + -3.306486178189516e-06 + ], + "0.003": [ + -5.463924026116729e-06, + -1.0309304343536496e-05 + ], + "0.01": [ + -1.8201360944658518e-05, + -3.473513061180711e-05 + ] + }, + "hidden_norms_per_layer": [ + 6274.154296875, + 343870.53125, + 375719.59375 + ], + "bp_grad_norms_per_layer": [ + 2.1775536879431456e-05, + 1.1024598052244983e-06, + 8.687474064572598e-07 + ] + }, + "drift": { + "embed.weight": 47.037152107454325, + "embed.bias": 16.388390598893867, + "blocks.0.ln.weight": 1.5269469673782543, + "blocks.0.w1.weight": 25.58456877637686, + "blocks.0.w1.bias": 18.02406186908715, + "blocks.0.w2.weight": 64.17096592056014, + "blocks.1.ln.weight": 1.3407223784512263, + "blocks.1.w1.weight": 21.242765670034206, + "blocks.1.w1.bias": 17.953849785001406, + "blocks.1.w2.weight": 23.00878057790312, + "out_ln.weight": 0.3700761416548859, + "out_head.weight": 5.133209234946238, + "out_head.bias": 3.885876738755204 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 3 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed3", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed4/results_cifar10.json b/results/fa_dfa_d512_L2_seed4/results_cifar10.json new file mode 100644 index 0000000..9e35093 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed4/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "4": { + "dfa": { + "log": { + "train_loss": [ + 2.0538645137786866, + 2.0414622520446777, + 2.0403961332702636, + 2.0313826718139647, + 2.028870590591431, + 2.0276684378814696, + 2.025082506790161, + 2.027707338409424, + 2.024693886566162, + 2.0219140999603273, + 2.024347680053711, + 2.0226476731109617, + 2.0212312582397463, + 2.019639465713501, + 2.0200840225982666, + 2.022379661102295, + 2.0195740895843506, + 2.018866870689392, + 2.0154102825164797, + 2.019121079330444, + 2.0158904475402832, + 2.0182619565582276, + 2.0165906146240236, + 2.0169575183868407, + 2.0150413411712647, + 2.0142192585754395, + 2.0149837336730956, + 2.014163435897827, + 2.0160239248275755, + 2.012873958053589, + 2.014381752700806, + 2.0157296450805666, + 2.0139720638656615, + 2.0141250953674317, + 2.0129578493881226, + 2.0142643730926513, + 2.011523956451416, + 2.0130753877258303, + 2.0103190279388428, + 2.0131986405181883, + 2.011784356765747, + 2.0098198655700683, + 2.0134172485351565, + 2.011796088027954, + 2.009907284927368, + 2.0113371464538576, + 2.011871915245056, + 2.012780702667236, + 2.0106706674194337, + 2.0101452950286864, + 2.0108949478149416, + 2.0081811878967284, + 2.0113219532775877, + 2.0078719809722902, + 2.008186681213379, + 2.0087371883392335, + 2.010032041015625, + 2.0062248846054076, + 2.008138601531982, + 2.008908841209412, + 2.006502756500244, + 2.0063726428604127, + 2.006322613143921, + 2.0073862936401365, + 2.0092133827209473, + 2.0075592138671876, + 2.0070631226348876, + 2.0061116324615478, + 2.0080163690567017, + 2.0098242531585693, + 2.004972350997925, + 2.0045611180877687, + 2.0060609978485107, + 2.0061137674713136, + 2.0058565605163574, + 2.0072338876342775, + 2.0047287912750242, + 2.0041282120132444, + 2.007643541030884, + 2.0067062200546264, + 2.005947174911499, + 2.0044366609191893, + 2.0044241131591796, + 2.0034669429016114, + 2.0061908039855956, + 2.004900071258545, + 2.002522025909424, + 2.0041396823120117, + 2.0040297746276856, + 2.0036116275787355, + 2.003599824256897, + 2.003862293243408, + 2.002604514427185, + 2.003420566329956, + 2.0044958754730224, + 2.0024802682876586, + 2.001940184288025, + 2.004589928436279, + 2.00544753616333, + 2.004831633987427 + ], + "train_acc": [ + 0.24328, + 0.2486, + 0.2526, + 0.25192, + 0.25386, + 0.25942, + 0.2579, + 0.25428, + 0.2558, + 0.2587, + 0.25714, + 0.25432, + 0.2597, + 0.26176, + 0.25932, + 0.25756, + 0.26362, + 0.26154, + 0.26296, + 0.26458, + 0.2631, + 0.26162, + 0.2629, + 0.26272, + 0.26022, + 0.26332, + 0.26328, + 0.26642, + 0.26188, + 0.26562, + 0.26456, + 0.26682, + 0.2653, + 0.2635, + 0.26696, + 0.26618, + 0.26604, + 0.26422, + 0.26612, + 0.26754, + 0.26958, + 0.2674, + 0.26656, + 0.2681, + 0.2651, + 0.2682, + 0.266, + 0.26656, + 0.26712, + 0.26854, + 0.26726, + 0.26844, + 0.26866, + 0.26848, + 0.2684, + 0.26968, + 0.2694, + 0.27024, + 0.26914, + 0.2689, + 0.27068, + 0.26872, + 0.2698, + 0.2698, + 0.27038, + 0.27008, + 0.26956, + 0.26956, + 0.26738, + 0.2676, + 0.27036, + 0.27014, + 0.26994, + 0.27158, + 0.27192, + 0.27026, + 0.27058, + 0.27212, + 0.26966, + 0.26896, + 0.27094, + 0.27132, + 0.27096, + 0.27272, + 0.2715, + 0.2717, + 0.27104, + 0.27176, + 0.26956, + 0.26908, + 0.27146, + 0.2707, + 0.27088, + 0.26976, + 0.2708, + 0.27234, + 0.27182, + 0.27142, + 0.27088, + 0.27292 + ], + "test_acc": [ + 0.2691, + 0.2657, + 0.2489, + 0.2683, + 0.2735, + 0.2856, + 0.2869, + 0.2671, + 0.2775, + 0.2707, + 0.2668, + 0.2677, + 0.2711, + 0.284, + 0.28, + 0.2978, + 0.2649, + 0.2558, + 0.2813, + 0.2732, + 0.2875, + 0.2844, + 0.2664, + 0.2731, + 0.2948, + 0.2757, + 0.2818, + 0.2811, + 0.2842, + 0.2762, + 0.2852, + 0.2639, + 0.2834, + 0.2859, + 0.2804, + 0.2727, + 0.2794, + 0.2916, + 0.2746, + 0.2768, + 0.2903, + 0.2722, + 0.2896, + 0.2856, + 0.2906, + 0.2848, + 0.2776, + 0.29, + 0.2918, + 0.2703, + 0.2847, + 0.2838, + 0.2816, + 0.2894, + 0.2815, + 0.2783, + 0.2917, + 0.2712, + 0.285, + 0.2861, + 0.2844, + 0.2898, + 0.2839, + 0.2886, + 0.2809, + 0.2826, + 0.2792, + 0.2864, + 0.2978, + 0.2876, + 0.2855, + 0.2997, + 0.2912, + 0.2887, + 0.2867, + 0.2842, + 0.284, + 0.2781, + 0.2817, + 0.2911, + 0.2842, + 0.2834, + 0.2867, + 0.2861, + 0.2867, + 0.284, + 0.2887, + 0.2864, + 0.2875, + 0.2899, + 0.2864, + 0.287, + 0.2901, + 0.2856, + 0.2866, + 0.2863, + 0.2859, + 0.2861, + 0.2861, + 0.2861 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.3827235698699951, + -0.000831119017675519 + ], + "perturbation_rho": [ + -0.01562047004699707, + 0.0 + ], + "nudging": { + "0.001": [ + -2.812594175338745e-07, + 0.0 + ], + "0.003": [ + -9.383074939250946e-07, + 0.0 + ], + "0.01": [ + -3.2153911888599396e-06, + 0.0 + ] + }, + "hidden_norms_per_layer": [ + 55950.5390625, + 1669607168.0, + 3780465152.0 + ], + "bp_grad_norms_per_layer": [ + 2.2658105081063695e-07, + 2.3708951468748296e-10, + 2.3738841448128767e-10 + ] + }, + "drift": { + "embed.weight": 337.5136698692365, + "embed.bias": 273.5401823379648, + "blocks.0.ln.weight": 9.322317945209601, + "blocks.0.w1.weight": 322.0064838288751, + "blocks.0.w1.bias": 292.5825702677452, + "blocks.0.w2.weight": 505.89197780725414, + "blocks.1.ln.weight": 9.09606961243359, + "blocks.1.w1.weight": 353.39130820313096, + "blocks.1.w1.bias": 346.5478039461049, + "blocks.1.w2.weight": 349.3876744406949, + "out_ln.weight": 0.435883262668521, + "out_head.weight": 7.086345015239166, + "out_head.bias": 3.9170311167015974 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.083539034461975, + 1.9695448753356934, + 1.9397349642562867, + 1.9066173551177978, + 1.886961284866333, + 1.8710758205795288, + 1.8628956564331054, + 1.8567623680877685, + 1.8501425790405273, + 1.850063798828125, + 1.8536383200073243, + 1.851845763206482, + 1.8578751723480225, + 1.8505859604263306, + 1.8504567489242554, + 1.8505229975128175, + 1.8486209188079834, + 1.8476799687957763, + 1.8461452635192872, + 1.8506561833953858, + 1.8442412594985962, + 1.8458462058258056, + 1.842808702659607, + 1.8387870932769776, + 1.8379171523666382, + 1.828637089920044, + 1.824536803894043, + 1.822955154800415, + 1.8200513568878174, + 1.814583374671936, + 1.8130536968231201, + 1.8116273685073851, + 1.810110004234314, + 1.808945005569458, + 1.8023032517242432, + 1.8105073015975952, + 1.81240649143219, + 1.8092195032501222, + 1.8054952936172486, + 1.8071781116485595, + 1.8033202910614015, + 1.8002542707061768, + 1.7996792233276366, + 1.794723889541626, + 1.788533415184021, + 1.7873220053100587, + 1.7867383001327515, + 1.7845936785888672, + 1.7811818826675414, + 1.7822032886505126, + 1.7786470770263672, + 1.7798991229629517, + 1.7747682898712158, + 1.771372130050659, + 1.7732186395645142, + 1.767907095336914, + 1.7648547525405884, + 1.7638111727523804, + 1.7647753219223024, + 1.7633140285491944, + 1.758088468322754, + 1.7619236371612548, + 1.7579612891006469, + 1.7559792065811157, + 1.7589319384002686, + 1.7530268384170533, + 1.7540921353530883, + 1.7525135382080077, + 1.7494160869979858, + 1.7476271326446533, + 1.7485824377059938, + 1.7481532715225219, + 1.7493890603256226, + 1.753240061607361, + 1.7450246924209594, + 1.7479931283950805, + 1.7435640298461914, + 1.7442714694595336, + 1.7464028844833375, + 1.7440459408950806, + 1.7451819785308837, + 1.7424172634124755, + 1.7414722381591796, + 1.7440056212997437, + 1.7441547555160521, + 1.741508984413147, + 1.7379715420913697, + 1.743159959335327, + 1.737216314048767, + 1.73957718044281, + 1.7399621460342407, + 1.7403172652435304, + 1.7397853451919556, + 1.7367944290161133, + 1.7394116397476196, + 1.7405242685317994, + 1.7411947003173829, + 1.7401781003189087, + 1.7426191668319702, + 1.7414960692977905 + ], + "train_acc": [ + 0.24206, + 0.28608, + 0.30284, + 0.31378, + 0.32214, + 0.33222, + 0.33458, + 0.33384, + 0.339, + 0.3365, + 0.33642, + 0.33268, + 0.3321, + 0.3336, + 0.33528, + 0.33412, + 0.3342, + 0.33516, + 0.3349, + 0.3369, + 0.33338, + 0.33768, + 0.3359, + 0.3389, + 0.33848, + 0.33996, + 0.34268, + 0.34516, + 0.34514, + 0.34766, + 0.3483, + 0.35214, + 0.34858, + 0.34992, + 0.35344, + 0.35134, + 0.34842, + 0.35144, + 0.35262, + 0.35272, + 0.35618, + 0.35276, + 0.35816, + 0.3578, + 0.35874, + 0.3609, + 0.36208, + 0.3595, + 0.3629, + 0.36296, + 0.36438, + 0.36234, + 0.3654, + 0.36716, + 0.36458, + 0.36772, + 0.36752, + 0.36904, + 0.3677, + 0.3659, + 0.3724, + 0.36748, + 0.37008, + 0.37166, + 0.37228, + 0.37126, + 0.37246, + 0.3711, + 0.37376, + 0.3725, + 0.37384, + 0.37248, + 0.37364, + 0.37298, + 0.37408, + 0.37304, + 0.37604, + 0.37672, + 0.3764, + 0.37472, + 0.37532, + 0.3761, + 0.37432, + 0.37696, + 0.3772, + 0.37738, + 0.37684, + 0.37452, + 0.37578, + 0.37302, + 0.37482, + 0.3745, + 0.37608, + 0.37402, + 0.37522, + 0.37624, + 0.377, + 0.37676, + 0.37804, + 0.3765 + ], + "test_acc": [ + 0.2873, + 0.2968, + 0.3102, + 0.3278, + 0.3267, + 0.3394, + 0.3594, + 0.3508, + 0.348, + 0.3577, + 0.3237, + 0.3288, + 0.3256, + 0.3468, + 0.3314, + 0.3353, + 0.3342, + 0.3343, + 0.3363, + 0.3386, + 0.3359, + 0.3401, + 0.3353, + 0.3296, + 0.3361, + 0.3547, + 0.3367, + 0.342, + 0.3335, + 0.3402, + 0.3509, + 0.3259, + 0.3411, + 0.342, + 0.3459, + 0.3338, + 0.3447, + 0.3462, + 0.3413, + 0.3409, + 0.3561, + 0.3417, + 0.3429, + 0.3539, + 0.3554, + 0.3409, + 0.3557, + 0.3535, + 0.3559, + 0.3474, + 0.3477, + 0.3647, + 0.3483, + 0.337, + 0.3458, + 0.3461, + 0.3488, + 0.3399, + 0.351, + 0.3393, + 0.3521, + 0.3505, + 0.354, + 0.3489, + 0.3421, + 0.3435, + 0.3427, + 0.345, + 0.3509, + 0.3393, + 0.3513, + 0.3592, + 0.3435, + 0.3491, + 0.3462, + 0.3486, + 0.3435, + 0.3421, + 0.3463, + 0.3537, + 0.3507, + 0.3476, + 0.3523, + 0.346, + 0.3548, + 0.3469, + 0.3489, + 0.3457, + 0.3483, + 0.3495, + 0.3507, + 0.3498, + 0.3509, + 0.3498, + 0.3485, + 0.3503, + 0.3507, + 0.3502, + 0.3501, + 0.3501 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.03285929560661316, + 0.9476211071014404 + ], + "perturbation_rho": [ + 0.010396776720881462, + 0.21192286908626556 + ], + "nudging": { + "0.001": [ + -4.190136678516865e-06, + -7.935799658298492e-06 + ], + "0.003": [ + -1.2663658708333969e-05, + -2.3964676074683666e-05 + ], + "0.01": [ + -4.219170659780502e-05, + -8.006719872355461e-05 + ] + }, + "hidden_norms_per_layer": [ + 3884.70703125, + 284714.21875, + 219214.828125 + ], + "bp_grad_norms_per_layer": [ + 1.9074828742304817e-05, + 6.634672331529146e-07, + 6.522877811221406e-07 + ] + }, + "drift": { + "embed.weight": 27.016438956012585, + "embed.bias": 16.969540152339597, + "blocks.0.ln.weight": 1.5836171150936886, + "blocks.0.w1.weight": 25.706341628627694, + "blocks.0.w1.bias": 21.000244171066868, + "blocks.0.w2.weight": 64.31974159860198, + "blocks.1.ln.weight": 0.9270002241369442, + "blocks.1.w1.weight": 14.313388328940205, + "blocks.1.w1.bias": 6.9567229997624445, + "blocks.1.w2.weight": 31.78188983358227, + "out_ln.weight": 0.46949971605667673, + "out_head.weight": 3.9796491257565023, + "out_head.bias": 12.763196315235756 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 4 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed4", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed5/results_cifar10.json b/results/fa_dfa_d512_L2_seed5/results_cifar10.json new file mode 100644 index 0000000..d5c8ea9 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed5/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "5": { + "dfa": { + "log": { + "train_loss": [ + 2.0501297537231444, + 2.0391713497924804, + 2.0405313035583497, + 2.0335817265319824, + 2.0266058160018923, + 2.0228387924957274, + 2.0232558026504517, + 2.0218732793426515, + 2.0191759132385254, + 2.0202307054901123, + 2.018844593505859, + 2.017033308792114, + 2.018027767868042, + 2.018668871307373, + 2.015074652633667, + 2.015589296875, + 2.016558995742798, + 2.019203759498596, + 2.0183322762298586, + 2.0155471825408937, + 2.018300429916382, + 2.0150776953125, + 2.0150013764953614, + 2.012721175842285, + 2.016366043663025, + 2.012075606918335, + 2.016505323944092, + 2.01135826461792, + 2.0136885620117186, + 2.013543639678955, + 2.013746788291931, + 2.0128848664093018, + 2.012771292114258, + 2.01191774017334, + 2.010462758102417, + 2.010521863632202, + 2.009425925369263, + 2.0124644644546508, + 2.0108147170639037, + 2.010228076095581, + 2.0088936639022825, + 2.0095657570648195, + 2.007657756729126, + 2.009475396652222, + 2.010636301422119, + 2.0098936741638185, + 2.0092102942276, + 2.009080667991638, + 2.007752434120178, + 2.0074822512054444, + 2.011261424026489, + 2.006237863006592, + 2.008023773574829, + 2.0073289957427978, + 2.006207457962036, + 2.0055681369781495, + 2.006566423034668, + 2.0086398764038087, + 2.005377918167114, + 2.008910454864502, + 2.0054182757568357, + 2.0053556463241575, + 2.005029055404663, + 2.0059510985565185, + 2.005951479187012, + 2.0057564875793457, + 2.0048374938964844, + 2.0036859961700437, + 2.005258397903442, + 2.005781553649902, + 2.004468630065918, + 2.0044502724456787, + 2.0023633781433103, + 2.002619049530029, + 2.003770675506592, + 2.0034406785583494, + 2.0038263823699953, + 2.004770489578247, + 2.0045603427505494, + 2.003471921005249, + 2.0053932760620117, + 2.002159836387634, + 2.0036436273574827, + 2.002813849029541, + 2.00232444770813, + 2.0031388371276857, + 2.003156742095947, + 2.0042506001281737, + 2.001820291137695, + 2.0023800246810914, + 2.0012053105926513, + 2.0006055866241454, + 2.0019854052352906, + 2.001209078979492, + 2.0008096754455567, + 2.0021328547668458, + 2.0031403130340575, + 1.9998649573516847, + 2.002096244430542, + 2.0019282162475585 + ], + "train_acc": [ + 0.25066, + 0.25648, + 0.25718, + 0.25752, + 0.26214, + 0.26184, + 0.26432, + 0.2653, + 0.26402, + 0.26416, + 0.26932, + 0.26444, + 0.26816, + 0.26564, + 0.2684, + 0.27134, + 0.26932, + 0.267, + 0.26686, + 0.26922, + 0.26814, + 0.26974, + 0.2686, + 0.26834, + 0.2661, + 0.26928, + 0.26724, + 0.26998, + 0.26744, + 0.26994, + 0.26896, + 0.2696, + 0.26804, + 0.2696, + 0.27204, + 0.26932, + 0.27152, + 0.26992, + 0.27178, + 0.27218, + 0.27082, + 0.27092, + 0.2737, + 0.27122, + 0.27138, + 0.27174, + 0.272, + 0.2722, + 0.2723, + 0.27346, + 0.26892, + 0.27434, + 0.27432, + 0.27238, + 0.27332, + 0.27454, + 0.27508, + 0.27256, + 0.27376, + 0.27296, + 0.27436, + 0.27456, + 0.27532, + 0.27364, + 0.2735, + 0.27452, + 0.27372, + 0.27588, + 0.27558, + 0.27532, + 0.27558, + 0.2755, + 0.27522, + 0.27524, + 0.2758, + 0.27718, + 0.27532, + 0.27618, + 0.27686, + 0.27688, + 0.27632, + 0.27328, + 0.27464, + 0.27712, + 0.27804, + 0.27666, + 0.2755, + 0.27784, + 0.27688, + 0.27586, + 0.27696, + 0.27672, + 0.27692, + 0.2786, + 0.27814, + 0.27484, + 0.27636, + 0.27772, + 0.27436, + 0.27908 + ], + "test_acc": [ + 0.2723, + 0.2812, + 0.2672, + 0.2819, + 0.2601, + 0.279, + 0.2898, + 0.289, + 0.2685, + 0.2896, + 0.2766, + 0.2838, + 0.2792, + 0.2863, + 0.3016, + 0.292, + 0.2951, + 0.2919, + 0.2929, + 0.2842, + 0.2837, + 0.2864, + 0.3007, + 0.2981, + 0.2983, + 0.2929, + 0.2962, + 0.3, + 0.2923, + 0.2989, + 0.2802, + 0.2939, + 0.269, + 0.2905, + 0.278, + 0.2981, + 0.3057, + 0.2984, + 0.3025, + 0.2843, + 0.3007, + 0.2824, + 0.2941, + 0.3064, + 0.284, + 0.2892, + 0.2981, + 0.2894, + 0.2983, + 0.2972, + 0.2969, + 0.2958, + 0.2892, + 0.299, + 0.2959, + 0.2899, + 0.2863, + 0.3044, + 0.2926, + 0.2916, + 0.2983, + 0.2914, + 0.3015, + 0.2956, + 0.2904, + 0.2972, + 0.2893, + 0.289, + 0.2961, + 0.2993, + 0.2911, + 0.289, + 0.2944, + 0.2942, + 0.2955, + 0.2939, + 0.2961, + 0.295, + 0.2949, + 0.2966, + 0.2977, + 0.2997, + 0.2982, + 0.2984, + 0.2965, + 0.2941, + 0.2911, + 0.2957, + 0.2962, + 0.2998, + 0.2916, + 0.2965, + 0.2952, + 0.2956, + 0.2959, + 0.2965, + 0.2957, + 0.296, + 0.2963, + 0.2963 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.38937193155288696, + -0.00133989576715976 + ], + "perturbation_rho": [ + -0.03419807553291321, + 0.0 + ], + "nudging": { + "0.001": [ + -3.7206336855888367e-07, + 0.0 + ], + "0.003": [ + -1.1147931218147278e-06, + 0.0 + ], + "0.01": [ + -3.5390257835388184e-06, + 3.725290298461914e-09 + ] + }, + "hidden_norms_per_layer": [ + 55290.671875, + 1164404992.0, + 1809793664.0 + ], + "bp_grad_norms_per_layer": [ + 2.527334572732798e-07, + 4.3922129822071554e-10, + 4.3930301063532795e-10 + ] + }, + "drift": { + "embed.weight": 334.45692579883826, + "embed.bias": 254.1500844207983, + "blocks.0.ln.weight": 10.378101238923492, + "blocks.0.w1.weight": 277.98485791253944, + "blocks.0.w1.bias": 241.48477586542305, + "blocks.0.w2.weight": 467.05002700094724, + "blocks.1.ln.weight": 6.909499166461763, + "blocks.1.w1.weight": 263.9133275584596, + "blocks.1.w1.bias": 244.87238078164592, + "blocks.1.w2.weight": 283.49051001246613, + "out_ln.weight": 0.43727802835016244, + "out_head.weight": 6.508111271175387, + "out_head.bias": 3.1430989188106846 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.057919206390381, + 1.9476089878845215, + 1.9109919176864625, + 1.8918181984710694, + 1.8734441751861572, + 1.859958696899414, + 1.8465128494644165, + 1.8378619164657592, + 1.8330042699813842, + 1.824237583885193, + 1.82030082321167, + 1.8168681536102296, + 1.8147564684677124, + 1.8174263799285888, + 1.8145158917236328, + 1.8134905625152589, + 1.816219719581604, + 1.8107114935684203, + 1.809884515991211, + 1.8064439716339111, + 1.8076383264160156, + 1.8070710248184203, + 1.799020498085022, + 1.799975090560913, + 1.7974333989715576, + 1.7963496450042724, + 1.8014785760498047, + 1.7986563251113892, + 1.7955734420776368, + 1.7972684210968017, + 1.7941396160125733, + 1.7950014336395264, + 1.7955820244979859, + 1.7959071813201903, + 1.79216160282135, + 1.7948664249038697, + 1.7919580102920531, + 1.7955779486465455, + 1.7922123107147216, + 1.7889366255950927, + 1.7819780150604247, + 1.7830965274429322, + 1.782749061050415, + 1.7800398288726806, + 1.780551498336792, + 1.7743108095932008, + 1.783085387916565, + 1.774484405479431, + 1.7695070098114014, + 1.7732080498504639, + 1.7756617493057252, + 1.7731031243133546, + 1.7697893057632446, + 1.7669785211181641, + 1.7650962362289428, + 1.763043028793335, + 1.7600419904708862, + 1.7632893228912354, + 1.7593432830429077, + 1.7613518405532838, + 1.754806079711914, + 1.7532699264144898, + 1.7546887551498414, + 1.7581820251083373, + 1.751631008605957, + 1.754752312927246, + 1.7504624893951417, + 1.7475749392318725, + 1.751591441116333, + 1.7496289820098876, + 1.748469584388733, + 1.747327193069458, + 1.7421584380722046, + 1.7428951406097413, + 1.7426397796249389, + 1.7432606916046143, + 1.7435048538589477, + 1.7410756462860106, + 1.743000368347168, + 1.7420079108047486, + 1.7423530182266236, + 1.7415100763320923, + 1.7408926866912842, + 1.7403043130874634, + 1.7404012395477295, + 1.7404220779418946, + 1.7388356212615967, + 1.7423903681182862, + 1.7363825454711914, + 1.7398050061035155, + 1.7384051816558839, + 1.7355061944198609, + 1.7398744315338135, + 1.7368783473968505, + 1.7394122046661378, + 1.7361857733154298, + 1.742128964920044, + 1.7403847741699219, + 1.737286774673462, + 1.7390742868041993 + ], + "train_acc": [ + 0.25062, + 0.29808, + 0.31302, + 0.32052, + 0.3262, + 0.33122, + 0.33844, + 0.3404, + 0.34186, + 0.34348, + 0.34654, + 0.34656, + 0.34834, + 0.34672, + 0.3507, + 0.34978, + 0.34784, + 0.35204, + 0.34968, + 0.35226, + 0.35198, + 0.35068, + 0.35748, + 0.35378, + 0.35458, + 0.35536, + 0.35374, + 0.3558, + 0.35606, + 0.35744, + 0.35864, + 0.3575, + 0.35448, + 0.35674, + 0.35792, + 0.3562, + 0.3585, + 0.35812, + 0.35828, + 0.3577, + 0.36192, + 0.36198, + 0.3611, + 0.36338, + 0.36406, + 0.3631, + 0.36, + 0.36512, + 0.36844, + 0.36508, + 0.3646, + 0.36468, + 0.36598, + 0.36518, + 0.3704, + 0.37062, + 0.3711, + 0.37078, + 0.36744, + 0.37034, + 0.37248, + 0.37374, + 0.3702, + 0.37384, + 0.37386, + 0.37256, + 0.3758, + 0.37518, + 0.3727, + 0.37388, + 0.37536, + 0.37548, + 0.37722, + 0.37788, + 0.37788, + 0.37554, + 0.37874, + 0.37764, + 0.37696, + 0.37702, + 0.37834, + 0.37902, + 0.37876, + 0.37774, + 0.37714, + 0.37816, + 0.38042, + 0.37544, + 0.38068, + 0.37858, + 0.38066, + 0.37826, + 0.38016, + 0.3802, + 0.37906, + 0.37884, + 0.37926, + 0.37942, + 0.37788, + 0.37972 + ], + "test_acc": [ + 0.2861, + 0.3325, + 0.3291, + 0.348, + 0.3452, + 0.3574, + 0.3412, + 0.3599, + 0.3485, + 0.3683, + 0.3663, + 0.3682, + 0.3616, + 0.367, + 0.3701, + 0.3702, + 0.3592, + 0.3731, + 0.3691, + 0.3588, + 0.3585, + 0.3754, + 0.3752, + 0.3647, + 0.3568, + 0.3678, + 0.3622, + 0.3687, + 0.3662, + 0.3697, + 0.3614, + 0.3669, + 0.3387, + 0.3594, + 0.3452, + 0.3532, + 0.3488, + 0.3581, + 0.3475, + 0.3599, + 0.3428, + 0.347, + 0.3621, + 0.3565, + 0.3451, + 0.3385, + 0.3365, + 0.3324, + 0.3551, + 0.3452, + 0.353, + 0.3504, + 0.3508, + 0.3448, + 0.3299, + 0.3413, + 0.3342, + 0.3506, + 0.3532, + 0.3428, + 0.3468, + 0.3395, + 0.3551, + 0.3435, + 0.334, + 0.3427, + 0.3352, + 0.3403, + 0.3419, + 0.3374, + 0.3358, + 0.3438, + 0.3401, + 0.3381, + 0.3437, + 0.3412, + 0.3362, + 0.3404, + 0.339, + 0.3406, + 0.3378, + 0.3481, + 0.3449, + 0.3412, + 0.334, + 0.3405, + 0.3431, + 0.3434, + 0.3401, + 0.3446, + 0.3386, + 0.3425, + 0.3413, + 0.3399, + 0.3425, + 0.3407, + 0.3402, + 0.3413, + 0.341, + 0.341 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.01863543502986431, + 0.966240406036377 + ], + "perturbation_rho": [ + 0.049932606518268585, + 0.11253532767295837 + ], + "nudging": { + "0.001": [ + -3.632623702287674e-06, + -7.5231073424220085e-06 + ], + "0.003": [ + -1.0822783224284649e-05, + -2.275872975587845e-05 + ], + "0.01": [ + -3.596034366637468e-05, + -7.59701943024993e-05 + ] + }, + "hidden_norms_per_layer": [ + 3561.231201171875, + 299735.8125, + 271834.53125 + ], + "bp_grad_norms_per_layer": [ + 2.061038867395837e-05, + 5.578897344094003e-07, + 5.395349376158265e-07 + ] + }, + "drift": { + "embed.weight": 25.27185200703667, + "embed.bias": 13.597802808023353, + "blocks.0.ln.weight": 1.7243371838975772, + "blocks.0.w1.weight": 26.803582221943408, + "blocks.0.w1.bias": 17.489010687530495, + "blocks.0.w2.weight": 62.11992968105335, + "blocks.1.ln.weight": 1.1537855920658724, + "blocks.1.w1.weight": 16.75993270124035, + "blocks.1.w1.bias": 9.810533469107403, + "blocks.1.w2.weight": 38.80660527636343, + "out_ln.weight": 0.460453695659119, + "out_head.weight": 4.155803649723169, + "out_head.bias": 13.238295370364474 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 5 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed5", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed6/results_cifar10.json b/results/fa_dfa_d512_L2_seed6/results_cifar10.json new file mode 100644 index 0000000..a0da849 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed6/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "6": { + "dfa": { + "log": { + "train_loss": [ + 2.055892396621704, + 2.0292765501403807, + 2.023764776611328, + 2.020959409866333, + 2.019423359222412, + 2.0197315069580077, + 2.023359367675781, + 2.017872388153076, + 2.01921712928772, + 2.0194658319854737, + 2.021274960708618, + 2.016373544845581, + 2.0175488451766967, + 2.0179581134796143, + 2.0176988045501707, + 2.0177309605407716, + 2.0153829962921144, + 2.013690803070068, + 2.016160803833008, + 2.012104736251831, + 2.0164508349990844, + 2.0172074625396728, + 2.011329041824341, + 2.014233197860718, + 2.0119160061264036, + 2.0112428046417237, + 2.0135288095092774, + 2.0137764200592043, + 2.0114110033416748, + 2.011166734046936, + 2.01347174949646, + 2.0128866384887694, + 2.012690868988037, + 2.0142732004547117, + 2.013253126296997, + 2.012407520980835, + 2.0118109189605713, + 2.012256015167236, + 2.0096606398773194, + 2.010287212905884, + 2.0091940085601805, + 2.0101572361755373, + 2.0091212350845336, + 2.010151270904541, + 2.0109262997817994, + 2.0077846411514284, + 2.0088822582626342, + 2.0098557024383545, + 2.0110063369750977, + 2.0083373790740966, + 2.0075834423828125, + 2.0072854569244383, + 2.0076236277771, + 2.009549726829529, + 2.00914709564209, + 2.0072284757232666, + 2.0073054156112673, + 2.008326266860962, + 2.0065047761535646, + 2.0040750775527956, + 2.008966516571045, + 2.006659435195923, + 2.006130950393677, + 2.0071768866729736, + 2.004858699607849, + 2.005933465194702, + 2.0051698098754884, + 2.0040941329193114, + 2.004839565887451, + 2.004348826370239, + 2.0043614904785154, + 2.005439692611694, + 2.0047981397247314, + 2.0036350199127195, + 2.0039494177627564, + 2.002428635940552, + 2.002635501022339, + 2.005249846954346, + 2.0045922426605225, + 1.9999085303497315, + 2.0022091399383544, + 2.001223135147095, + 2.005691502914429, + 2.0016508586883544, + 2.0019130290222167, + 2.0011878201293944, + 2.002061458091736, + 2.001959662628174, + 2.0005823556518556, + 2.003897890167236, + 2.003386359901428, + 2.0019060565185547, + 2.0028662380981443, + 2.000936403427124, + 2.003748550567627, + 2.00183758392334, + 2.001099390411377, + 2.002016563873291, + 2.001500345993042, + 2.000833380355835 + ], + "train_acc": [ + 0.2468, + 0.25656, + 0.2583, + 0.25842, + 0.26294, + 0.26512, + 0.25982, + 0.26482, + 0.26072, + 0.26388, + 0.25908, + 0.26332, + 0.26268, + 0.26474, + 0.26364, + 0.26604, + 0.2647, + 0.26542, + 0.26462, + 0.26776, + 0.26296, + 0.26282, + 0.26768, + 0.2662, + 0.26584, + 0.26868, + 0.26712, + 0.2637, + 0.26642, + 0.2643, + 0.26304, + 0.2664, + 0.26832, + 0.2695, + 0.26582, + 0.26628, + 0.2661, + 0.26608, + 0.26796, + 0.26702, + 0.26834, + 0.26798, + 0.27016, + 0.26738, + 0.2666, + 0.26772, + 0.26924, + 0.2708, + 0.26744, + 0.26734, + 0.27232, + 0.26954, + 0.27186, + 0.2691, + 0.26756, + 0.27156, + 0.26944, + 0.26982, + 0.27, + 0.2734, + 0.27014, + 0.27428, + 0.2721, + 0.27192, + 0.27394, + 0.27134, + 0.2706, + 0.27214, + 0.2724, + 0.27452, + 0.27194, + 0.27366, + 0.273, + 0.27294, + 0.2731, + 0.27446, + 0.27392, + 0.2735, + 0.274, + 0.2742, + 0.2727, + 0.27306, + 0.27192, + 0.27572, + 0.27528, + 0.27414, + 0.2743, + 0.27318, + 0.27448, + 0.27334, + 0.27238, + 0.27438, + 0.27446, + 0.27388, + 0.27336, + 0.27486, + 0.27176, + 0.27446, + 0.2743, + 0.27204 + ], + "test_acc": [ + 0.2713, + 0.2844, + 0.2811, + 0.28, + 0.299, + 0.2643, + 0.2824, + 0.2917, + 0.2912, + 0.2738, + 0.2633, + 0.2873, + 0.285, + 0.2787, + 0.2918, + 0.2776, + 0.2763, + 0.278, + 0.2677, + 0.2808, + 0.2944, + 0.2953, + 0.2923, + 0.2954, + 0.297, + 0.2825, + 0.2932, + 0.2826, + 0.2908, + 0.285, + 0.2799, + 0.2848, + 0.3029, + 0.2813, + 0.2915, + 0.2709, + 0.3065, + 0.2788, + 0.3023, + 0.2961, + 0.2864, + 0.2937, + 0.297, + 0.3001, + 0.2957, + 0.282, + 0.2855, + 0.2917, + 0.2897, + 0.2983, + 0.2881, + 0.2843, + 0.2775, + 0.2932, + 0.2953, + 0.2891, + 0.3012, + 0.2843, + 0.2914, + 0.2863, + 0.2978, + 0.2948, + 0.2863, + 0.2944, + 0.2897, + 0.2924, + 0.2925, + 0.2939, + 0.2888, + 0.2956, + 0.2964, + 0.2878, + 0.2855, + 0.2904, + 0.2999, + 0.2861, + 0.3025, + 0.2935, + 0.2953, + 0.2878, + 0.2974, + 0.2959, + 0.2944, + 0.2924, + 0.2937, + 0.2953, + 0.2911, + 0.2938, + 0.2924, + 0.2938, + 0.2953, + 0.2933, + 0.2955, + 0.2954, + 0.2939, + 0.2937, + 0.2938, + 0.2936, + 0.2933, + 0.2933 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.3805140256881714, + -0.00048614124534651637 + ], + "perturbation_rho": [ + 0.010780533775687218, + 0.0 + ], + "nudging": { + "0.001": [ + -3.3294782042503357e-07, + 0.0 + ], + "0.003": [ + -1.0631047189235687e-06, + 0.0 + ], + "0.01": [ + -3.623776137828827e-06, + 9.313225746154785e-10 + ] + }, + "hidden_norms_per_layer": [ + 54183.453125, + 1543298944.0, + 3004247040.0 + ], + "bp_grad_norms_per_layer": [ + 2.45587500558031e-07, + 3.692772754249063e-10, + 3.6931127600503544e-10 + ] + }, + "drift": { + "embed.weight": 334.1717985048247, + "embed.bias": 247.7536312263247, + "blocks.0.ln.weight": 9.54033246951961, + "blocks.0.w1.weight": 315.1784989948421, + "blocks.0.w1.bias": 269.25844995073356, + "blocks.0.w2.weight": 501.0628341707343, + "blocks.1.ln.weight": 8.010501646129491, + "blocks.1.w1.weight": 317.0441501524571, + "blocks.1.w1.bias": 304.79370018295725, + "blocks.1.w2.weight": 308.3494986335005, + "out_ln.weight": 0.4575791541009348, + "out_head.weight": 6.637383076518475, + "out_head.bias": 3.6615792171842636 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.065024294166565, + 1.9630456279754638, + 1.9233733417129517, + 1.9067592349243163, + 1.8861662637329102, + 1.8703109120941162, + 1.8611061660766601, + 1.8445787530899047, + 1.829258919067383, + 1.8199622161102296, + 1.8095393975448608, + 1.7991895624542236, + 1.79670791015625, + 1.7896576220321656, + 1.7787932138061524, + 1.777765901145935, + 1.768475433731079, + 1.7715507403945923, + 1.7699084145736694, + 1.7577085567474364, + 1.7492800643539428, + 1.7376883664703369, + 1.7303459451293945, + 1.7293071484375, + 1.7250730394744873, + 1.726899693222046, + 1.7227447848129271, + 1.7317372454452515, + 1.7304015692138672, + 1.7312020352172852, + 1.7351370712280274, + 1.7342145778656006, + 1.7334492336654663, + 1.731919374732971, + 1.7354538122940064, + 1.7322556600952148, + 1.7360534030914307, + 1.7412740663909911, + 1.7402722408676148, + 1.7424089098358155, + 1.7433867425918579, + 1.7484364804077148, + 1.7482806644058229, + 1.7441494891738891, + 1.7457399251937866, + 1.740269787902832, + 1.7429518072128296, + 1.7410654845428466, + 1.7409823212432862, + 1.7343881818008422, + 1.736907575340271, + 1.7340854137802124, + 1.734343325843811, + 1.7312401904296875, + 1.7296781457138062, + 1.7250156705474853, + 1.7303392902374268, + 1.7267954447174072, + 1.7268179355621338, + 1.7248962731552124, + 1.7265133374786377, + 1.725437378578186, + 1.7230116376495361, + 1.721951872406006, + 1.7212245053482056, + 1.7194156966781615, + 1.7177197817993164, + 1.7192226972198486, + 1.7195494704437255, + 1.7167850009536743, + 1.7179081740570068, + 1.7162539065551758, + 1.7143080168914795, + 1.7166276223373413, + 1.716199371986389, + 1.7125089855957032, + 1.7071118188858032, + 1.7140238445281983, + 1.711059020729065, + 1.7108189041137696, + 1.7138464722061157, + 1.712218090133667, + 1.7137628644561766, + 1.7120113299179076, + 1.7139518783950807, + 1.7121368927383422, + 1.7085959339523316, + 1.7080107479095459, + 1.7105122137069702, + 1.7120901413345337, + 1.7093008170318604, + 1.707778765335083, + 1.7079006521987916, + 1.7071793605804444, + 1.7064016004180909, + 1.7067470293426514, + 1.7087606598281861, + 1.7090342914581298, + 1.707086050338745, + 1.7054454077911376 + ], + "train_acc": [ + 0.24472, + 0.28828, + 0.3077, + 0.31704, + 0.32668, + 0.33122, + 0.3323, + 0.3406, + 0.34542, + 0.35004, + 0.35242, + 0.3556, + 0.35868, + 0.36042, + 0.36382, + 0.36434, + 0.3649, + 0.36774, + 0.36866, + 0.36968, + 0.37606, + 0.3764, + 0.3827, + 0.38044, + 0.38462, + 0.38274, + 0.38134, + 0.3805, + 0.38172, + 0.38162, + 0.37954, + 0.3812, + 0.38154, + 0.38246, + 0.38214, + 0.38156, + 0.38276, + 0.37756, + 0.37994, + 0.37818, + 0.37948, + 0.3745, + 0.3761, + 0.37776, + 0.37658, + 0.37964, + 0.3789, + 0.37938, + 0.37994, + 0.38366, + 0.38024, + 0.38026, + 0.38234, + 0.38264, + 0.38382, + 0.38478, + 0.3855, + 0.38412, + 0.38446, + 0.38598, + 0.38618, + 0.38722, + 0.38892, + 0.3878, + 0.38716, + 0.38748, + 0.38934, + 0.3868, + 0.391, + 0.39034, + 0.39048, + 0.39066, + 0.39448, + 0.3907, + 0.39146, + 0.39476, + 0.39424, + 0.39456, + 0.39384, + 0.3942, + 0.39406, + 0.39414, + 0.3935, + 0.39276, + 0.39352, + 0.3918, + 0.3955, + 0.39626, + 0.39398, + 0.39516, + 0.3936, + 0.39472, + 0.39608, + 0.39546, + 0.3953, + 0.3935, + 0.39494, + 0.39556, + 0.39782, + 0.39818 + ], + "test_acc": [ + 0.2947, + 0.3403, + 0.3462, + 0.3468, + 0.3593, + 0.3653, + 0.3723, + 0.3689, + 0.372, + 0.3699, + 0.3635, + 0.3648, + 0.3749, + 0.3716, + 0.3759, + 0.3936, + 0.3903, + 0.3867, + 0.3853, + 0.3955, + 0.3909, + 0.4035, + 0.3969, + 0.3965, + 0.3878, + 0.3929, + 0.3895, + 0.3756, + 0.3801, + 0.377, + 0.3725, + 0.3941, + 0.3851, + 0.3949, + 0.3941, + 0.3753, + 0.3733, + 0.3731, + 0.3846, + 0.3734, + 0.3695, + 0.3857, + 0.3871, + 0.3798, + 0.3841, + 0.3771, + 0.3762, + 0.3814, + 0.3893, + 0.3712, + 0.3673, + 0.3771, + 0.3882, + 0.3816, + 0.3894, + 0.3825, + 0.3915, + 0.369, + 0.3847, + 0.3726, + 0.3876, + 0.3789, + 0.3679, + 0.3797, + 0.3785, + 0.3897, + 0.3758, + 0.3883, + 0.3789, + 0.3812, + 0.375, + 0.3779, + 0.3876, + 0.3884, + 0.3853, + 0.3865, + 0.3874, + 0.3859, + 0.3907, + 0.3826, + 0.385, + 0.3929, + 0.3856, + 0.3809, + 0.385, + 0.3849, + 0.3824, + 0.3854, + 0.3838, + 0.3857, + 0.3835, + 0.3842, + 0.385, + 0.3866, + 0.3864, + 0.3868, + 0.3865, + 0.3866, + 0.3865, + 0.3865 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.017966898158192635, + 0.9611225724220276 + ], + "perturbation_rho": [ + -0.022423196583986282, + 0.05839487165212631 + ], + "nudging": { + "0.001": [ + -1.0927324183285236e-06, + -6.873218808323145e-06 + ], + "0.003": [ + -3.30101465806365e-06, + -2.0737992599606514e-05 + ], + "0.01": [ + -1.1010735761374235e-05, + -6.911164382472634e-05 + ] + }, + "hidden_norms_per_layer": [ + 5436.04931640625, + 220695.28125, + 94910.3203125 + ], + "bp_grad_norms_per_layer": [ + 2.4236895114881918e-05, + 2.0322845557529945e-06, + 1.8303358046978246e-06 + ] + }, + "drift": { + "embed.weight": 32.57084729521678, + "embed.bias": 26.337782289300907, + "blocks.0.ln.weight": 1.3633438568396286, + "blocks.0.w1.weight": 19.737032593158407, + "blocks.0.w1.bias": 13.272555669066517, + "blocks.0.w2.weight": 55.66342599708415, + "blocks.1.ln.weight": 1.0296837730331478, + "blocks.1.w1.weight": 17.104661462082266, + "blocks.1.w1.bias": 9.996125311409836, + "blocks.1.w2.weight": 40.18024027628224, + "out_ln.weight": 0.49660679931516427, + "out_head.weight": 3.626674560113674, + "out_head.bias": 8.095867414428609 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 6 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed6", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed7/results_cifar10.json b/results/fa_dfa_d512_L2_seed7/results_cifar10.json new file mode 100644 index 0000000..481338c --- /dev/null +++ b/results/fa_dfa_d512_L2_seed7/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "7": { + "dfa": { + "log": { + "train_loss": [ + 2.055191882095337, + 2.043570478172302, + 2.0389851570892334, + 2.035236138343811, + 2.034318865890503, + 2.028933349571228, + 2.027525556945801, + 2.0267114724731443, + 2.0244813998031614, + 2.0217119938659667, + 2.0195889541625975, + 2.0210990812683107, + 2.0200074877166747, + 2.0166528286743164, + 2.0172128536224365, + 2.0116504791259766, + 2.0116141784667967, + 2.007682880592346, + 2.0073163960266114, + 2.0056607851791384, + 2.0035111226654054, + 2.0027661852264402, + 2.001486244812012, + 1.9986013162231446, + 1.997822530975342, + 1.994211967010498, + 1.9924970539093017, + 1.991006000213623, + 1.9917142455291748, + 1.9929195302581788, + 1.991425333518982, + 1.9895027528381348, + 1.9874201156997682, + 1.9884889245605468, + 1.986476892967224, + 1.9826913553237915, + 1.9848502500152587, + 1.982286188697815, + 1.9824313144683838, + 1.9826200205230713, + 1.9815959204864502, + 1.9807573030853272, + 1.9788639946746827, + 1.9796723455810548, + 1.976197322998047, + 1.9779787089538574, + 1.9767974209594728, + 1.9764958350372315, + 1.9774945835876465, + 1.9747930298614502, + 1.9727936906433106, + 1.9740131066131592, + 1.9746413722991942, + 1.9732583834838868, + 1.9723580028533936, + 1.9721147798919678, + 1.972632116394043, + 1.970795502243042, + 1.9690511889648437, + 1.9716948537826537, + 1.9707828464508057, + 1.9699502618026734, + 1.969847806854248, + 1.9696522138214112, + 1.9674137978744506, + 1.9683316724014281, + 1.96831518825531, + 1.9711441744613647, + 1.9686481335830688, + 1.9681067206573486, + 1.9675092895507813, + 1.9689227291107179, + 1.9677108968353272, + 1.9652254253387451, + 1.9658482390594483, + 1.9671778398895263, + 1.9676923946762086, + 1.9644599740219115, + 1.9651101944351197, + 1.9664753457260131, + 1.9669549044799806, + 1.9642051581573485, + 1.9663040605163575, + 1.9670575283432006, + 1.9643160835266114, + 1.9654435306167604, + 1.964134133644104, + 1.9641572988891602, + 1.9649527727508544, + 1.9648468869781495, + 1.9629023349380492, + 1.9663221743774415, + 1.9669214375305175, + 1.9649668531036377, + 1.9656098468780518, + 1.966679817123413, + 1.9638590016174315, + 1.9619427868270873, + 1.9663854665756226, + 1.966093027381897 + ], + "train_acc": [ + 0.24264, + 0.24694, + 0.24798, + 0.24876, + 0.25246, + 0.25632, + 0.25708, + 0.25698, + 0.25824, + 0.26182, + 0.2633, + 0.26056, + 0.26094, + 0.26186, + 0.26334, + 0.26518, + 0.26768, + 0.26548, + 0.26774, + 0.26878, + 0.2695, + 0.2733, + 0.26882, + 0.27458, + 0.2744, + 0.27714, + 0.27356, + 0.27912, + 0.27626, + 0.27554, + 0.2775, + 0.2796, + 0.27856, + 0.27752, + 0.27988, + 0.28044, + 0.28096, + 0.2833, + 0.28072, + 0.28224, + 0.28246, + 0.28386, + 0.28302, + 0.28622, + 0.28438, + 0.28368, + 0.28372, + 0.28316, + 0.28436, + 0.28636, + 0.28652, + 0.28768, + 0.2859, + 0.28832, + 0.28462, + 0.28838, + 0.2877, + 0.28634, + 0.28904, + 0.28846, + 0.28746, + 0.28738, + 0.2889, + 0.28872, + 0.28988, + 0.28838, + 0.2881, + 0.29144, + 0.28876, + 0.28998, + 0.29006, + 0.28796, + 0.28852, + 0.29134, + 0.29044, + 0.29064, + 0.28968, + 0.2908, + 0.28898, + 0.29126, + 0.28764, + 0.29144, + 0.2911, + 0.29084, + 0.29148, + 0.28904, + 0.29322, + 0.2902, + 0.2903, + 0.2903, + 0.29346, + 0.2899, + 0.28924, + 0.29046, + 0.28966, + 0.29182, + 0.2899, + 0.28986, + 0.2916, + 0.28824 + ], + "test_acc": [ + 0.2524, + 0.2673, + 0.2815, + 0.2506, + 0.2816, + 0.2711, + 0.2809, + 0.2844, + 0.2929, + 0.286, + 0.2744, + 0.289, + 0.2842, + 0.2845, + 0.2813, + 0.2816, + 0.2741, + 0.2919, + 0.2911, + 0.2705, + 0.2945, + 0.3034, + 0.2878, + 0.2798, + 0.2791, + 0.298, + 0.2976, + 0.293, + 0.2996, + 0.2971, + 0.3057, + 0.3087, + 0.3011, + 0.3061, + 0.3061, + 0.2919, + 0.3051, + 0.306, + 0.3078, + 0.3056, + 0.3019, + 0.3022, + 0.3075, + 0.3081, + 0.3008, + 0.2993, + 0.2992, + 0.3129, + 0.3147, + 0.2965, + 0.31, + 0.2998, + 0.3082, + 0.312, + 0.3096, + 0.3084, + 0.306, + 0.3072, + 0.308, + 0.3153, + 0.3089, + 0.3032, + 0.3083, + 0.3208, + 0.3021, + 0.3158, + 0.3101, + 0.3162, + 0.3125, + 0.3103, + 0.3118, + 0.309, + 0.3117, + 0.3098, + 0.3162, + 0.3136, + 0.3107, + 0.3074, + 0.3122, + 0.3141, + 0.3127, + 0.314, + 0.3129, + 0.3145, + 0.3133, + 0.3138, + 0.3131, + 0.3158, + 0.3159, + 0.3161, + 0.3141, + 0.3163, + 0.3165, + 0.3154, + 0.3163, + 0.3164, + 0.3155, + 0.3155, + 0.3157, + 0.3157 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.4341934025287628, + -0.0012890032958239317 + ], + "perturbation_rho": [ + -0.016353363171219826, + 0.0 + ], + "nudging": { + "0.001": [ + -6.253831088542938e-07, + 0.0 + ], + "0.003": [ + -1.7937272787094116e-06, + 1.862645149230957e-09 + ], + "0.01": [ + -6.007961928844452e-06, + 5.587935447692871e-09 + ] + }, + "hidden_norms_per_layer": [ + 51618.765625, + 586580544.0, + 3527968768.0 + ], + "bp_grad_norms_per_layer": [ + 3.575926825760689e-07, + 3.2419492090873803e-10, + 3.250963664935824e-10 + ] + }, + "drift": { + "embed.weight": 305.91109235091693, + "embed.bias": 181.83596330336093, + "blocks.0.ln.weight": 10.507483299521725, + "blocks.0.w1.weight": 259.244642347064, + "blocks.0.w1.bias": 195.88858302375624, + "blocks.0.w2.weight": 487.57558684361754, + "blocks.1.ln.weight": 9.741938719661329, + "blocks.1.w1.weight": 370.35850714262523, + "blocks.1.w1.bias": 288.30321808565367, + "blocks.1.w2.weight": 411.72437143655895, + "out_ln.weight": 0.4751365199666905, + "out_head.weight": 6.958959988926819, + "out_head.bias": 2.336896476076589 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.065423551902771, + 1.959441243019104, + 1.922898962135315, + 1.8985358092498779, + 1.8807485336685181, + 1.8632725805664063, + 1.8535257900238038, + 1.8472118298339844, + 1.8383856839752197, + 1.8354701901626587, + 1.8370730047988892, + 1.837422192993164, + 1.8389883419036865, + 1.8328349654388427, + 1.8349895798110962, + 1.8335126385116578, + 1.8306356255722045, + 1.8322726831817626, + 1.8286183166885377, + 1.8288718432617188, + 1.8341219143295289, + 1.8368035592269898, + 1.836811799583435, + 1.836639772872925, + 1.8373656529998779, + 1.836424264526367, + 1.8352625201797486, + 1.835800665588379, + 1.832284771118164, + 1.833700514755249, + 1.8298065981674194, + 1.8254384775543213, + 1.821793406639099, + 1.8201812942123412, + 1.8183945245742799, + 1.8144645150375367, + 1.8105643741607667, + 1.809894971961975, + 1.804399468612671, + 1.8044112042999267, + 1.8014125130844116, + 1.8024961044311523, + 1.7954880487442018, + 1.798938335647583, + 1.790709694480896, + 1.7935989169311524, + 1.7944183544921875, + 1.7913234217071534, + 1.7892272469711303, + 1.7870802127838135, + 1.7874060748291016, + 1.7849993677139282, + 1.7821664669418336, + 1.784739924659729, + 1.7879353686523438, + 1.786940495300293, + 1.784608112564087, + 1.7870150913238525, + 1.7842396997070313, + 1.7843707611465454, + 1.7866798919296265, + 1.7842874060440064, + 1.7791111608886718, + 1.7806335921859742, + 1.7796328140640258, + 1.7788975708770751, + 1.7777235781097411, + 1.7809623500823974, + 1.7793159552764892, + 1.7774666768646241, + 1.7777468250274657, + 1.7797447577667236, + 1.7749399490356446, + 1.7782198504257203, + 1.7768975478744506, + 1.7724751330184936, + 1.7744919674682618, + 1.771957764816284, + 1.773150245323181, + 1.7745548120880128, + 1.7708911740112305, + 1.7687720165634155, + 1.7715272402191162, + 1.7657094653320313, + 1.7654507334899903, + 1.7671412448120116, + 1.7664664395904541, + 1.768566664352417, + 1.7679725888824462, + 1.7632778255844117, + 1.7664907150268554, + 1.7670541214370727, + 1.7653729449081421, + 1.7616971383666993, + 1.7609278664398194, + 1.7637566931915283, + 1.764719966506958, + 1.7616608474349975, + 1.7632052117538453, + 1.7648505276107789 + ], + "train_acc": [ + 0.24818, + 0.2905, + 0.3052, + 0.31584, + 0.32536, + 0.33246, + 0.33778, + 0.33886, + 0.33934, + 0.34248, + 0.3433, + 0.34108, + 0.33788, + 0.34122, + 0.3401, + 0.34124, + 0.3422, + 0.3437, + 0.34032, + 0.34338, + 0.33892, + 0.33994, + 0.3359, + 0.33762, + 0.33858, + 0.33796, + 0.33882, + 0.34338, + 0.34164, + 0.34066, + 0.34704, + 0.3459, + 0.34644, + 0.34698, + 0.34912, + 0.34876, + 0.3523, + 0.35206, + 0.35342, + 0.35536, + 0.35396, + 0.3552, + 0.35924, + 0.35432, + 0.35796, + 0.35682, + 0.35832, + 0.35936, + 0.36088, + 0.36014, + 0.36244, + 0.3632, + 0.36318, + 0.36202, + 0.36224, + 0.36138, + 0.36264, + 0.35992, + 0.3656, + 0.36372, + 0.3631, + 0.3644, + 0.3645, + 0.36648, + 0.36714, + 0.36738, + 0.3652, + 0.36398, + 0.36904, + 0.3683, + 0.36936, + 0.3682, + 0.36688, + 0.36788, + 0.3702, + 0.36986, + 0.36914, + 0.3728, + 0.3697, + 0.37158, + 0.37106, + 0.37164, + 0.37216, + 0.3765, + 0.37258, + 0.3718, + 0.37274, + 0.37224, + 0.3709, + 0.372, + 0.37362, + 0.3726, + 0.37316, + 0.37384, + 0.37496, + 0.37558, + 0.3719, + 0.37734, + 0.37454, + 0.37464 + ], + "test_acc": [ + 0.2781, + 0.3155, + 0.3352, + 0.3255, + 0.3405, + 0.3545, + 0.3621, + 0.3461, + 0.3528, + 0.3543, + 0.331, + 0.3589, + 0.3453, + 0.349, + 0.3481, + 0.3479, + 0.355, + 0.3516, + 0.347, + 0.3387, + 0.3431, + 0.3418, + 0.3297, + 0.3231, + 0.3156, + 0.3425, + 0.3359, + 0.3325, + 0.3627, + 0.3371, + 0.351, + 0.3523, + 0.3384, + 0.3573, + 0.3432, + 0.3522, + 0.3549, + 0.3581, + 0.361, + 0.3577, + 0.3477, + 0.3498, + 0.3374, + 0.3495, + 0.337, + 0.3531, + 0.3467, + 0.3464, + 0.3582, + 0.3339, + 0.3416, + 0.3402, + 0.3472, + 0.3388, + 0.3466, + 0.3442, + 0.3452, + 0.3515, + 0.3485, + 0.3612, + 0.3508, + 0.3437, + 0.3639, + 0.3558, + 0.3438, + 0.3503, + 0.3466, + 0.3499, + 0.3569, + 0.3508, + 0.3467, + 0.3479, + 0.3496, + 0.352, + 0.3543, + 0.3569, + 0.3474, + 0.344, + 0.3485, + 0.3519, + 0.3507, + 0.352, + 0.3534, + 0.3545, + 0.353, + 0.3563, + 0.3533, + 0.355, + 0.3552, + 0.357, + 0.354, + 0.3546, + 0.3545, + 0.3529, + 0.3537, + 0.3533, + 0.3535, + 0.3541, + 0.3539, + 0.3537 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.039823852479457855, + 0.9718431234359741 + ], + "perturbation_rho": [ + -0.031080063432455063, + 0.03366325423121452 + ], + "nudging": { + "0.001": [ + -2.9830262064933777e-06, + -5.972804501652718e-06 + ], + "0.003": [ + -8.938834071159363e-06, + -1.8277671188116074e-05 + ], + "0.01": [ + -3.016670234501362e-05, + -6.103678606450558e-05 + ] + }, + "hidden_norms_per_layer": [ + 5119.443359375, + 277281.78125, + 221633.09375 + ], + "bp_grad_norms_per_layer": [ + 2.112029142153915e-05, + 1.5214125141937984e-06, + 1.2520613381639123e-06 + ] + }, + "drift": { + "embed.weight": 32.21887870916367, + "embed.bias": 19.808534592837418, + "blocks.0.ln.weight": 1.571417962464522, + "blocks.0.w1.weight": 27.02065176792147, + "blocks.0.w1.bias": 18.826762362229086, + "blocks.0.w2.weight": 62.552517441341244, + "blocks.1.ln.weight": 1.2403307752162702, + "blocks.1.w1.weight": 19.856236759319962, + "blocks.1.w1.bias": 14.207741981198222, + "blocks.1.w2.weight": 38.602188965118806, + "out_ln.weight": 0.442837113470192, + "out_head.weight": 4.411774526619856, + "out_head.bias": 11.444358576141303 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 7 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed7", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed8/results_cifar10.json b/results/fa_dfa_d512_L2_seed8/results_cifar10.json new file mode 100644 index 0000000..ad4f8bf --- /dev/null +++ b/results/fa_dfa_d512_L2_seed8/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "8": { + "dfa": { + "log": { + "train_loss": [ + 2.0461414054870604, + 2.028953609237671, + 2.0224830960845948, + 2.0189640741348267, + 2.0186043047332762, + 2.011939255371094, + 2.00884051902771, + 2.006375752105713, + 2.0061196725463866, + 2.0044810569763185, + 2.000473762359619, + 1.999425143661499, + 1.9991976582336426, + 1.9983828188323975, + 1.9999516902160646, + 2.000131714401245, + 2.000538674583435, + 1.999087513999939, + 1.9965772060394287, + 1.9952695233154296, + 1.9958783276367187, + 1.9935351610565186, + 1.9946063667297362, + 1.9927719008636475, + 1.9930043865203857, + 1.992498685836792, + 1.9923327257537842, + 1.9919360898208618, + 1.9922407612609863, + 1.992235760192871, + 1.992523240623474, + 1.9913750788116455, + 1.9934618493652343, + 1.9923072975921632, + 1.991889149017334, + 1.9904881185913086, + 1.9916703072357178, + 1.9886171953582763, + 1.9906551904296874, + 1.9912475842285156, + 1.9906490882873535, + 1.9897464477920532, + 1.9902090267562866, + 1.98889538230896, + 1.9899030431365967, + 1.9914473651885987, + 1.987290951499939, + 1.9882332851409912, + 1.987876655883789, + 1.986627554321289, + 1.9874235430526734, + 1.985631577758789, + 1.9859224739074708, + 1.984819174156189, + 1.9877312922668458, + 1.9880256802749634, + 1.9883026748657227, + 1.9880133050918578, + 1.9872900528717041, + 1.9861614239501952, + 1.9879061701202392, + 1.986334129486084, + 1.9889874462509156, + 1.9848816363143922, + 1.985108938598633, + 1.9852813848114013, + 1.9847304062652589, + 1.9842171102142334, + 1.986599141769409, + 1.9858817081451416, + 1.9852649099349975, + 1.982235690460205, + 1.9838326638793946, + 1.9827596533966065, + 1.9836663906097411, + 1.9844815605163575, + 1.9842618350982666, + 1.984488549156189, + 1.9830066680908203, + 1.9835596384048462, + 1.9822921755981446, + 1.9845924890899658, + 1.9827096743392945, + 1.9821208544158935, + 1.983244455909729, + 1.9823134091567993, + 1.9837292085266114, + 1.9827219149017334, + 1.9828275063323975, + 1.9811715400695802, + 1.9821106326675415, + 1.9830516720199585, + 1.9831754531097412, + 1.9825017529296876, + 1.9844729486465453, + 1.9823845765686035, + 1.9825008573913574, + 1.9817371285247802, + 1.9809684440612794, + 1.981824859008789 + ], + "train_acc": [ + 0.24566, + 0.25216, + 0.25672, + 0.2583, + 0.25908, + 0.2611, + 0.26292, + 0.26504, + 0.26436, + 0.26372, + 0.26838, + 0.2684, + 0.26774, + 0.26886, + 0.26864, + 0.26928, + 0.26802, + 0.26984, + 0.27068, + 0.27056, + 0.27002, + 0.27018, + 0.27034, + 0.27138, + 0.27238, + 0.27266, + 0.27208, + 0.27464, + 0.27406, + 0.27306, + 0.27278, + 0.27354, + 0.27164, + 0.27506, + 0.27494, + 0.27682, + 0.27244, + 0.2749, + 0.27302, + 0.2745, + 0.2736, + 0.27382, + 0.27478, + 0.27344, + 0.27582, + 0.2732, + 0.27662, + 0.27616, + 0.2763, + 0.27742, + 0.27642, + 0.27984, + 0.2777, + 0.28112, + 0.27754, + 0.27344, + 0.2769, + 0.27664, + 0.27734, + 0.27754, + 0.27784, + 0.27642, + 0.27598, + 0.2796, + 0.27964, + 0.27868, + 0.27864, + 0.27868, + 0.27826, + 0.27986, + 0.27616, + 0.2817, + 0.28024, + 0.27688, + 0.28002, + 0.27942, + 0.27838, + 0.2787, + 0.27856, + 0.28034, + 0.27972, + 0.27722, + 0.27992, + 0.28172, + 0.27796, + 0.2775, + 0.28122, + 0.2803, + 0.27958, + 0.27936, + 0.28062, + 0.2791, + 0.2797, + 0.28202, + 0.27898, + 0.27894, + 0.27866, + 0.27832, + 0.28, + 0.28114 + ], + "test_acc": [ + 0.2641, + 0.2447, + 0.2423, + 0.2586, + 0.2666, + 0.2868, + 0.2779, + 0.27, + 0.2886, + 0.2704, + 0.2846, + 0.2825, + 0.2801, + 0.2912, + 0.277, + 0.2834, + 0.2838, + 0.2802, + 0.2793, + 0.2883, + 0.2871, + 0.2825, + 0.288, + 0.2913, + 0.2813, + 0.2931, + 0.2994, + 0.3001, + 0.2753, + 0.2966, + 0.2996, + 0.2965, + 0.2835, + 0.3002, + 0.2979, + 0.2783, + 0.3017, + 0.2848, + 0.2846, + 0.2928, + 0.2956, + 0.2705, + 0.2811, + 0.292, + 0.2763, + 0.2834, + 0.2837, + 0.2949, + 0.2998, + 0.2894, + 0.2853, + 0.2921, + 0.2986, + 0.2918, + 0.2933, + 0.28, + 0.2988, + 0.2925, + 0.2828, + 0.2875, + 0.2985, + 0.2927, + 0.2922, + 0.3016, + 0.2947, + 0.2952, + 0.2936, + 0.2923, + 0.2992, + 0.2968, + 0.2915, + 0.2992, + 0.2941, + 0.2969, + 0.2936, + 0.2972, + 0.2964, + 0.2928, + 0.2958, + 0.2973, + 0.2958, + 0.2971, + 0.2983, + 0.2983, + 0.2951, + 0.2963, + 0.2953, + 0.2964, + 0.2963, + 0.293, + 0.2951, + 0.2967, + 0.2967, + 0.297, + 0.2966, + 0.2967, + 0.2967, + 0.2966, + 0.2966, + 0.2967 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.3839848041534424, + -0.0006596383173018694 + ], + "perturbation_rho": [ + 0.02412901259958744, + 0.0 + ], + "nudging": { + "0.001": [ + -4.284083843231201e-07, + 0.0 + ], + "0.003": [ + -1.3029202818870544e-06, + 0.0 + ], + "0.01": [ + -4.258938133716583e-06, + 0.0 + ] + }, + "hidden_norms_per_layer": [ + 53606.07421875, + 782741952.0, + 4561426432.0 + ], + "bp_grad_norms_per_layer": [ + 3.0959373020778003e-07, + 3.211692578553027e-10, + 3.2110258896267396e-10 + ] + }, + "drift": { + "embed.weight": 322.14716880304843, + "embed.bias": 253.7663922411994, + "blocks.0.ln.weight": 9.639264948834146, + "blocks.0.w1.weight": 278.73838749620353, + "blocks.0.w1.bias": 250.4325839565606, + "blocks.0.w2.weight": 488.7145731499842, + "blocks.1.ln.weight": 9.351844341907716, + "blocks.1.w1.weight": 376.1230199590697, + "blocks.1.w1.bias": 376.4920228948829, + "blocks.1.w2.weight": 403.36194320458816, + "out_ln.weight": 0.513534701987637, + "out_head.weight": 7.742896807554914, + "out_head.bias": 1.1408249446090957 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.0653204177856446, + 1.9681754495239259, + 1.9207439241790771, + 1.8943786935806275, + 1.8802625274276734, + 1.8628883916473389, + 1.8490794301605225, + 1.8400364615249634, + 1.8355530837249756, + 1.8266534114837647, + 1.8170614450454712, + 1.8170341843414306, + 1.809728816757202, + 1.80435710231781, + 1.8044258756256104, + 1.7987670265960694, + 1.7999644842147826, + 1.7953315704345703, + 1.7825802756118774, + 1.7845099126434327, + 1.7779603283691405, + 1.7730702184677125, + 1.7695367748260498, + 1.7647147689056397, + 1.7608037928009033, + 1.7535330474090576, + 1.75540680393219, + 1.7487119026947022, + 1.7451758492279053, + 1.745663233718872, + 1.7409377331924438, + 1.7372819427871704, + 1.7388353146743774, + 1.7410107398605346, + 1.7357459603881835, + 1.7290400104522705, + 1.7318879638290405, + 1.729545888900757, + 1.7291300962066651, + 1.733487999343872, + 1.7302352466583253, + 1.7331863732147217, + 1.7339300806045532, + 1.7359314554595948, + 1.7347373684310914, + 1.7375746230316162, + 1.7313235457611085, + 1.7353338851165772, + 1.73820286403656, + 1.7340232833862306, + 1.7341221398925781, + 1.7339955667114257, + 1.7339833393096924, + 1.7300079647064208, + 1.7303287143707275, + 1.7298947993469238, + 1.7274299069213868, + 1.7319763919067384, + 1.7264507590103149, + 1.7299312719726563, + 1.7262817291641235, + 1.7270497384262085, + 1.7246126723098756, + 1.7255360033798217, + 1.7196656997680664, + 1.7214281465911865, + 1.7198693132781981, + 1.7217327685546875, + 1.7179188094329834, + 1.7154542624664306, + 1.7195643279266357, + 1.7149108078384399, + 1.7153279098510743, + 1.7153327802276612, + 1.7094103066253663, + 1.7162176602554322, + 1.7122739435195924, + 1.7128758419418335, + 1.708295757408142, + 1.7086283514785767, + 1.7106998838043213, + 1.711903858909607, + 1.7090504042434693, + 1.7111234002304077, + 1.7090101749420166, + 1.7079448748779298, + 1.707335115890503, + 1.7080174974822997, + 1.705443286972046, + 1.7056840801239013, + 1.708050958175659, + 1.7072754711151124, + 1.7059114752960205, + 1.7063753726959228, + 1.7061457764434815, + 1.7009761544418336, + 1.707751445274353, + 1.7066948248291016, + 1.7041513320922852, + 1.7030947677612305 + ], + "train_acc": [ + 0.24824, + 0.29, + 0.30804, + 0.32064, + 0.32604, + 0.33114, + 0.33818, + 0.3373, + 0.34064, + 0.3462, + 0.34982, + 0.34634, + 0.35334, + 0.35346, + 0.35488, + 0.3554, + 0.35676, + 0.35738, + 0.36206, + 0.36144, + 0.36364, + 0.3671, + 0.36916, + 0.3698, + 0.37032, + 0.3709, + 0.37188, + 0.3757, + 0.37546, + 0.37376, + 0.37888, + 0.37908, + 0.37828, + 0.37822, + 0.37984, + 0.3791, + 0.38144, + 0.38286, + 0.37966, + 0.3797, + 0.38152, + 0.38084, + 0.3812, + 0.3809, + 0.38156, + 0.382, + 0.38494, + 0.38258, + 0.38066, + 0.3824, + 0.38058, + 0.38342, + 0.3823, + 0.38198, + 0.38534, + 0.38368, + 0.38524, + 0.38594, + 0.38536, + 0.38514, + 0.38676, + 0.38838, + 0.38778, + 0.38664, + 0.39118, + 0.38748, + 0.39154, + 0.3882, + 0.38852, + 0.39078, + 0.38918, + 0.39202, + 0.39086, + 0.39128, + 0.39242, + 0.3935, + 0.39422, + 0.39322, + 0.39388, + 0.39516, + 0.39496, + 0.39414, + 0.39428, + 0.39522, + 0.39426, + 0.3953, + 0.39628, + 0.39662, + 0.39662, + 0.39608, + 0.39646, + 0.39568, + 0.39606, + 0.39558, + 0.3961, + 0.3992, + 0.39438, + 0.39786, + 0.39626, + 0.39922 + ], + "test_acc": [ + 0.297, + 0.3142, + 0.3407, + 0.3409, + 0.3537, + 0.3628, + 0.3482, + 0.3626, + 0.3557, + 0.351, + 0.3614, + 0.3693, + 0.363, + 0.3681, + 0.374, + 0.3756, + 0.3778, + 0.3795, + 0.3757, + 0.3742, + 0.3801, + 0.3732, + 0.3782, + 0.3805, + 0.3718, + 0.3786, + 0.3914, + 0.3924, + 0.3797, + 0.3967, + 0.3891, + 0.3829, + 0.3811, + 0.3967, + 0.3876, + 0.3862, + 0.3945, + 0.3851, + 0.3807, + 0.3871, + 0.3924, + 0.3941, + 0.3823, + 0.392, + 0.3942, + 0.3908, + 0.3979, + 0.3954, + 0.4003, + 0.401, + 0.3938, + 0.3938, + 0.3997, + 0.3943, + 0.4069, + 0.4012, + 0.4056, + 0.408, + 0.4037, + 0.3996, + 0.4115, + 0.4098, + 0.4062, + 0.4083, + 0.4115, + 0.4094, + 0.4147, + 0.4109, + 0.4101, + 0.4084, + 0.4112, + 0.4114, + 0.4127, + 0.4113, + 0.407, + 0.4126, + 0.4167, + 0.4132, + 0.4156, + 0.4138, + 0.4122, + 0.4128, + 0.417, + 0.4138, + 0.4128, + 0.4136, + 0.4137, + 0.4132, + 0.4153, + 0.4136, + 0.4123, + 0.413, + 0.4134, + 0.4142, + 0.4155, + 0.4155, + 0.4142, + 0.4147, + 0.4149, + 0.4146 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.02496020682156086, + 0.9624444246292114 + ], + "perturbation_rho": [ + -0.018970193341374397, + 0.1267668902873993 + ], + "nudging": { + "0.001": [ + -2.1758460206910968e-06, + -7.471287972293794e-06 + ], + "0.003": [ + -6.601490895263851e-06, + -2.2840846213512123e-05 + ], + "0.01": [ + -2.1907704649493098e-05, + -7.614441710757092e-05 + ] + }, + "hidden_norms_per_layer": [ + 6420.08984375, + 205601.1875, + 92651.546875 + ], + "bp_grad_norms_per_layer": [ + 2.9750526664429344e-05, + 2.456108632031828e-06, + 2.075838210657821e-06 + ] + }, + "drift": { + "embed.weight": 37.110060620728056, + "embed.bias": 24.00213600327038, + "blocks.0.ln.weight": 1.3952106680756078, + "blocks.0.w1.weight": 19.29938322080228, + "blocks.0.w1.bias": 17.715456046599957, + "blocks.0.w2.weight": 54.87838247174745, + "blocks.1.ln.weight": 1.0669201771550085, + "blocks.1.w1.weight": 18.224214324200183, + "blocks.1.w1.bias": 17.474148202150594, + "blocks.1.w2.weight": 31.754151778532563, + "out_ln.weight": 0.42931707858608786, + "out_head.weight": 4.036471355678216, + "out_head.bias": 7.088871139191845 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 8 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed8", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_dfa_d512_L2_seed9/results_cifar10.json b/results/fa_dfa_d512_L2_seed9/results_cifar10.json new file mode 100644 index 0000000..f8f63d4 --- /dev/null +++ b/results/fa_dfa_d512_L2_seed9/results_cifar10.json @@ -0,0 +1,749 @@ +{ + "9": { + "dfa": { + "log": { + "train_loss": [ + 2.0386180068206787, + 2.0063905084228515, + 2.0115847008895873, + 2.0072783072662355, + 2.0072200788116454, + 2.007624338607788, + 2.004089120826721, + 2.004089771194458, + 1.9991737648010255, + 1.9998447619247437, + 1.9991543921661377, + 2.0006776647186277, + 1.9984499712371826, + 1.9943631922149658, + 1.994661587486267, + 1.993870899658203, + 1.9933709629821776, + 1.9952729042816162, + 1.9926185557556153, + 1.9912919250488281, + 1.9926570781707764, + 1.9911886741638183, + 1.9924334022521972, + 1.9887267460632325, + 1.9918579584121705, + 1.990792886352539, + 1.9901814585113526, + 1.9899697443008424, + 1.989364903640747, + 1.9860043600463868, + 1.9868051830291749, + 1.9893270887756347, + 1.9876267849349976, + 1.9893067332458496, + 1.9857772846221924, + 1.9837282821655273, + 1.9841915615081787, + 1.9825071751403809, + 1.9810244177246095, + 1.9853471765136719, + 1.9825159663391114, + 1.9842312357330323, + 1.9828273846817017, + 1.9831989236450196, + 1.9810442190551758, + 1.9819502600479126, + 1.981037528152466, + 1.9779154906845093, + 1.9809581425476075, + 1.9797189464569092, + 1.9827272113037109, + 1.9808893032073975, + 1.9813728839111329, + 1.9781295204925538, + 1.9786987835311889, + 1.9786434796905517, + 1.9790834701919555, + 1.9781043656921387, + 1.9786688150024414, + 1.9782110293579103, + 1.9773203885269166, + 1.9758285034942626, + 1.977617745285034, + 1.97553291305542, + 1.975182448425293, + 1.9747401064682006, + 1.9758236003875733, + 1.9758562586212158, + 1.9762433794403076, + 1.9765984790420532, + 1.9761512882232666, + 1.9743927758407593, + 1.9739378618621826, + 1.9728027178192138, + 1.9720366858673095, + 1.9760308059692382, + 1.9740920357513427, + 1.9741222943115235, + 1.9719459004974365, + 1.9733767440032959, + 1.9732320972061157, + 1.9731594284057616, + 1.9734650510025025, + 1.9740613651275636, + 1.970825253868103, + 1.97225927734375, + 1.9708090161132812, + 1.972899548110962, + 1.9718735347747802, + 1.9716896154785157, + 1.9735381398010254, + 1.970859426651001, + 1.970537699661255, + 1.9716636752700805, + 1.971408115081787, + 1.97217986328125, + 1.9708149477767944, + 1.9725973287582397, + 1.970340883255005, + 1.970512617111206 + ], + "train_acc": [ + 0.2506, + 0.26582, + 0.26488, + 0.26812, + 0.2673, + 0.26874, + 0.26924, + 0.27152, + 0.2706, + 0.2732, + 0.27076, + 0.26982, + 0.27112, + 0.27538, + 0.27152, + 0.27576, + 0.27424, + 0.2737, + 0.27522, + 0.27654, + 0.27548, + 0.27776, + 0.27694, + 0.27644, + 0.2755, + 0.27554, + 0.27802, + 0.2774, + 0.2763, + 0.27872, + 0.27982, + 0.27564, + 0.27854, + 0.27758, + 0.27964, + 0.28212, + 0.28362, + 0.28318, + 0.28196, + 0.28094, + 0.28112, + 0.28024, + 0.28322, + 0.28244, + 0.28272, + 0.28128, + 0.28152, + 0.28524, + 0.28126, + 0.28276, + 0.28214, + 0.28506, + 0.2806, + 0.28196, + 0.2851, + 0.2821, + 0.28354, + 0.28526, + 0.28296, + 0.28452, + 0.28764, + 0.28686, + 0.28512, + 0.28712, + 0.28588, + 0.28462, + 0.28596, + 0.2847, + 0.28546, + 0.2833, + 0.2868, + 0.28714, + 0.285, + 0.2887, + 0.28604, + 0.28332, + 0.28698, + 0.28862, + 0.2851, + 0.2889, + 0.29026, + 0.28762, + 0.28606, + 0.2889, + 0.28878, + 0.29028, + 0.2883, + 0.28902, + 0.2895, + 0.28942, + 0.29002, + 0.29058, + 0.28932, + 0.28564, + 0.286, + 0.28974, + 0.28982, + 0.28878, + 0.2885, + 0.2891 + ], + "test_acc": [ + 0.2683, + 0.295, + 0.2917, + 0.2896, + 0.2953, + 0.2912, + 0.311, + 0.2938, + 0.2849, + 0.2891, + 0.2916, + 0.3049, + 0.3017, + 0.3203, + 0.3102, + 0.2822, + 0.286, + 0.3035, + 0.3098, + 0.3009, + 0.3053, + 0.3044, + 0.302, + 0.2901, + 0.3166, + 0.3038, + 0.2909, + 0.2941, + 0.3023, + 0.2937, + 0.3143, + 0.2817, + 0.289, + 0.3074, + 0.3052, + 0.3025, + 0.3089, + 0.3155, + 0.3068, + 0.3125, + 0.3145, + 0.3012, + 0.3152, + 0.315, + 0.3013, + 0.3054, + 0.3112, + 0.2904, + 0.3169, + 0.2959, + 0.3056, + 0.3098, + 0.3101, + 0.3139, + 0.309, + 0.2991, + 0.3182, + 0.3204, + 0.2996, + 0.3108, + 0.3082, + 0.3162, + 0.3227, + 0.3027, + 0.3098, + 0.2966, + 0.309, + 0.3138, + 0.3095, + 0.302, + 0.3007, + 0.3096, + 0.3111, + 0.3052, + 0.3121, + 0.3254, + 0.3128, + 0.3127, + 0.3042, + 0.3189, + 0.3148, + 0.3064, + 0.3145, + 0.3172, + 0.3148, + 0.3157, + 0.3164, + 0.3148, + 0.3139, + 0.3147, + 0.315, + 0.3137, + 0.3139, + 0.3151, + 0.3161, + 0.3159, + 0.3157, + 0.3154, + 0.3157, + 0.3156 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.4173116087913513, + -0.0009379963739775121 + ], + "perturbation_rho": [ + 0.029962807893753052, + 0.0 + ], + "nudging": { + "0.001": [ + -5.238689482212067e-07, + 0.0 + ], + "0.003": [ + -1.4784745872020721e-06, + 1.862645149230957e-09 + ], + "0.01": [ + -4.835892468690872e-06, + -4.6566128730773926e-09 + ] + }, + "hidden_norms_per_layer": [ + 52101.48828125, + 755819456.0, + 2329725696.0 + ], + "bp_grad_norms_per_layer": [ + 3.256158436215628e-07, + 4.5803447146219867e-10, + 4.584113644234833e-10 + ] + }, + "drift": { + "embed.weight": 314.4915524892561, + "embed.bias": 242.71554547665664, + "blocks.0.ln.weight": 9.316701961843108, + "blocks.0.w1.weight": 260.9170515894993, + "blocks.0.w1.bias": 217.2580939372968, + "blocks.0.w2.weight": 445.91774813193825, + "blocks.1.ln.weight": 8.403039866788982, + "blocks.1.w1.weight": 293.10409601703697, + "blocks.1.w1.bias": 296.0137857555933, + "blocks.1.w2.weight": 322.56673718802114, + "out_ln.weight": 0.438240662386626, + "out_head.weight": 7.216506109788919, + "out_head.bias": 4.092364299645094 + } + }, + "fa": { + "log": { + "train_loss": [ + 2.059632821121216, + 1.969141392288208, + 1.9324101000595093, + 1.8976149897003174, + 1.8789895468902589, + 1.8676787176513672, + 1.86150774559021, + 1.8536200240707397, + 1.8464182161712646, + 1.843310608215332, + 1.8476687601470947, + 1.8471106829452515, + 1.8417246197128296, + 1.838934436569214, + 1.8480940073013306, + 1.841720523109436, + 1.8359926760482788, + 1.8341460582733153, + 1.8318497580718993, + 1.8298718184661866, + 1.8337750116729736, + 1.8297588416290282, + 1.8286794381332399, + 1.8261895336532592, + 1.8255341546630859, + 1.8234506717300416, + 1.8259339986801149, + 1.821055384902954, + 1.8205125164794922, + 1.8167257778549195, + 1.8076189421081543, + 1.8187879995727538, + 1.8171783072662353, + 1.8175717837524414, + 1.8142739019775391, + 1.8160600145721435, + 1.8160681335449218, + 1.8097417612075806, + 1.8152825867080689, + 1.8145546726226807, + 1.814115087852478, + 1.8139980951690673, + 1.8110767394256593, + 1.8108615531158447, + 1.8082058542633057, + 1.806761900253296, + 1.8031038238143922, + 1.8006043545150756, + 1.8022277898406982, + 1.8031743975067138, + 1.8007837707138061, + 1.8021138620758057, + 1.8031328768920898, + 1.7993617374038697, + 1.7992449984359742, + 1.7986480782699585, + 1.7985403269195557, + 1.7962297436141967, + 1.7968010370254517, + 1.795757841835022, + 1.7965225546646117, + 1.7958138821792602, + 1.7978459465789796, + 1.795165989151001, + 1.79601362575531, + 1.7981088857269287, + 1.7986928964233397, + 1.796389625015259, + 1.7948159000015258, + 1.795204453163147, + 1.7975346917724608, + 1.7945230379486083, + 1.794725383377075, + 1.7934243726348877, + 1.7920495998382568, + 1.7969142938995362, + 1.7927454508209228, + 1.7913766479110718, + 1.792234444503784, + 1.7971493310928344, + 1.7925569116210938, + 1.794579433517456, + 1.7918265099716186, + 1.7970994007110597, + 1.7937700403213501, + 1.79378436958313, + 1.7918543418121338, + 1.7927718152618408, + 1.792527283859253, + 1.7904213320922853, + 1.791894981842041, + 1.7942055084228516, + 1.7854851037597657, + 1.7892465375518798, + 1.7894356133651734, + 1.792232000427246, + 1.7908443558120728, + 1.7891317191314697, + 1.7882942892456055, + 1.7879403518295287 + ], + "train_acc": [ + 0.24674, + 0.28878, + 0.30256, + 0.31774, + 0.32412, + 0.32786, + 0.33236, + 0.33536, + 0.33926, + 0.34026, + 0.33958, + 0.33756, + 0.34346, + 0.34348, + 0.33962, + 0.34076, + 0.34474, + 0.34456, + 0.34354, + 0.34714, + 0.3452, + 0.347, + 0.34758, + 0.34994, + 0.34704, + 0.34678, + 0.34878, + 0.35086, + 0.3489, + 0.35086, + 0.354, + 0.35062, + 0.35238, + 0.35012, + 0.35416, + 0.3535, + 0.35234, + 0.3533, + 0.35412, + 0.35204, + 0.35228, + 0.35076, + 0.35398, + 0.35318, + 0.35272, + 0.3529, + 0.3562, + 0.35708, + 0.35474, + 0.35498, + 0.35806, + 0.35614, + 0.3567, + 0.359, + 0.35654, + 0.3578, + 0.35942, + 0.35754, + 0.36092, + 0.35994, + 0.36174, + 0.36086, + 0.36242, + 0.36584, + 0.36354, + 0.36262, + 0.3625, + 0.35902, + 0.36222, + 0.36288, + 0.36128, + 0.36226, + 0.36288, + 0.36458, + 0.36242, + 0.36178, + 0.3654, + 0.36442, + 0.3654, + 0.36446, + 0.36594, + 0.36486, + 0.36312, + 0.36548, + 0.36586, + 0.3661, + 0.36682, + 0.36846, + 0.3665, + 0.36726, + 0.3648, + 0.36732, + 0.36906, + 0.36904, + 0.36968, + 0.36782, + 0.3685, + 0.36574, + 0.36768, + 0.36662 + ], + "test_acc": [ + 0.2753, + 0.3239, + 0.3396, + 0.3381, + 0.3497, + 0.3518, + 0.3631, + 0.3416, + 0.3506, + 0.3507, + 0.3488, + 0.3577, + 0.3492, + 0.3608, + 0.3625, + 0.3672, + 0.3648, + 0.373, + 0.3637, + 0.3675, + 0.3612, + 0.3701, + 0.3614, + 0.3567, + 0.3695, + 0.3672, + 0.3684, + 0.3628, + 0.3779, + 0.3671, + 0.3748, + 0.3768, + 0.3622, + 0.3679, + 0.3684, + 0.3674, + 0.365, + 0.365, + 0.3596, + 0.3633, + 0.3614, + 0.3725, + 0.3516, + 0.3689, + 0.3514, + 0.356, + 0.3549, + 0.3407, + 0.3627, + 0.3489, + 0.3519, + 0.3524, + 0.3578, + 0.3532, + 0.3396, + 0.3532, + 0.3527, + 0.3599, + 0.3532, + 0.3571, + 0.3569, + 0.3491, + 0.3603, + 0.3561, + 0.3594, + 0.3522, + 0.3517, + 0.3523, + 0.3618, + 0.3583, + 0.3522, + 0.3556, + 0.3617, + 0.3617, + 0.362, + 0.3692, + 0.3621, + 0.3631, + 0.3579, + 0.3653, + 0.3655, + 0.3626, + 0.3669, + 0.3689, + 0.3658, + 0.3664, + 0.3649, + 0.3635, + 0.3656, + 0.3676, + 0.3651, + 0.3669, + 0.3663, + 0.3669, + 0.3663, + 0.3667, + 0.3665, + 0.3667, + 0.3662, + 0.3661 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.005262219812721014, + 0.9548712968826294 + ], + "perturbation_rho": [ + 0.014450715854763985, + 0.06202582269906998 + ], + "nudging": { + "0.001": [ + 1.9065337255597115e-06, + -7.555587217211723e-06 + ], + "0.003": [ + 5.672394763678312e-06, + -2.293300349265337e-05 + ], + "0.01": [ + 1.8914113752543926e-05, + -7.650378393009305e-05 + ] + }, + "hidden_norms_per_layer": [ + 4985.0673828125, + 267249.8125, + 166942.46875 + ], + "bp_grad_norms_per_layer": [ + 2.8030346584273502e-05, + 1.5286594816643628e-06, + 1.4554038898495492e-06 + ] + }, + "drift": { + "embed.weight": 30.89122351236108, + "embed.bias": 29.388918157004504, + "blocks.0.ln.weight": 1.5688009803019956, + "blocks.0.w1.weight": 22.80179549488676, + "blocks.0.w1.bias": 19.927274930467597, + "blocks.0.w2.weight": 66.38185851388366, + "blocks.1.ln.weight": 1.1343783166936168, + "blocks.1.w1.weight": 17.613374844897635, + "blocks.1.w1.bias": 9.480317921179704, + "blocks.1.w2.weight": 41.63261375754297, + "out_ln.weight": 0.4189237920059734, + "out_head.weight": 4.914939575481723, + "out_head.bias": 17.20838803283335 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 9 + ], + "gpu": 0, + "output_dir": "results/fa_dfa_d512_L2_seed9", + "methods": [ + "fa", + "dfa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/frozen_d512_baselines.log b/results/frozen_d512_baselines.log new file mode 100644 index 0000000..7a1a42d --- /dev/null +++ b/results/frozen_d512_baselines.log @@ -0,0 +1,111 @@ +=== FROZEN BASELINES d=512 === +Start: Sat Apr 25 10:42:45 PM CDT 2026 + d=512 L=4 s=42 (Sat Apr 25 10:42:45 PM CDT 2026) + DFA-shallow: 0.3458 + DFA-frozen: 0.3445 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=4 s=123 (Sat Apr 25 11:22:20 PM CDT 2026) + DFA-shallow: 0.3524 + DFA-frozen: 0.3506 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=4 s=456 (Sun Apr 26 12:01:58 AM CDT 2026) + DFA-shallow: 0.3516 + DFA-frozen: 0.3514 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=2 s=42 (Sun Apr 26 12:41:03 AM CDT 2026) + DFA-shallow: 0.3458 + DFA-frozen: 0.3452 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=2 s=123 (Sun Apr 26 01:20:51 AM CDT 2026) + DFA-shallow: 0.3524 + DFA-frozen: 0.3502 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=2 s=456 (Sun Apr 26 01:59:55 AM CDT 2026) + DFA-shallow: 0.3516 + DFA-frozen: 0.3514 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=8 s=42 (Sun Apr 26 02:39:45 AM CDT 2026) + DFA-shallow: 0.3458 + DFA-frozen: 0.3432 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=8 s=123 (Sun Apr 26 03:19:06 AM CDT 2026) + DFA-shallow: 0.3524 + DFA-frozen: 0.3505 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=8 s=456 (Sun Apr 26 03:58:23 AM CDT 2026) + DFA-shallow: 0.3516 + DFA-frozen: 0.3508 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=12 s=42 (Sun Apr 26 04:37:35 AM CDT 2026) + DFA-shallow: 0.3458 + DFA-frozen: 0.3435 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=12 s=123 (Sun Apr 26 05:17:07 AM CDT 2026) + DFA-shallow: 0.3524 + DFA-frozen: 0.3526 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) + d=512 L=12 s=456 (Sun Apr 26 05:56:51 AM CDT 2026) + DFA-shallow: 0.3516 + DFA-frozen: 0.3513 + +Compare to trainable 4-block ResMLP (3-seed): BP=0.6147 100ep / 0.585 30ep, DFA=0.306 100ep / 0.301 30ep + +Interpretation: + If DFA-frozen ≈ DFA-trainable: blocks are passengers, walk-back parallels ViT + If DFA-frozen << DFA-trainable: ResMLP DFA actually trains the blocks (interesting contrast with ViT) +=== FROZEN BASELINES DONE (Sun Apr 26 06:36:08 AM CDT 2026) === |
