diff options
| -rw-r--r-- | experiments/cifar_resmlp.py | 142 | ||||
| -rw-r--r-- | results/fa_core_experiments.log | 220 | ||||
| -rw-r--r-- | results/fa_depth_scan_d512_L12/results_cifar10.json | 499 | ||||
| -rw-r--r-- | results/fa_depth_scan_d512_L2/results_cifar10.json | 389 | ||||
| -rw-r--r-- | results/fa_depth_scan_d512_L4/results_cifar10.json | 411 | ||||
| -rw-r--r-- | results/fa_depth_scan_d512_L6/results_cifar10.json | 433 | ||||
| -rw-r--r-- | results/fa_depth_scan_d512_L8/results_cifar10.json | 455 | ||||
| -rw-r--r-- | results/fa_depth_sweep_redo.log | 90 | ||||
| -rw-r--r-- | results/fa_early_ckpts/results_cifar10.json | 324 | ||||
| -rw-r--r-- | results/fa_extension_experiments.log | 159 | ||||
| -rw-r--r-- | results/fa_main_audit/results_cifar10.json | 1179 | ||||
| -rw-r--r-- | results/fa_no_penalty_30ep/results_cifar10.json | 549 | ||||
| -rw-r--r-- | results/fa_penalty_30ep/results_cifar10.json | 549 | ||||
| -rw-r--r-- | results/fa_penalty_lam1e-4_30ep/results_cifar10.json | 549 | ||||
| -rw-r--r-- | results/fa_random_targets_s42/results_cifar10.json | 411 | ||||
| -rw-r--r-- | results/fa_smoke_test/results_cifar10.json | 120 |
16 files changed, 6477 insertions, 2 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 7aba671..05a355d 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -229,6 +229,116 @@ def train_dfa(model, train_loader, test_loader, device, args): # ============================================================================= +# Vanilla FA (Lillicrap 2016) +# ============================================================================= +def train_fa(model, train_loader, test_loader, device, args): + """ + Vanilla Feedback Alignment (Lillicrap et al. 2016). + Unlike DFA (which projects output error directly to each layer via + a_l = B_l^T @ e_T), FA propagates credit sequentially backward through + the block stack using fixed random d×d feedback matrices: + a_L = exact gradient at h_L through out_head + out_ln + a_l = B_l @ a_{l+1} (random d×d replaces block Jacobian transpose) + Each block is updated with the same local loss as DFA: <f_l(h_l), a_l>. + """ + d = model.d_hidden + num_classes = args.num_classes + L = model.num_blocks + + # Fixed random feedback matrices: d × d (one per block). + # These replace the transpose of the block Jacobian dF_l/dh_l in the + # backward pass. Contrast with DFA's B_l which are d × num_classes. + Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)] + + # Same optimizer structure as DFA + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + if getattr(args, 'random_targets', False): + y = torch.randint(0, args.num_classes, y.shape, device=device) + batch = x.size(0) + + # Forward pass + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + + # 1. Update output head (exact CE gradient, h_L detached) + hL_det = hiddens[-1].detach().requires_grad_(True) + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Exact gradient at h_L — FA's starting credit signal + a_credit = hL_det.grad.detach() # (batch, d) + + # 2. Update each block with FA credit (backward sequential) + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + # Normalize credit + rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_credit / rms + # Local surrogate (same form as DFA) + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + block_opts[l].step() + + # Propagate credit backward: FA replaces block Jacobian^T with B_l + a_credit = (a_credit @ Bs[l]).detach() + + # 3. Update embedding with FA credit at h_0 + rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_credit / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + + train_loss = total_loss / total + train_acc = correct / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['train_acc'].append(train_acc) + log['test_acc'].append(test_acc) + if epoch % 10 == 0 or epoch == 1: + print(f" [FA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") + + return log, Bs + + +# ============================================================================= # State Bridge # ============================================================================= def train_state_bridge(model, train_loader, test_loader, device, args): @@ -621,6 +731,18 @@ def compute_diagnostics(model, method_name, test_loader, device, args, 'bp_grad_norms_per_layer': bp_grad_norms_per_layer, } + # Pre-compute FA credits if needed (sequential backward from exact h_L gradient) + _fa_credits = None + if method_name == 'fa' and dfa_Bs is not None: + hL_req = hiddens[L].detach().requires_grad_(True) + logits_fa = model.out_head(model.out_ln(hL_req)) + loss_fa = F.cross_entropy(logits_fa, y, reduction='sum') + _fa_a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach() + _fa_credits = [None] * L + _fa_credits[L - 1] = _fa_a_L + for ll in range(L - 2, -1, -1): + _fa_credits[ll] = (_fa_credits[ll + 1] @ dfa_Bs[ll + 1]).detach() + for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) @@ -630,6 +752,8 @@ def compute_diagnostics(model, method_name, test_loader, device, args, a_l = bp_grads[l] elif method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'fa': + a_l = _fa_credits[l] elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_predictor(h_l_req, t_l, s) @@ -720,6 +844,20 @@ def run_experiment(args): seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") + # ---- FA (vanilla Feedback Alignment, Lillicrap 2016) ---- + if 'fa' in methods_to_run: + print("\n--- FA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_fa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_fa = {n: p.clone().detach() for n, p in model_fa.named_parameters()} + fa_log, fa_Bs = train_fa(model_fa, train_loader, test_loader, device, args) + fa_diag = compute_diagnostics(model_fa, 'fa', test_loader, device, args, dfa_Bs=fa_Bs) + fa_drift = feature_drift(init_fa, {n: p.detach() for n, p in model_fa.named_parameters()}) + seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag, 'drift': fa_drift} + print(f" Final test acc: {fa_log['test_acc'][-1]:.4f}") + # ---- State Bridge ---- if 'state_bridge' in methods_to_run: print("\n--- State Bridge ---") @@ -793,8 +931,8 @@ def main(): parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/cifar10') - parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'state_bridge', 'credit_bridge'], - help='Subset of methods to run.') + parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'fa', 'state_bridge', 'credit_bridge'], + help='Subset of methods to run. fa = vanilla Feedback Alignment (Lillicrap 2016).') parser.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') parser.add_argument('--penalty_lam', type=float, default=0.0, diff --git a/results/fa_core_experiments.log b/results/fa_core_experiments.log new file mode 100644 index 0000000..45e92c2 --- /dev/null +++ b/results/fa_core_experiments.log @@ -0,0 +1,220 @@ +========================================== +FA CORE EXPERIMENTS (A-H) +========================================== +Start: Wed Apr 22 09:42:02 PM CDT 2026 + +=== A: FA main audit (100ep, 3 seeds) === +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0491, train=0.2436, test=0.2789 + [FA] Epoch 10: loss=1.8849, train=0.3172, test=0.3416 + [FA] Epoch 20: loss=1.8524, train=0.3333, test=0.3609 + [FA] Epoch 30: loss=1.8295, train=0.3444, test=0.3714 + [FA] Epoch 40: loss=1.8168, train=0.3508, test=0.3823 + [FA] Epoch 50: loss=1.8001, train=0.3588, test=0.3842 + [FA] Epoch 60: loss=1.7960, train=0.3575, test=0.3806 + [FA] Epoch 70: loss=1.7859, train=0.3609, test=0.3892 + [FA] Epoch 80: loss=1.7819, train=0.3658, test=0.3919 + [FA] Epoch 90: loss=1.7770, train=0.3677, test=0.3913 + [FA] Epoch 100: loss=1.7776, train=0.3685, test=0.3929 + Final test acc: 0.3929 + +============================================================ +Seed 123 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0394, train=0.2504, test=0.2905 + [FA] Epoch 10: loss=1.8597, train=0.3285, test=0.3514 + [FA] Epoch 20: loss=1.8128, train=0.3518, test=0.3801 + [FA] Epoch 30: loss=1.7718, train=0.3646, test=0.3944 + [FA] Epoch 40: loss=1.7571, train=0.3705, test=0.3937 + [FA] Epoch 50: loss=1.7417, train=0.3765, test=0.4046 + [FA] Epoch 60: loss=1.7407, train=0.3791, test=0.4087 + [FA] Epoch 70: loss=1.7346, train=0.3839, test=0.4048 + [FA] Epoch 80: loss=1.7344, train=0.3828, test=0.4070 + [FA] Epoch 90: loss=1.7341, train=0.3838, test=0.4093 + [FA] Epoch 100: loss=1.7313, train=0.3814, test=0.4099 + Final test acc: 0.4099 + +============================================================ +Seed 456 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0641, train=0.2401, test=0.2713 + [FA] Epoch 10: loss=1.8796, train=0.3171, test=0.3422 + [FA] Epoch 20: loss=1.8405, train=0.3365, test=0.3666 + [FA] Epoch 30: loss=1.8124, train=0.3511, test=0.3808 + [FA] Epoch 40: loss=1.7906, train=0.3597, test=0.3872 + [FA] Epoch 50: loss=1.7753, train=0.3662, test=0.3900 + [FA] Epoch 60: loss=1.7704, train=0.3685, test=0.3957 + [FA] Epoch 70: loss=1.7610, train=0.3737, test=0.3968 + [FA] Epoch 80: loss=1.7538, train=0.3723, test=0.3999 + [FA] Epoch 90: loss=1.7540, train=0.3762, test=0.3992 + [FA] Epoch 100: loss=1.7500, train=0.3750, test=0.3996 + Final test acc: 0.3996 + +All results saved to results/fa_main_audit/results_cifar10.json + +=== B: FA+penalty lam=1e-2 (30ep, 3 seeds) === +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0011, train=0.2748, test=0.3237 + [FA] Epoch 10: loss=1.9015, train=0.3259, test=0.3462 + [FA] Epoch 20: loss=1.8762, train=0.3376, test=0.3620 + [FA] Epoch 30: loss=1.8683, train=0.3472, test=0.3713 + Final test acc: 0.3713 + +============================================================ +Seed 123 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=1.9927, train=0.2807, test=0.3339 + [FA] Epoch 10: loss=1.9047, train=0.3250, test=0.3513 + [FA] Epoch 20: loss=1.8867, train=0.3368, test=0.3600 + [FA] Epoch 30: loss=1.8784, train=0.3428, test=0.3660 + Final test acc: 0.3660 + +============================================================ +Seed 456 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0075, train=0.2735, test=0.3312 + [FA] Epoch 10: loss=1.9005, train=0.3261, test=0.3561 + [FA] Epoch 20: loss=1.8808, train=0.3402, test=0.3682 + [FA] Epoch 30: loss=1.8698, train=0.3499, test=0.3695 + Final test acc: 0.3695 + +All results saved to results/fa_penalty_30ep/results_cifar10.json + +=== C: FA no-pen 30ep (3 seeds) === +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0491, train=0.2436, test=0.2789 + [FA] Epoch 10: loss=1.8842, train=0.3176, test=0.3427 + [FA] Epoch 20: loss=1.8545, train=0.3324, test=0.3630 + [FA] Epoch 30: loss=1.8426, train=0.3414, test=0.3655 + Final test acc: 0.3655 + +============================================================ +Seed 123 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0394, train=0.2504, test=0.2905 + [FA] Epoch 10: loss=1.8601, train=0.3275, test=0.3514 + [FA] Epoch 20: loss=1.8234, train=0.3491, test=0.3753 + [FA] Epoch 30: loss=1.8114, train=0.3527, test=0.3793 + Final test acc: 0.3793 + +============================================================ +Seed 456 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0641, train=0.2401, test=0.2713 + [FA] Epoch 10: loss=1.8808, train=0.3167, test=0.3426 + [FA] Epoch 20: loss=1.8417, train=0.3367, test=0.3677 + [FA] Epoch 30: loss=1.8354, train=0.3412, test=0.3699 + Final test acc: 0.3699 + +All results saved to results/fa_no_penalty_30ep/results_cifar10.json + +=== D: SKIPPED (frozen blocks needs separate script) === + +=== E: No-terminal-LN FA (100ep, 3 seeds) === +usage: snapshot_evolution_no_outln.py [-h] [--output_dir OUTPUT_DIR] + [--epochs EPOCHS] [--lr LR] [--wd WD] + [--seed SEED] [--depth DEPTH] + [--d_hidden D_HIDDEN] +snapshot_evolution_no_outln.py: error: unrecognized arguments: --method fa + [E] seed 42: snapshot_evolution_no_outln.py may not support FA yet — skipping +usage: snapshot_evolution_no_outln.py [-h] [--output_dir OUTPUT_DIR] + [--epochs EPOCHS] [--lr LR] [--wd WD] + [--seed SEED] [--depth DEPTH] + [--d_hidden D_HIDDEN] +snapshot_evolution_no_outln.py: error: unrecognized arguments: --method fa + [E] seed 123: snapshot_evolution_no_outln.py may not support FA yet — skipping +usage: snapshot_evolution_no_outln.py [-h] [--output_dir OUTPUT_DIR] + [--epochs EPOCHS] [--lr LR] [--wd WD] + [--seed SEED] [--depth DEPTH] + [--d_hidden D_HIDDEN] +snapshot_evolution_no_outln.py: error: unrecognized arguments: --method fa + [E] seed 456: snapshot_evolution_no_outln.py may not support FA yet — skipping + +=== F: Random-target FA (100ep, s42) === +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.3211, train=0.0992, test=0.0948 + [FA] Epoch 10: loss=2.3113, train=0.1016, test=0.0894 + [FA] Epoch 20: loss=2.3071, train=0.1004, test=0.1045 + [FA] Epoch 30: loss=2.3050, train=0.0974, test=0.1226 + [FA] Epoch 40: loss=2.3039, train=0.1005, test=0.0837 + [FA] Epoch 50: loss=2.3034, train=0.1006, test=0.0945 + [FA] Epoch 60: loss=2.3032, train=0.0982, test=0.1195 + [FA] Epoch 70: loss=2.3029, train=0.0994, test=0.1104 + [FA] Epoch 80: loss=2.3027, train=0.1005, test=0.0927 + [FA] Epoch 90: loss=2.3026, train=0.1001, test=0.1170 + [FA] Epoch 100: loss=2.3026, train=0.0988, test=0.1200 + Final test acc: 0.1200 + +All results saved to results/fa_random_targets_s42/results_cifar10.json + +=== G: FA early checkpoints (5ep, 3 seeds) === +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0491, train=0.2436, test=0.2789 + Final test acc: 0.3219 + +============================================================ +Seed 123 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0394, train=0.2504, test=0.2905 + Final test acc: 0.3478 + +============================================================ +Seed 456 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0641, train=0.2401, test=0.2713 + Final test acc: 0.3242 + +All results saved to results/fa_early_ckpts/results_cifar10.json + +=== H: FA trajectory — covered by A (log has per-epoch values) === + +========================================== +ALL CORE FA EXPERIMENTS DONE +End: Wed Apr 22 10:21:51 PM CDT 2026 +========================================== diff --git a/results/fa_depth_scan_d512_L12/results_cifar10.json b/results/fa_depth_scan_d512_L12/results_cifar10.json new file mode 100644 index 0000000..15a5792 --- /dev/null +++ b/results/fa_depth_scan_d512_L12/results_cifar10.json @@ -0,0 +1,499 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.0427230606079103, + 1.9547925742340089, + 1.9269662255096436, + 1.9024112060546876, + 1.88389292137146, + 1.874641305809021, + 1.8671076870727539, + 1.8585125997543335, + 1.8507517318725586, + 1.8509845052719116, + 1.8455515924835204, + 1.8402700800323486, + 1.8322524318695068, + 1.8287768964385986, + 1.824175419845581, + 1.8240570126342774, + 1.817736747970581, + 1.8165113877105712, + 1.8092418829345702, + 1.8068542711639404, + 1.8066692431640625, + 1.8037006209945678, + 1.7959260940551758, + 1.797553392868042, + 1.7958658642196654, + 1.7940396698379517, + 1.7932731056976319, + 1.7874454331207275, + 1.7880735418319702, + 1.7836660583496093, + 1.7832960818862915, + 1.7804118716812134, + 1.7788523398208618, + 1.772011728515625, + 1.7708375668716432, + 1.7682569339370728, + 1.7717147860336304, + 1.7703663568115235, + 1.7637487329864503, + 1.7676782482910156, + 1.7649735053253175, + 1.7634011583709717, + 1.7645813860702515, + 1.7605182119369507, + 1.7599670697021483, + 1.7621903922271729, + 1.756671106300354, + 1.7577124716949464, + 1.7564089126205444, + 1.7521357986450194, + 1.7525928847885133, + 1.7497260301971436, + 1.7528332720184325, + 1.7475884844207763, + 1.7469696446990968, + 1.7445055263519287, + 1.752090754776001, + 1.7456406805038451, + 1.747734097251892, + 1.7441029932403564, + 1.743490291824341, + 1.7414577841949463, + 1.742126556777954, + 1.7391094409942627, + 1.7401588076019288, + 1.7370625415420533, + 1.7388630017471314, + 1.7353392750930785, + 1.736816965751648, + 1.7355695288085937, + 1.73539924369812, + 1.7323005257797242, + 1.7321935805511475, + 1.7312652449798585, + 1.7298928540420533, + 1.73121183719635, + 1.7288608194351196, + 1.7294834777069092, + 1.726299252166748, + 1.73074989528656, + 1.7293904189682008, + 1.7276834594345092, + 1.7300372393035888, + 1.731684062461853, + 1.7275019702529908, + 1.729561651916504, + 1.7284230736923218, + 1.7257295697021484, + 1.7226565048217772, + 1.726801979408264, + 1.7250631372833252, + 1.7279036038970947, + 1.726708238143921, + 1.7242388882064819, + 1.7248469045257568, + 1.7260398196411133, + 1.721121542892456, + 1.7220706603240967, + 1.7268135174942016, + 1.729515742111206 + ], + "train_acc": [ + 0.24062, + 0.28346, + 0.2951, + 0.30692, + 0.31778, + 0.3211, + 0.32388, + 0.32854, + 0.33456, + 0.33332, + 0.33668, + 0.33984, + 0.34108, + 0.34376, + 0.346, + 0.34636, + 0.35004, + 0.34922, + 0.35104, + 0.35196, + 0.35292, + 0.35532, + 0.3572, + 0.35614, + 0.3578, + 0.35934, + 0.35694, + 0.35918, + 0.36018, + 0.35928, + 0.3622, + 0.362, + 0.36236, + 0.36548, + 0.36778, + 0.36536, + 0.36884, + 0.367, + 0.36852, + 0.36842, + 0.36944, + 0.36784, + 0.3697, + 0.37156, + 0.37064, + 0.3725, + 0.37072, + 0.37062, + 0.37082, + 0.37298, + 0.37268, + 0.37554, + 0.3732, + 0.3752, + 0.37724, + 0.37652, + 0.37506, + 0.37708, + 0.37712, + 0.37668, + 0.37814, + 0.37944, + 0.37878, + 0.37886, + 0.38042, + 0.3778, + 0.38022, + 0.37988, + 0.38096, + 0.3807, + 0.3811, + 0.38084, + 0.37994, + 0.37914, + 0.38104, + 0.38152, + 0.38376, + 0.38158, + 0.38472, + 0.38244, + 0.38066, + 0.38392, + 0.3839, + 0.38222, + 0.38382, + 0.38384, + 0.38472, + 0.385, + 0.38792, + 0.3839, + 0.38516, + 0.3858, + 0.38262, + 0.38758, + 0.38396, + 0.38268, + 0.38628, + 0.38618, + 0.3846, + 0.38336 + ], + "test_acc": [ + 0.2963, + 0.3132, + 0.3406, + 0.3377, + 0.3353, + 0.3407, + 0.3367, + 0.3657, + 0.3618, + 0.3712, + 0.3519, + 0.3722, + 0.371, + 0.379, + 0.3666, + 0.372, + 0.3674, + 0.3754, + 0.3733, + 0.3747, + 0.3712, + 0.3772, + 0.3792, + 0.3804, + 0.3742, + 0.3852, + 0.3864, + 0.3907, + 0.3861, + 0.3827, + 0.3852, + 0.3833, + 0.3885, + 0.3865, + 0.3933, + 0.391, + 0.3837, + 0.3915, + 0.4, + 0.4088, + 0.3956, + 0.3878, + 0.407, + 0.3928, + 0.3983, + 0.4042, + 0.4064, + 0.4019, + 0.3975, + 0.3905, + 0.399, + 0.3953, + 0.3981, + 0.3986, + 0.399, + 0.3963, + 0.4013, + 0.3988, + 0.401, + 0.4042, + 0.3978, + 0.399, + 0.4035, + 0.4022, + 0.4052, + 0.4033, + 0.4035, + 0.4041, + 0.4026, + 0.4046, + 0.4012, + 0.3996, + 0.4052, + 0.4034, + 0.3962, + 0.4036, + 0.3996, + 0.4008, + 0.4053, + 0.4037, + 0.4033, + 0.4058, + 0.408, + 0.4066, + 0.4017, + 0.4038, + 0.4015, + 0.4035, + 0.4052, + 0.4014, + 0.4046, + 0.4027, + 0.4046, + 0.4036, + 0.4035, + 0.4042, + 0.4039, + 0.4033, + 0.4034, + 0.4035 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.04570477455854416, + 0.10783404111862183, + -0.03488321602344513, + -0.0594203844666481, + -0.04577865079045296, + -0.024857502430677414, + -0.03406952694058418, + 0.030555542558431625, + -0.00125521095469594, + 0.010614164173603058, + 0.05838795006275177, + 0.9949829578399658 + ], + "perturbation_rho": [ + -0.010346438735723495, + 0.04423713684082031, + 0.01146540604531765, + 0.03663264587521553, + -0.011911284178495407, + 0.0016591008752584457, + -0.02494950033724308, + 0.02218942530453205, + 0.0039163315668702126, + 0.07129630446434021, + 0.00445366557687521, + -0.01873880997300148 + ], + "nudging": { + "0.001": [ + -2.829881850630045e-06, + -3.539607860147953e-07, + 7.35744833946228e-08, + 8.218921720981598e-08, + 4.21423465013504e-08, + 2.7241185307502747e-08, + 7.008202373981476e-08, + -3.632158041000366e-08, + -3.958120942115784e-09, + -3.073364496231079e-08, + -9.918585419654846e-08, + -1.0794028639793396e-06 + ], + "0.003": [ + -8.66112532094121e-06, + -9.669456630945206e-07, + 1.387670636177063e-07, + 2.7869828045368195e-07, + 1.7974525690078735e-07, + 1.0058283805847168e-07, + 1.448206603527069e-07, + -2.1746382117271423e-07, + -1.0011717677116394e-08, + -1.1711381375789642e-07, + -2.377200871706009e-07, + -3.898283466696739e-06 + ], + "0.01": [ + -2.8825539629906416e-05, + -3.2445532269775867e-06, + 4.411558620631695e-07, + 9.136274456977844e-07, + 6.3673360273242e-07, + 3.696768544614315e-07, + 4.895846359431744e-07, + -5.699694156646729e-07, + -7.869675755500793e-08, + -1.8370337784290314e-07, + -8.582137525081635e-07, + -1.4071993064135313e-05 + ] + }, + "hidden_norms_per_layer": [ + 7405.54296875, + 126799.640625, + 722124.125, + 1197351.125, + 1311770.625, + 1445317.875, + 1575290.0, + 1646774.0, + 1672649.625, + 1692752.625, + 1726090.75, + 1747446.5, + 1108917.0 + ], + "bp_grad_norms_per_layer": [ + 2.2136026018415578e-05, + 1.0887044936680468e-06, + 6.514224537568225e-07, + 6.450830483117898e-07, + 6.397550009751285e-07, + 6.399715175575693e-07, + 6.393033231688605e-07, + 6.382227297763166e-07, + 6.320739203147241e-07, + 6.241282335395226e-07, + 6.161063197396288e-07, + 6.117401767369302e-07, + 5.839314667355211e-07 + ] + }, + "drift": { + "embed.weight": 52.2471866434948, + "embed.bias": 18.97250327390356, + "blocks.0.ln.weight": 1.2566883045780834, + "blocks.0.w1.weight": 17.12006786604635, + "blocks.0.w1.bias": 14.342858083073688, + "blocks.0.w2.weight": 63.004850525998116, + "blocks.1.ln.weight": 1.0284467630081824, + "blocks.1.w1.weight": 21.225831949490352, + "blocks.1.w1.bias": 18.554260026881323, + "blocks.1.w2.weight": 42.67940842745791, + "blocks.2.ln.weight": 0.5431388174757359, + "blocks.2.w1.weight": 20.036899880398536, + "blocks.2.w1.bias": 21.66446091790098, + "blocks.2.w2.weight": 26.571695423395656, + "blocks.3.ln.weight": 0.4791053623183397, + "blocks.3.w1.weight": 17.857825652507895, + "blocks.3.w1.bias": 19.196407269205075, + "blocks.3.w2.weight": 24.453272281036483, + "blocks.4.ln.weight": 0.40179147711472446, + "blocks.4.w1.weight": 16.142007729189494, + "blocks.4.w1.bias": 18.38189338630125, + "blocks.4.w2.weight": 22.251118320035413, + "blocks.5.ln.weight": 0.3999388280630244, + "blocks.5.w1.weight": 16.08867610628456, + "blocks.5.w1.bias": 18.228301269113516, + "blocks.5.w2.weight": 22.282579138688718, + "blocks.6.ln.weight": 0.5139340433841122, + "blocks.6.w1.weight": 16.498928955337856, + "blocks.6.w1.bias": 17.341930385247757, + "blocks.6.w2.weight": 44.47264444577482, + "blocks.7.ln.weight": 0.5190564871460961, + "blocks.7.w1.weight": 15.921519106035, + "blocks.7.w1.bias": 13.554678640598476, + "blocks.7.w2.weight": 55.375944465972985, + "blocks.8.ln.weight": 0.5783217006624557, + "blocks.8.w1.weight": 15.308415657722076, + "blocks.8.w1.bias": 12.359797369765815, + "blocks.8.w2.weight": 60.55354023990087, + "blocks.9.ln.weight": 0.5265177438656736, + "blocks.9.w1.weight": 14.699539264919315, + "blocks.9.w1.bias": 12.199485985244364, + "blocks.9.w2.weight": 53.496852348666465, + "blocks.10.ln.weight": 0.5541756864436261, + "blocks.10.w1.weight": 14.732920668123558, + "blocks.10.w1.bias": 10.868519121164393, + "blocks.10.w2.weight": 61.24451839129346, + "blocks.11.ln.weight": 0.6274670342096703, + "blocks.11.w1.weight": 18.31665424290367, + "blocks.11.w1.bias": 18.6607903814988, + "blocks.11.w2.weight": 57.15975509546593, + "out_ln.weight": 0.317930900746446, + "out_head.weight": 6.3670164706750825, + "out_head.bias": 0.940628348025828 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 12, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_depth_scan_d512", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_depth_scan_d512_L2/results_cifar10.json b/results/fa_depth_scan_d512_L2/results_cifar10.json new file mode 100644 index 0000000..0c41380 --- /dev/null +++ b/results/fa_depth_scan_d512_L2/results_cifar10.json @@ -0,0 +1,389 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.061217526626587, + 1.9565438193130493, + 1.9224825448608398, + 1.9001438703155518, + 1.8821109796524047, + 1.8704035317230225, + 1.857736103477478, + 1.844048505897522, + 1.8342980657958985, + 1.8290129190063475, + 1.8224646506118773, + 1.8214716808319091, + 1.8136491051483155, + 1.8157705508041382, + 1.8199064575195312, + 1.816836177368164, + 1.8065774325180053, + 1.8125089934921266, + 1.8090396573638916, + 1.81023629032135, + 1.8082425650787353, + 1.8127569119644165, + 1.808132509727478, + 1.8062178915405274, + 1.8039376613998412, + 1.8006342431259155, + 1.797272424888611, + 1.795366856956482, + 1.795199264755249, + 1.796318228149414, + 1.7931722381973267, + 1.790380922317505, + 1.787453678970337, + 1.7863555522918702, + 1.7848892670059204, + 1.782602568359375, + 1.7844260739135742, + 1.7822325122833251, + 1.7821675496673584, + 1.7775094485092162, + 1.775276083984375, + 1.772415744934082, + 1.771381389389038, + 1.7729799353790283, + 1.7665789907073974, + 1.7650687371063232, + 1.7673343083953856, + 1.7628559505844117, + 1.7635160286712646, + 1.7610479801177978, + 1.757714970741272, + 1.7587621269989013, + 1.75955526802063, + 1.7605536008453369, + 1.7632213006973267, + 1.7597938927459718, + 1.764720836868286, + 1.7602048022460937, + 1.7601288507843018, + 1.7592250678253174, + 1.763182290611267, + 1.7588830117034913, + 1.7620349166107179, + 1.762400091934204, + 1.7608643816375733, + 1.7610327990341186, + 1.7634911743164063, + 1.761160830154419, + 1.7584447234344482, + 1.758816735496521, + 1.7580244549560546, + 1.758498070678711, + 1.757904008178711, + 1.7605663097381592, + 1.7611833373260497, + 1.7573552610778809, + 1.7571686492538452, + 1.7577348608016967, + 1.7563239881134034, + 1.7563734949493408, + 1.757859500465393, + 1.7567143533325196, + 1.7577434346771241, + 1.75511939453125, + 1.7555543961334228, + 1.756235351524353, + 1.7555060604095458, + 1.7561510982894897, + 1.7538081964874268, + 1.7542751317977905, + 1.7549605862426758, + 1.7509819832611084, + 1.755103257446289, + 1.7511512590789795, + 1.7497140390396118, + 1.755158496170044, + 1.7556276675033569, + 1.7520254708862304, + 1.7512920989227294, + 1.7526899521255492 + ], + "train_acc": [ + 0.24758, + 0.29054, + 0.30476, + 0.31646, + 0.3238, + 0.33244, + 0.334, + 0.33794, + 0.34238, + 0.34354, + 0.3469, + 0.34778, + 0.34922, + 0.34738, + 0.34754, + 0.3477, + 0.3522, + 0.35128, + 0.3501, + 0.34894, + 0.3502, + 0.34714, + 0.35396, + 0.35312, + 0.35266, + 0.35432, + 0.35642, + 0.35256, + 0.35418, + 0.35464, + 0.35604, + 0.35646, + 0.35824, + 0.35708, + 0.35856, + 0.35934, + 0.35824, + 0.35668, + 0.36006, + 0.36048, + 0.36242, + 0.3658, + 0.36572, + 0.36512, + 0.36582, + 0.36634, + 0.36442, + 0.36774, + 0.3675, + 0.36846, + 0.36932, + 0.36868, + 0.3692, + 0.36584, + 0.36652, + 0.3708, + 0.36522, + 0.36784, + 0.36796, + 0.37036, + 0.36874, + 0.37152, + 0.37144, + 0.37078, + 0.36992, + 0.37114, + 0.3711, + 0.37028, + 0.37068, + 0.37472, + 0.3726, + 0.372, + 0.37436, + 0.37286, + 0.37184, + 0.37572, + 0.37322, + 0.37538, + 0.37558, + 0.37514, + 0.37416, + 0.37244, + 0.37488, + 0.3776, + 0.37426, + 0.37688, + 0.37648, + 0.3746, + 0.37672, + 0.37686, + 0.3762, + 0.37536, + 0.3752, + 0.37636, + 0.37886, + 0.37486, + 0.3733, + 0.37878, + 0.3779, + 0.37676 + ], + "test_acc": [ + 0.3028, + 0.3196, + 0.3402, + 0.3524, + 0.3584, + 0.3591, + 0.3566, + 0.3543, + 0.3712, + 0.3705, + 0.3579, + 0.3373, + 0.3627, + 0.3557, + 0.3414, + 0.3688, + 0.3572, + 0.3632, + 0.354, + 0.3634, + 0.362, + 0.3549, + 0.363, + 0.3531, + 0.345, + 0.3459, + 0.3517, + 0.3421, + 0.3294, + 0.3398, + 0.3294, + 0.3409, + 0.3476, + 0.3318, + 0.3477, + 0.3302, + 0.3283, + 0.3309, + 0.3337, + 0.3497, + 0.3314, + 0.3292, + 0.3374, + 0.3361, + 0.3335, + 0.3461, + 0.3254, + 0.3366, + 0.3353, + 0.3288, + 0.3467, + 0.3374, + 0.3445, + 0.3415, + 0.3358, + 0.3474, + 0.3382, + 0.333, + 0.3356, + 0.3376, + 0.3356, + 0.3418, + 0.3358, + 0.3446, + 0.35, + 0.3412, + 0.3452, + 0.3429, + 0.3438, + 0.3421, + 0.3477, + 0.3497, + 0.3499, + 0.3473, + 0.3451, + 0.3504, + 0.3474, + 0.3503, + 0.344, + 0.3497, + 0.3487, + 0.3488, + 0.3521, + 0.3509, + 0.3455, + 0.3511, + 0.3469, + 0.3474, + 0.353, + 0.3472, + 0.3472, + 0.3501, + 0.3493, + 0.3492, + 0.349, + 0.3496, + 0.3496, + 0.3494, + 0.3494, + 0.3495 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.03441809490323067, + 0.955886721611023 + ], + "perturbation_rho": [ + 0.04754525423049927, + 0.05055658146739006 + ], + "nudging": { + "0.001": [ + -5.270587280392647e-06, + -7.06920400261879e-06 + ], + "0.003": [ + -1.57852191478014e-05, + -2.1289335563778877e-05 + ], + "0.01": [ + -5.261087790131569e-05, + -7.100868970155716e-05 + ] + }, + "hidden_norms_per_layer": [ + 4686.95556640625, + 94784.5546875, + 138362.359375 + ], + "bp_grad_norms_per_layer": [ + 1.9998093193862587e-05, + 1.4634815670433454e-06, + 1.188113401440205e-06 + ] + }, + "drift": { + "embed.weight": 28.586715545463985, + "embed.bias": 19.716378864552336, + "blocks.0.ln.weight": 1.4738119287890208, + "blocks.0.w1.weight": 26.045392429453724, + "blocks.0.w1.bias": 15.939058582542359, + "blocks.0.w2.weight": 59.12906119267325, + "blocks.1.ln.weight": 1.2124303196678357, + "blocks.1.w1.weight": 16.696964117465345, + "blocks.1.w1.bias": 8.535990814749171, + "blocks.1.w2.weight": 37.85731940298285, + "out_ln.weight": 0.5125606883018807, + "out_head.weight": 3.721893823331011, + "out_head.bias": 11.592811594248388 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 2, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_depth_scan_d512_L2", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_depth_scan_d512_L4/results_cifar10.json b/results/fa_depth_scan_d512_L4/results_cifar10.json new file mode 100644 index 0000000..87a1f92 --- /dev/null +++ b/results/fa_depth_scan_d512_L4/results_cifar10.json @@ -0,0 +1,411 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.0300785620117185, + 1.9446230431365967, + 1.923573359298706, + 1.9025392212677001, + 1.8942920899963378, + 1.8785306893157958, + 1.8725497973251344, + 1.8588989192962646, + 1.8521492172241212, + 1.8486995751571655, + 1.8387074700546264, + 1.8285302072906495, + 1.8229916802597046, + 1.8140362118911744, + 1.812324379043579, + 1.8051523593902588, + 1.8015315892791748, + 1.7916892041015624, + 1.7894653066253663, + 1.7864193521118163, + 1.7835033431243896, + 1.7791855911636352, + 1.7703584003448487, + 1.7708648685073853, + 1.7641783823394774, + 1.762984627342224, + 1.7599577627182006, + 1.7553836069107056, + 1.75658689201355, + 1.7510007577896118, + 1.7528975827026367, + 1.7468154891586303, + 1.7496728462982178, + 1.7414471523666382, + 1.7459090908050536, + 1.7412950765228272, + 1.7402052612304688, + 1.7362829974365235, + 1.734298306350708, + 1.7386894550323486, + 1.735369085960388, + 1.7248860055541992, + 1.7300375979232787, + 1.7265542502212525, + 1.72896113155365, + 1.724715986366272, + 1.723635640487671, + 1.7220598682022095, + 1.717545898704529, + 1.7209024740600587, + 1.7202868030548095, + 1.716822794113159, + 1.7148029877090454, + 1.7123187323760987, + 1.7099079937744142, + 1.7115138537597656, + 1.7089375814056396, + 1.710836911201477, + 1.7087473178863526, + 1.7052376666641236, + 1.7051094373321534, + 1.7055147225952147, + 1.7014220635986328, + 1.6999273017120362, + 1.6991235815811156, + 1.697499723777771, + 1.6986251516342163, + 1.6985239101409912, + 1.6973113412857055, + 1.694503396911621, + 1.6924839485549927, + 1.6907986538314819, + 1.6948804891204834, + 1.6908944356536866, + 1.694915372390747, + 1.692836050567627, + 1.690400951499939, + 1.68655061794281, + 1.6889203528213501, + 1.686811109046936, + 1.688070924682617, + 1.6837623129653931, + 1.684536149520874, + 1.6821202702713012, + 1.6863232943344115, + 1.6847787243652343, + 1.6850631763076782, + 1.6828089485931397, + 1.6857323441314698, + 1.683035474510193, + 1.6798774777221679, + 1.6814900134658815, + 1.6816223468017577, + 1.6811535089874268, + 1.681296463279724, + 1.6815302163696288, + 1.6837011975097655, + 1.6801897742843628, + 1.6809225379180908, + 1.6781013634490967 + ], + "train_acc": [ + 0.25312, + 0.29106, + 0.3008, + 0.31326, + 0.31492, + 0.3233, + 0.32324, + 0.33074, + 0.33364, + 0.33658, + 0.3413, + 0.34538, + 0.3485, + 0.34936, + 0.3499, + 0.3504, + 0.35564, + 0.35788, + 0.35942, + 0.36092, + 0.36132, + 0.36218, + 0.36636, + 0.36582, + 0.3677, + 0.3675, + 0.36926, + 0.37292, + 0.37044, + 0.37244, + 0.37438, + 0.37436, + 0.37378, + 0.37712, + 0.37622, + 0.37694, + 0.37552, + 0.37782, + 0.37806, + 0.37674, + 0.3797, + 0.38284, + 0.3805, + 0.38056, + 0.3796, + 0.38304, + 0.38512, + 0.38384, + 0.38492, + 0.3875, + 0.38534, + 0.38444, + 0.38538, + 0.38828, + 0.38914, + 0.3886, + 0.38882, + 0.38966, + 0.39026, + 0.39132, + 0.39252, + 0.38896, + 0.3926, + 0.39232, + 0.39498, + 0.394, + 0.39308, + 0.39394, + 0.39548, + 0.39628, + 0.39624, + 0.3982, + 0.39524, + 0.39714, + 0.39504, + 0.39736, + 0.39672, + 0.39934, + 0.397, + 0.40178, + 0.3973, + 0.4012, + 0.39838, + 0.40118, + 0.40006, + 0.40216, + 0.4005, + 0.39964, + 0.4008, + 0.40088, + 0.40064, + 0.40226, + 0.40106, + 0.40206, + 0.40272, + 0.4009, + 0.40232, + 0.40004, + 0.403, + 0.4021 + ], + "test_acc": [ + 0.2917, + 0.3201, + 0.3234, + 0.3265, + 0.3415, + 0.3452, + 0.341, + 0.3599, + 0.3605, + 0.3541, + 0.3715, + 0.3712, + 0.3656, + 0.3709, + 0.3852, + 0.3723, + 0.3799, + 0.3724, + 0.3741, + 0.3908, + 0.3881, + 0.3848, + 0.3868, + 0.3732, + 0.3985, + 0.3937, + 0.3954, + 0.3986, + 0.3997, + 0.399, + 0.3937, + 0.3925, + 0.4031, + 0.4036, + 0.4049, + 0.4084, + 0.4026, + 0.3988, + 0.3939, + 0.3946, + 0.4049, + 0.4057, + 0.4006, + 0.4033, + 0.402, + 0.4098, + 0.4052, + 0.4075, + 0.3987, + 0.4165, + 0.401, + 0.4107, + 0.3986, + 0.4164, + 0.4144, + 0.4063, + 0.4156, + 0.4166, + 0.4158, + 0.4173, + 0.4165, + 0.4143, + 0.4139, + 0.4156, + 0.4148, + 0.4128, + 0.4153, + 0.4186, + 0.4181, + 0.4137, + 0.4163, + 0.4148, + 0.418, + 0.421, + 0.4217, + 0.4192, + 0.4199, + 0.4217, + 0.4208, + 0.4219, + 0.4241, + 0.4226, + 0.4209, + 0.4193, + 0.4221, + 0.4205, + 0.4229, + 0.42, + 0.4224, + 0.425, + 0.4253, + 0.4246, + 0.4224, + 0.4247, + 0.4235, + 0.4238, + 0.4241, + 0.4242, + 0.4245, + 0.4244 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.007479430641978979, + 0.03817027807235718, + -0.1650964617729187, + 0.9953471422195435 + ], + "perturbation_rho": [ + -0.0640043318271637, + 0.045080721378326416, + -5.970895290374756e-05, + -0.07074473798274994 + ], + "nudging": { + "0.001": [ + 4.5995693653821945e-07, + -2.60770320892334e-07, + 3.650784492492676e-07, + -2.395769115537405e-06 + ], + "0.003": [ + 1.9447761587798595e-06, + -8.136266842484474e-07, + 1.4320830814540386e-06, + -8.26612813398242e-06 + ], + "0.01": [ + 6.193062290549278e-06, + -3.1341915018856525e-06, + 4.9326918087899685e-06, + -2.9165181331336498e-05 + ] + }, + "hidden_norms_per_layer": [ + 6685.263671875, + 69715.734375, + 615023.8125, + 1362937.75, + 466747.5 + ], + "bp_grad_norms_per_layer": [ + 3.627526166383177e-05, + 3.065474629693199e-06, + 1.1288288987998385e-06, + 1.124225377679977e-06, + 1.1187978543603094e-06 + ] + }, + "drift": { + "embed.weight": 41.806395160952334, + "embed.bias": 22.65316128534114, + "blocks.0.ln.weight": 1.0245623407955453, + "blocks.0.w1.weight": 14.934223802543457, + "blocks.0.w1.bias": 15.12866030040346, + "blocks.0.w2.weight": 53.87746375738443, + "blocks.1.ln.weight": 0.9157411148330525, + "blocks.1.w1.weight": 20.47192274795864, + "blocks.1.w1.bias": 17.83477261022895, + "blocks.1.w2.weight": 42.46486129463301, + "blocks.2.ln.weight": 0.609426201836372, + "blocks.2.w1.weight": 21.9031621576148, + "blocks.2.w1.bias": 25.87166178551529, + "blocks.2.w2.weight": 26.77134315272843, + "blocks.3.ln.weight": 0.6636105283840906, + "blocks.3.w1.weight": 21.84827836339499, + "blocks.3.w1.bias": 23.312782840496634, + "blocks.3.w2.weight": 37.76159247144327, + "out_ln.weight": 0.2876820859790667, + "out_head.weight": 5.782440518290608, + "out_head.bias": 1.2881523166930118 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 4, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_depth_scan_d512_L4", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_depth_scan_d512_L6/results_cifar10.json b/results/fa_depth_scan_d512_L6/results_cifar10.json new file mode 100644 index 0000000..f92f310 --- /dev/null +++ b/results/fa_depth_scan_d512_L6/results_cifar10.json @@ -0,0 +1,433 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.0374581846618653, + 1.9505271494293213, + 1.9189605682373048, + 1.8992357767105104, + 1.8916226636123656, + 1.8799470928955078, + 1.8715093560409546, + 1.8660212090301513, + 1.8614284765625, + 1.8615510483169555, + 1.858706633758545, + 1.8505436419296264, + 1.8469465209197997, + 1.84686426486969, + 1.8409993788909913, + 1.8422894380950927, + 1.8362944522094726, + 1.838251561012268, + 1.8342618326568603, + 1.8288902439117432, + 1.8288569342803955, + 1.8237011749649048, + 1.8190487328338623, + 1.8216550624847412, + 1.818442062072754, + 1.8108094896697997, + 1.8102704141998291, + 1.8097241515350342, + 1.8078967553710938, + 1.799188058128357, + 1.8009985346221924, + 1.8004494039535524, + 1.79110828830719, + 1.7962453186035157, + 1.7955271081924438, + 1.7904337732696534, + 1.788247091140747, + 1.7867638687133789, + 1.787053722076416, + 1.7837253421020507, + 1.7821894748687743, + 1.7816637888336182, + 1.7749454086685181, + 1.7757588800811768, + 1.7753548749542236, + 1.7709813816070556, + 1.7743957537078858, + 1.7711823197174072, + 1.7686557120132447, + 1.7698634561157227, + 1.76816160987854, + 1.7640094296646118, + 1.7590285652923583, + 1.759483937072754, + 1.7608595893096923, + 1.759725392189026, + 1.7576867937850953, + 1.7575671384429932, + 1.7513107781600952, + 1.7550161592102052, + 1.752431219062805, + 1.7487841637802124, + 1.7479721618652344, + 1.7504890463638305, + 1.7505805645751953, + 1.7466501081085206, + 1.7456124398422241, + 1.7472877283096313, + 1.740817502784729, + 1.7439257540893556, + 1.7465672635650635, + 1.744754091758728, + 1.7406636810684204, + 1.741213535194397, + 1.7423687073516845, + 1.742182956199646, + 1.742147672958374, + 1.7392435836029052, + 1.7430149713516236, + 1.7413112850952148, + 1.7365029236602783, + 1.739822388305664, + 1.7374979788208007, + 1.7375597472763062, + 1.7379079512786866, + 1.735270802078247, + 1.7385935315322876, + 1.7365304137420654, + 1.7351179248428346, + 1.7382376040649414, + 1.734423624343872, + 1.7308770125579833, + 1.7355197325897216, + 1.7343845403671265, + 1.734937107887268, + 1.7350866909408569, + 1.735432066116333, + 1.734759259376526, + 1.733669355697632, + 1.73626565574646 + ], + "train_acc": [ + 0.24744, + 0.28938, + 0.3027, + 0.31002, + 0.31422, + 0.32022, + 0.32548, + 0.32622, + 0.33078, + 0.32936, + 0.33574, + 0.33636, + 0.33724, + 0.33784, + 0.34, + 0.3413, + 0.34206, + 0.33976, + 0.3412, + 0.34588, + 0.34544, + 0.34566, + 0.34976, + 0.34872, + 0.35166, + 0.35328, + 0.3528, + 0.35174, + 0.35398, + 0.35788, + 0.35598, + 0.35664, + 0.35946, + 0.3579, + 0.35866, + 0.35922, + 0.36176, + 0.36192, + 0.36018, + 0.36312, + 0.36604, + 0.36232, + 0.36722, + 0.36558, + 0.36774, + 0.36944, + 0.367, + 0.3692, + 0.36964, + 0.37104, + 0.37188, + 0.37134, + 0.37254, + 0.37084, + 0.37208, + 0.3718, + 0.37112, + 0.3728, + 0.37378, + 0.37414, + 0.373, + 0.3766, + 0.37534, + 0.37602, + 0.37528, + 0.37748, + 0.37684, + 0.37644, + 0.3761, + 0.37702, + 0.37658, + 0.3781, + 0.37686, + 0.38036, + 0.3784, + 0.37902, + 0.3791, + 0.37944, + 0.37872, + 0.37962, + 0.3794, + 0.37758, + 0.37992, + 0.37976, + 0.38084, + 0.37792, + 0.37966, + 0.3806, + 0.38122, + 0.37914, + 0.38144, + 0.3834, + 0.38042, + 0.3804, + 0.38228, + 0.38292, + 0.38086, + 0.38258, + 0.38086, + 0.37848 + ], + "test_acc": [ + 0.2938, + 0.3207, + 0.3406, + 0.3392, + 0.3402, + 0.349, + 0.3619, + 0.3561, + 0.3742, + 0.3541, + 0.3632, + 0.3742, + 0.3679, + 0.3696, + 0.3734, + 0.3754, + 0.3854, + 0.3709, + 0.3747, + 0.3711, + 0.3773, + 0.3787, + 0.384, + 0.3688, + 0.386, + 0.3897, + 0.383, + 0.3852, + 0.3876, + 0.3857, + 0.3878, + 0.3886, + 0.3873, + 0.3929, + 0.3862, + 0.3864, + 0.3871, + 0.3947, + 0.3901, + 0.3942, + 0.3854, + 0.3841, + 0.3892, + 0.3876, + 0.3965, + 0.3924, + 0.382, + 0.3953, + 0.3896, + 0.3921, + 0.3975, + 0.3964, + 0.3916, + 0.3991, + 0.3928, + 0.4014, + 0.3993, + 0.4035, + 0.3844, + 0.3975, + 0.4034, + 0.4017, + 0.3952, + 0.3992, + 0.4025, + 0.4019, + 0.3993, + 0.3959, + 0.3993, + 0.4058, + 0.3945, + 0.4016, + 0.4055, + 0.4015, + 0.4036, + 0.4034, + 0.3976, + 0.4013, + 0.4024, + 0.4014, + 0.3967, + 0.398, + 0.3947, + 0.4056, + 0.4012, + 0.4001, + 0.402, + 0.4004, + 0.4031, + 0.4008, + 0.4025, + 0.402, + 0.4018, + 0.4026, + 0.4007, + 0.4012, + 0.4015, + 0.4021, + 0.4015, + 0.4014 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.018813904374837875, + 0.08222152292728424, + -0.07224605977535248, + -0.09462258964776993, + -0.09898842871189117, + 0.9958882927894592 + ], + "perturbation_rho": [ + -0.017015788704156876, + -0.0317268893122673, + -0.004754193127155304, + -0.006329755764454603, + 0.03243139386177063, + -0.012378348037600517 + ], + "nudging": { + "0.001": [ + -1.5720142982900143e-06, + -4.273606464266777e-07, + 7.217749953269958e-08, + 1.2828968465328217e-07, + 1.2386590242385864e-07, + -1.2737000361084938e-06 + ], + "0.003": [ + -4.96965367347002e-06, + -1.4296965673565865e-06, + 4.01865690946579e-07, + 4.814937710762024e-07, + 5.081528797745705e-07, + -4.469649866223335e-06 + ], + "0.01": [ + -1.6447564121335745e-05, + -4.997651558369398e-06, + 1.1826632544398308e-06, + 1.6157864592969418e-06, + 1.566135324537754e-06, + -1.586077269166708e-05 + ] + }, + "hidden_norms_per_layer": [ + 7316.48291015625, + 78755.234375, + 576686.625, + 1329546.125, + 1568668.375, + 1854290.875, + 1069435.5 + ], + "bp_grad_norms_per_layer": [ + 2.8450938771129586e-05, + 1.794567538127012e-06, + 6.647851478192024e-07, + 6.471778988270671e-07, + 6.503790359602135e-07, + 6.49630294446979e-07, + 6.414796303033654e-07 + ] + }, + "drift": { + "embed.weight": 48.85389529931473, + "embed.bias": 13.565561887190782, + "blocks.0.ln.weight": 1.0624637379045918, + "blocks.0.w1.weight": 16.554310656593483, + "blocks.0.w1.bias": 11.372855905617644, + "blocks.0.w2.weight": 53.87251555976206, + "blocks.1.ln.weight": 1.0837239573564923, + "blocks.1.w1.weight": 22.586400420670774, + "blocks.1.w1.bias": 15.377758862989662, + "blocks.1.w2.weight": 35.85773072344426, + "blocks.2.ln.weight": 0.9078778614654982, + "blocks.2.w1.weight": 23.522098095099405, + "blocks.2.w1.bias": 22.904554985601767, + "blocks.2.w2.weight": 42.07479881121452, + "blocks.3.ln.weight": 0.6724399561234441, + "blocks.3.w1.weight": 20.752863459155623, + "blocks.3.w1.bias": 20.37186903428861, + "blocks.3.w2.weight": 36.191664332701706, + "blocks.4.ln.weight": 0.6795979662627959, + "blocks.4.w1.weight": 20.607291873742884, + "blocks.4.w1.bias": 20.91900548314815, + "blocks.4.w2.weight": 42.61545291740059, + "blocks.5.ln.weight": 0.6835153473115325, + "blocks.5.w1.weight": 20.653168018914645, + "blocks.5.w1.bias": 21.158310771641304, + "blocks.5.w2.weight": 45.94200541763277, + "out_ln.weight": 0.34596348666722354, + "out_head.weight": 6.343616372534138, + "out_head.bias": 0.7709641239019656 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 6, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_depth_scan_d512_L6", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_depth_scan_d512_L8/results_cifar10.json b/results/fa_depth_scan_d512_L8/results_cifar10.json new file mode 100644 index 0000000..ae9217c --- /dev/null +++ b/results/fa_depth_scan_d512_L8/results_cifar10.json @@ -0,0 +1,455 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.043062113418579, + 1.95162189743042, + 1.920884204940796, + 1.9046970273590087, + 1.894838911705017, + 1.8837308880233765, + 1.8795231484603883, + 1.870696762161255, + 1.864269206237793, + 1.8619007619857788, + 1.8598490927505493, + 1.8526468254852295, + 1.8434244635009767, + 1.8405422570037842, + 1.8366851897811889, + 1.8321461069107057, + 1.8310893895721436, + 1.825379607810974, + 1.8238727295303345, + 1.8163351691055298, + 1.8104810161590577, + 1.8127811385726929, + 1.8086815365982056, + 1.804211870803833, + 1.8026433480453492, + 1.7963054836654664, + 1.7984822222518921, + 1.7911431057739258, + 1.7928030535888673, + 1.7889471031951905, + 1.7860312490081787, + 1.779016375427246, + 1.782450693588257, + 1.7774018466186523, + 1.7763366507339478, + 1.7730076089859008, + 1.7679658282852173, + 1.7679697444915772, + 1.7660620833969116, + 1.7650900997543335, + 1.7642925763320922, + 1.7597538912582398, + 1.7603528467178344, + 1.7629493494415283, + 1.7623634843826295, + 1.756051732444763, + 1.7566206322860718, + 1.7543605234146118, + 1.7578279922485351, + 1.7509302628326415, + 1.7494868576812743, + 1.751580044555664, + 1.7507692102813721, + 1.748738152809143, + 1.7474167670059204, + 1.7463805029678345, + 1.7446185354995727, + 1.742856813430786, + 1.7411311776351928, + 1.7384925448989867, + 1.7379512426376342, + 1.7373969628524781, + 1.7368676550674438, + 1.735899489364624, + 1.7348208691787719, + 1.7361816341400147, + 1.7328850330352783, + 1.7339306282806397, + 1.7350596377563476, + 1.729668268814087, + 1.7330098389053346, + 1.7316270180892945, + 1.7296399521255492, + 1.7288719012451172, + 1.7293606524658203, + 1.7286542098999023, + 1.7271976955795287, + 1.726504995765686, + 1.7264521031951905, + 1.7255205670547484, + 1.7223241259384154, + 1.7238075354385376, + 1.7262169020843505, + 1.7254022949981689, + 1.7218294512176513, + 1.7234743248748778, + 1.7214548685455322, + 1.7229200372695923, + 1.722522060546875, + 1.7214176669311523, + 1.721013473739624, + 1.7213758332061768, + 1.7200873915863037, + 1.7215943475341797, + 1.7193219763183594, + 1.7175406676483154, + 1.7181979109954835, + 1.7207658834075927, + 1.7238012107086182, + 1.7181251257705688 + ], + "train_acc": [ + 0.2481, + 0.28956, + 0.29888, + 0.30732, + 0.31422, + 0.31622, + 0.32282, + 0.3247, + 0.3269, + 0.33028, + 0.33286, + 0.33738, + 0.3392, + 0.34038, + 0.3437, + 0.34256, + 0.34454, + 0.34842, + 0.34672, + 0.35004, + 0.3535, + 0.35304, + 0.35456, + 0.35594, + 0.35742, + 0.35726, + 0.35818, + 0.35816, + 0.35896, + 0.36116, + 0.36062, + 0.36406, + 0.36038, + 0.36366, + 0.36602, + 0.3664, + 0.36858, + 0.36588, + 0.3697, + 0.3681, + 0.3694, + 0.37272, + 0.37256, + 0.3704, + 0.36982, + 0.3713, + 0.37258, + 0.37364, + 0.37452, + 0.37376, + 0.37334, + 0.37602, + 0.37386, + 0.37334, + 0.37724, + 0.37578, + 0.37558, + 0.3779, + 0.37994, + 0.37834, + 0.37772, + 0.38086, + 0.38184, + 0.38024, + 0.38166, + 0.38102, + 0.38296, + 0.38192, + 0.38094, + 0.3819, + 0.38064, + 0.38246, + 0.3825, + 0.38418, + 0.3848, + 0.3841, + 0.38336, + 0.38264, + 0.3835, + 0.3861, + 0.38616, + 0.38576, + 0.38372, + 0.38684, + 0.3885, + 0.3868, + 0.3877, + 0.38486, + 0.38744, + 0.38718, + 0.3855, + 0.38922, + 0.3886, + 0.38746, + 0.388, + 0.39044, + 0.38896, + 0.38804, + 0.3862, + 0.38788 + ], + "test_acc": [ + 0.296, + 0.3104, + 0.3372, + 0.3454, + 0.3476, + 0.3503, + 0.3516, + 0.3529, + 0.3615, + 0.3574, + 0.3669, + 0.3575, + 0.3654, + 0.3694, + 0.3666, + 0.3687, + 0.3662, + 0.3673, + 0.3741, + 0.3617, + 0.368, + 0.3787, + 0.3763, + 0.3839, + 0.3868, + 0.378, + 0.3857, + 0.3626, + 0.3829, + 0.3795, + 0.3808, + 0.3767, + 0.3809, + 0.3887, + 0.3878, + 0.3792, + 0.3945, + 0.384, + 0.3959, + 0.3955, + 0.3925, + 0.3985, + 0.3845, + 0.3945, + 0.3938, + 0.3921, + 0.3958, + 0.4002, + 0.3951, + 0.4002, + 0.4059, + 0.3946, + 0.3994, + 0.398, + 0.3924, + 0.3941, + 0.3965, + 0.3964, + 0.405, + 0.406, + 0.398, + 0.402, + 0.4056, + 0.3978, + 0.4048, + 0.4017, + 0.4051, + 0.4054, + 0.4028, + 0.4046, + 0.4031, + 0.4043, + 0.4074, + 0.4097, + 0.4072, + 0.4063, + 0.4068, + 0.4068, + 0.4075, + 0.4064, + 0.408, + 0.4072, + 0.4083, + 0.4043, + 0.4071, + 0.4075, + 0.407, + 0.4105, + 0.4073, + 0.4076, + 0.4071, + 0.4095, + 0.409, + 0.4099, + 0.4093, + 0.409, + 0.4087, + 0.4092, + 0.4093, + 0.4094 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.026475638151168823, + 0.0839557871222496, + -0.0336647555232048, + -0.057554133236408234, + -0.06538982689380646, + -0.06987135112285614, + -0.051756590604782104, + 0.9981189966201782 + ], + "perturbation_rho": [ + 0.021994909271597862, + -0.047078944742679596, + 0.022503621876239777, + 0.025055162608623505, + -0.025491345673799515, + -0.03935955837368965, + -0.002435870934277773, + 0.047669265419244766 + ], + "nudging": { + "0.001": [ + -2.411194145679474e-06, + -3.704335540533066e-07, + 4.307366907596588e-08, + 9.359791874885559e-08, + 1.1106021702289581e-07, + 7.031485438346863e-08, + 1.210719347000122e-08, + -1.080334186553955e-06 + ], + "0.003": [ + -7.076392648741603e-06, + -1.1655502021312714e-06, + 2.3515895009040833e-07, + 3.2177194952964783e-07, + 3.494787961244583e-07, + 2.796296030282974e-07, + 1.310836523771286e-07, + -4.032859578728676e-06 + ], + "0.01": [ + -2.3676970158703625e-05, + -4.229601472616196e-06, + 3.962777554988861e-07, + 9.336508810520172e-07, + 1.050299033522606e-06, + 1.1418014764785767e-06, + 7.725320756435394e-07, + -1.4404766261577606e-05 + ] + }, + "hidden_norms_per_layer": [ + 7406.1591796875, + 81019.6640625, + 737389.0, + 1286861.0, + 1535224.625, + 1916653.125, + 2282567.25, + 2428105.25, + 1218042.25 + ], + "bp_grad_norms_per_layer": [ + 3.065524288103916e-05, + 1.910160335683031e-06, + 6.406741590581078e-07, + 6.244783321562863e-07, + 6.195934929564828e-07, + 6.173827955535671e-07, + 6.186029395394144e-07, + 6.207166052263347e-07, + 6.07166782629065e-07 + ] + }, + "drift": { + "embed.weight": 47.55569593801599, + "embed.bias": 14.17447035910836, + "blocks.0.ln.weight": 1.0983379558592676, + "blocks.0.w1.weight": 16.250524799209163, + "blocks.0.w1.bias": 12.62261139589463, + "blocks.0.w2.weight": 61.59978351635402, + "blocks.1.ln.weight": 1.1319422288079952, + "blocks.1.w1.weight": 22.880541075681524, + "blocks.1.w1.bias": 17.383779962472858, + "blocks.1.w2.weight": 44.96085798621816, + "blocks.2.ln.weight": 0.8177694150936128, + "blocks.2.w1.weight": 22.134352628330692, + "blocks.2.w1.bias": 21.07281150651575, + "blocks.2.w2.weight": 35.61435775314368, + "blocks.3.ln.weight": 0.7311969207849126, + "blocks.3.w1.weight": 20.785325587610842, + "blocks.3.w1.bias": 19.03573291074914, + "blocks.3.w2.weight": 35.05588574995981, + "blocks.4.ln.weight": 0.6490724785036317, + "blocks.4.w1.weight": 21.354390651446167, + "blocks.4.w1.bias": 21.159821383377235, + "blocks.4.w2.weight": 28.795764202368506, + "blocks.5.ln.weight": 0.6339903970932835, + "blocks.5.w1.weight": 21.59073387917402, + "blocks.5.w1.bias": 21.94406712381077, + "blocks.5.w2.weight": 32.095039396227314, + "blocks.6.ln.weight": 0.6243319829763818, + "blocks.6.w1.weight": 19.49885333502707, + "blocks.6.w1.bias": 20.57317316179931, + "blocks.6.w2.weight": 32.88519516715278, + "blocks.7.ln.weight": 0.7128901839398082, + "blocks.7.w1.weight": 20.944827961922154, + "blocks.7.w1.bias": 21.447201189692905, + "blocks.7.w2.weight": 42.87255169680327, + "out_ln.weight": 0.37541884485095733, + "out_head.weight": 6.641485211520903, + "out_head.bias": 1.0438900292204085 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 512, + "num_blocks": 8, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_depth_scan_d512_L8", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_depth_sweep_redo.log b/results/fa_depth_sweep_redo.log new file mode 100644 index 0000000..cb304b6 --- /dev/null +++ b/results/fa_depth_sweep_redo.log @@ -0,0 +1,90 @@ +=== FA d=512 depth sweep REDO (separate dirs) === + L=2 (Wed Apr 22 11:16:28 PM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0612, train=0.2476, test=0.3028 + [FA] Epoch 10: loss=1.8290, train=0.3435, test=0.3705 + [FA] Epoch 20: loss=1.8102, train=0.3489, test=0.3634 + [FA] Epoch 30: loss=1.7963, train=0.3546, test=0.3398 + [FA] Epoch 40: loss=1.7775, train=0.3605, test=0.3497 + [FA] Epoch 50: loss=1.7610, train=0.3685, test=0.3288 + [FA] Epoch 60: loss=1.7592, train=0.3704, test=0.3376 + [FA] Epoch 70: loss=1.7588, train=0.3747, test=0.3421 + [FA] Epoch 80: loss=1.7564, train=0.3751, test=0.3497 + [FA] Epoch 90: loss=1.7543, train=0.3769, test=0.3472 + [FA] Epoch 100: loss=1.7527, train=0.3768, test=0.3495 + Final test acc: 0.3495 + +All results saved to results/fa_depth_scan_d512_L2/results_cifar10.json + L=4 (Wed Apr 22 11:22:52 PM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0301, train=0.2531, test=0.2917 + [FA] Epoch 10: loss=1.8487, train=0.3366, test=0.3541 + [FA] Epoch 20: loss=1.7864, train=0.3609, test=0.3908 + [FA] Epoch 30: loss=1.7510, train=0.3724, test=0.3990 + [FA] Epoch 40: loss=1.7387, train=0.3767, test=0.3946 + [FA] Epoch 50: loss=1.7209, train=0.3875, test=0.4165 + [FA] Epoch 60: loss=1.7052, train=0.3913, test=0.4173 + [FA] Epoch 70: loss=1.6945, train=0.3963, test=0.4137 + [FA] Epoch 80: loss=1.6868, train=0.4018, test=0.4219 + [FA] Epoch 90: loss=1.6830, train=0.4009, test=0.4250 + [FA] Epoch 100: loss=1.6781, train=0.4021, test=0.4244 + Final test acc: 0.4244 + +All results saved to results/fa_depth_scan_d512_L4/results_cifar10.json + L=6 (Wed Apr 22 11:29:10 PM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0375, train=0.2474, test=0.2938 + [FA] Epoch 10: loss=1.8616, train=0.3294, test=0.3541 + [FA] Epoch 20: loss=1.8289, train=0.3459, test=0.3711 + [FA] Epoch 30: loss=1.7992, train=0.3579, test=0.3857 + [FA] Epoch 40: loss=1.7837, train=0.3631, test=0.3942 + [FA] Epoch 50: loss=1.7699, train=0.3710, test=0.3921 + [FA] Epoch 60: loss=1.7550, train=0.3741, test=0.3975 + [FA] Epoch 70: loss=1.7439, train=0.3770, test=0.4058 + [FA] Epoch 80: loss=1.7413, train=0.3796, test=0.4014 + [FA] Epoch 90: loss=1.7382, train=0.3791, test=0.4008 + [FA] Epoch 100: loss=1.7363, train=0.3785, test=0.4014 + Final test acc: 0.4014 + +All results saved to results/fa_depth_scan_d512_L6/results_cifar10.json + L=8 (Wed Apr 22 11:36:14 PM CDT 2026) +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0431, train=0.2481, test=0.2960 + [FA] Epoch 10: loss=1.8619, train=0.3303, test=0.3574 + [FA] Epoch 20: loss=1.8163, train=0.3500, test=0.3617 + [FA] Epoch 30: loss=1.7889, train=0.3612, test=0.3795 + [FA] Epoch 40: loss=1.7651, train=0.3681, test=0.3955 + [FA] Epoch 50: loss=1.7509, train=0.3738, test=0.4002 + [FA] Epoch 60: loss=1.7385, train=0.3783, test=0.4060 + [FA] Epoch 70: loss=1.7297, train=0.3819, test=0.4046 + [FA] Epoch 80: loss=1.7255, train=0.3861, test=0.4064 + [FA] Epoch 90: loss=1.7214, train=0.3872, test=0.4076 + [FA] Epoch 100: loss=1.7181, train=0.3879, test=0.4094 + Final test acc: 0.4094 + +All results saved to results/fa_depth_scan_d512_L8/results_cifar10.json +=== DEPTH SWEEP REDO DONE (Wed Apr 22 11:44:47 PM CDT 2026) === diff --git a/results/fa_early_ckpts/results_cifar10.json b/results/fa_early_ckpts/results_cifar10.json new file mode 100644 index 0000000..aa046e3 --- /dev/null +++ b/results/fa_early_ckpts/results_cifar10.json @@ -0,0 +1,324 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.049124941177368, + 1.9718926105117798, + 1.9489588039398194, + 1.934002057762146, + 1.9262304685211182 + ], + "train_acc": [ + 0.2436, + 0.2745, + 0.29034, + 0.29818, + 0.301 + ], + "test_acc": [ + 0.2789, + 0.3111, + 0.3091, + 0.3094, + 0.3219 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.0644938200712204, + 0.00024983659386634827, + -0.006781474687159061, + 0.9952471852302551 + ], + "perturbation_rho": [ + 0.06215526908636093, + -0.005929573439061642, + 0.029685823246836662, + 0.183500275015831 + ], + "nudging": { + "0.001": [ + -2.8510112315416336e-06, + -6.868503987789154e-08, + -1.4668330550193787e-08, + -6.758375093340874e-06 + ], + "0.003": [ + -8.508563041687012e-06, + -1.6507692635059357e-07, + -4.912726581096649e-08, + -2.0218081772327423e-05 + ], + "0.01": [ + -2.8386013582348824e-05, + -6.488990038633347e-07, + -1.6367994248867035e-07, + -6.742705591022968e-05 + ] + }, + "hidden_norms_per_layer": [ + 960.3140869140625, + 11076.888671875, + 31092.83984375, + 34876.390625, + 29288.8984375 + ], + "bp_grad_norms_per_layer": [ + 2.04460939130513e-05, + 3.4170184335380327e-06, + 3.3133326269307872e-06, + 3.321988742754911e-06, + 3.197354772055405e-06 + ] + }, + "drift": { + "embed.weight": 9.203150246245913, + "embed.bias": 7.782631309228306, + "blocks.0.ln.weight": 0.5169422030448914, + "blocks.0.w1.weight": 7.173286797050866, + "blocks.0.w1.bias": 7.530177507927273, + "blocks.0.w2.weight": 20.753186824233353, + "blocks.1.ln.weight": 0.4266158640384674, + "blocks.1.w1.weight": 6.365489317176397, + "blocks.1.w1.bias": 7.642448312201989, + "blocks.1.w2.weight": 16.266082583917918, + "blocks.2.ln.weight": 0.366255521774292, + "blocks.2.w1.weight": 5.12781119461704, + "blocks.2.w1.bias": 5.784653325984481, + "blocks.2.w2.weight": 14.283055474835482, + "blocks.3.ln.weight": 0.3213387429714203, + "blocks.3.w1.weight": 4.3690188550450015, + "blocks.3.w1.bias": 4.601659627893413, + "blocks.3.w2.weight": 12.971053381902397, + "out_ln.weight": 0.05851004645228386, + "out_head.weight": 1.1675271811094783, + "out_head.bias": 0.45402486694117133 + } + } + }, + "123": { + "fa": { + "log": { + "train_loss": [ + 2.039442294845581, + 1.953437792930603, + 1.9257947838592528, + 1.9107846630096434, + 1.9030635565567016 + ], + "train_acc": [ + 0.25042, + 0.29012, + 0.30112, + 0.30812, + 0.31226 + ], + "test_acc": [ + 0.2905, + 0.3307, + 0.3419, + 0.3476, + 0.3478 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.06648196280002594, + -0.010407494381070137, + -0.06906092911958694, + 0.9917651414871216 + ], + "perturbation_rho": [ + 0.031062623485922813, + -0.02341378480195999, + -0.033602140843868256, + 0.2669849991798401 + ], + "nudging": { + "0.001": [ + -3.1692907214164734e-06, + -1.126900315284729e-07, + 7.380731403827667e-07, + -1.1092517524957657e-05 + ], + "0.003": [ + -9.447336196899414e-06, + -3.655441105365753e-07, + 2.216547727584839e-06, + -3.331666812300682e-05 + ], + "0.01": [ + -3.144703805446625e-05, + -1.210719347000122e-06, + 7.366761565208435e-06, + -0.00011098524555563927 + ] + }, + "hidden_norms_per_layer": [ + 863.883056640625, + 6307.2734375, + 12256.30859375, + 20627.75390625, + 15648.552734375 + ], + "bp_grad_norms_per_layer": [ + 3.119367829640396e-05, + 5.960608177701943e-06, + 5.3465173550648615e-06, + 5.451070592243923e-06, + 5.225469521974446e-06 + ] + }, + "drift": { + "embed.weight": 8.803298126163122, + "embed.bias": 6.680663743131999, + "blocks.0.ln.weight": 0.4969744086265564, + "blocks.0.w1.weight": 6.118721027737535, + "blocks.0.w1.bias": 5.453192795258035, + "blocks.0.w2.weight": 17.832094218833557, + "blocks.1.ln.weight": 0.4233635365962982, + "blocks.1.w1.weight": 5.605093419712821, + "blocks.1.w1.bias": 6.1996663705372885, + "blocks.1.w2.weight": 16.130887571112158, + "blocks.2.ln.weight": 0.38245585560798645, + "blocks.2.w1.weight": 5.4055954853998145, + "blocks.2.w1.bias": 6.487736894934294, + "blocks.2.w2.weight": 15.370923217597205, + "blocks.3.ln.weight": 0.39037927985191345, + "blocks.3.w1.weight": 5.164380800730237, + "blocks.3.w1.bias": 5.6053151435754565, + "blocks.3.w2.weight": 13.468106289535303, + "out_ln.weight": 0.04700668901205063, + "out_head.weight": 1.1282362053366835, + "out_head.bias": 1.0759554656539119 + } + } + }, + "456": { + "fa": { + "log": { + "train_loss": [ + 2.064060781517029, + 1.9901417289733887, + 1.9581991654205322, + 1.9443307501983642, + 1.935314490966797 + ], + "train_acc": [ + 0.24014, + 0.268, + 0.28416, + 0.29098, + 0.29644 + ], + "test_acc": [ + 0.2713, + 0.3048, + 0.313, + 0.3135, + 0.3242 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.06508420407772064, + -0.013459235429763794, + -0.00898228120058775, + 0.9861425161361694 + ], + "perturbation_rho": [ + 0.0332966186106205, + -0.03956230729818344, + 0.008942835032939911, + 0.18084916472434998 + ], + "nudging": { + "0.001": [ + -2.798624336719513e-06, + 1.7695128917694092e-08, + -1.3969838619232178e-08, + -6.395392119884491e-06 + ], + "0.003": [ + -8.38935375213623e-06, + 7.916241884231567e-09, + -2.60770320892334e-08, + -1.91426370292902e-05 + ], + "0.01": [ + -2.7919188141822815e-05, + 9.639188647270203e-08, + -4.912726581096649e-08, + -6.371899507939816e-05 + ] + }, + "hidden_norms_per_layer": [ + 1010.2516479492188, + 13113.5537109375, + 22355.0546875, + 36300.3671875, + 30980.419921875 + ], + "bp_grad_norms_per_layer": [ + 1.7613443560549058e-05, + 3.159134394081775e-06, + 3.121002691841568e-06, + 3.1120944186113775e-06, + 2.9537256978073856e-06 + ] + }, + "drift": { + "embed.weight": 9.613138255724964, + "embed.bias": 6.926131704202163, + "blocks.0.ln.weight": 0.548354983329773, + "blocks.0.w1.weight": 7.294897696552802, + "blocks.0.w1.bias": 5.312217412593903, + "blocks.0.w2.weight": 20.619584988017134, + "blocks.1.ln.weight": 0.3894560933113098, + "blocks.1.w1.weight": 5.488544569987622, + "blocks.1.w1.bias": 4.831823702146538, + "blocks.1.w2.weight": 15.40878297211816, + "blocks.2.ln.weight": 0.37565672397613525, + "blocks.2.w1.weight": 5.8266288518010825, + "blocks.2.w1.bias": 6.58294580181279, + "blocks.2.w2.weight": 15.720187648044627, + "blocks.3.ln.weight": 0.33931589126586914, + "blocks.3.w1.weight": 4.6498033455621615, + "blocks.3.w1.bias": 3.4624320285535566, + "blocks.3.w2.weight": 14.935221420670851, + "out_ln.weight": 0.04563162103295326, + "out_head.weight": 1.066836036120676, + "out_head.bias": 0.553680529428652 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 5, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42, + 123, + 456 + ], + "gpu": 0, + "output_dir": "results/fa_early_ckpts", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_extension_experiments.log b/results/fa_extension_experiments.log new file mode 100644 index 0000000..ffc2805 --- /dev/null +++ b/results/fa_extension_experiments.log @@ -0,0 +1,159 @@ +========================================== +FA EXTENSION EXPERIMENTS (I-M) +========================================== +Start: Wed Apr 22 10:26:41 PM CDT 2026 + +=== I: FA+pen lam=1e-4 (30ep, 3 seeds) === +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0346, train=0.2557, test=0.2909 + [FA] Epoch 10: loss=1.8700, train=0.3346, test=0.3635 + [FA] Epoch 20: loss=1.8495, train=0.3436, test=0.3682 + [FA] Epoch 30: loss=1.8430, train=0.3521, test=0.3759 + Final test acc: 0.3759 + +============================================================ +Seed 123 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0259, train=0.2600, test=0.3099 + [FA] Epoch 10: loss=1.8666, train=0.3358, test=0.3532 + [FA] Epoch 20: loss=1.8505, train=0.3472, test=0.3685 + [FA] Epoch 30: loss=1.8391, train=0.3530, test=0.3725 + Final test acc: 0.3725 + +============================================================ +Seed 456 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0371, train=0.2562, test=0.2999 + [FA] Epoch 10: loss=1.8573, train=0.3373, test=0.3567 + [FA] Epoch 20: loss=1.8335, train=0.3500, test=0.3831 + [FA] Epoch 30: loss=1.8207, train=0.3574, test=0.3837 + Final test acc: 0.3837 + +All results saved to results/fa_penalty_lam1e-4_30ep/results_cifar10.json + +=== L: FA d=512 depth sweep (100ep, s42) === + L=2 +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0612, train=0.2476, test=0.3028 + [FA] Epoch 10: loss=1.8290, train=0.3435, test=0.3705 + [FA] Epoch 20: loss=1.8102, train=0.3489, test=0.3634 + [FA] Epoch 30: loss=1.7963, train=0.3546, test=0.3398 + [FA] Epoch 40: loss=1.7775, train=0.3605, test=0.3497 + [FA] Epoch 50: loss=1.7610, train=0.3685, test=0.3288 + [FA] Epoch 60: loss=1.7592, train=0.3704, test=0.3376 + [FA] Epoch 70: loss=1.7588, train=0.3747, test=0.3421 + [FA] Epoch 80: loss=1.7564, train=0.3751, test=0.3497 + [FA] Epoch 90: loss=1.7543, train=0.3769, test=0.3472 + [FA] Epoch 100: loss=1.7527, train=0.3768, test=0.3495 + Final test acc: 0.3495 + +All results saved to results/fa_depth_scan_d512/results_cifar10.json + L=4 +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0301, train=0.2531, test=0.2917 + [FA] Epoch 10: loss=1.8487, train=0.3366, test=0.3541 + [FA] Epoch 20: loss=1.7864, train=0.3609, test=0.3908 + [FA] Epoch 30: loss=1.7510, train=0.3724, test=0.3990 + [FA] Epoch 40: loss=1.7387, train=0.3767, test=0.3946 + [FA] Epoch 50: loss=1.7209, train=0.3875, test=0.4165 + [FA] Epoch 60: loss=1.7052, train=0.3913, test=0.4173 + [FA] Epoch 70: loss=1.6945, train=0.3963, test=0.4137 + [FA] Epoch 80: loss=1.6868, train=0.4018, test=0.4219 + [FA] Epoch 90: loss=1.6830, train=0.4009, test=0.4250 + [FA] Epoch 100: loss=1.6781, train=0.4021, test=0.4244 + Final test acc: 0.4244 + +All results saved to results/fa_depth_scan_d512/results_cifar10.json + L=6 +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0375, train=0.2474, test=0.2938 + [FA] Epoch 10: loss=1.8616, train=0.3294, test=0.3541 + [FA] Epoch 20: loss=1.8289, train=0.3459, test=0.3711 + [FA] Epoch 30: loss=1.7992, train=0.3579, test=0.3857 + [FA] Epoch 40: loss=1.7837, train=0.3631, test=0.3942 + [FA] Epoch 50: loss=1.7699, train=0.3710, test=0.3921 + [FA] Epoch 60: loss=1.7550, train=0.3741, test=0.3975 + [FA] Epoch 70: loss=1.7439, train=0.3770, test=0.4058 + [FA] Epoch 80: loss=1.7413, train=0.3796, test=0.4014 + [FA] Epoch 90: loss=1.7382, train=0.3791, test=0.4008 + [FA] Epoch 100: loss=1.7363, train=0.3785, test=0.4014 + Final test acc: 0.4014 + +All results saved to results/fa_depth_scan_d512/results_cifar10.json + L=8 +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0431, train=0.2481, test=0.2960 + [FA] Epoch 10: loss=1.8619, train=0.3303, test=0.3574 + [FA] Epoch 20: loss=1.8163, train=0.3500, test=0.3617 + [FA] Epoch 30: loss=1.7889, train=0.3612, test=0.3795 + [FA] Epoch 40: loss=1.7651, train=0.3681, test=0.3955 + [FA] Epoch 50: loss=1.7509, train=0.3738, test=0.4002 + [FA] Epoch 60: loss=1.7385, train=0.3783, test=0.4060 + [FA] Epoch 70: loss=1.7297, train=0.3819, test=0.4046 + [FA] Epoch 80: loss=1.7255, train=0.3861, test=0.4064 + [FA] Epoch 90: loss=1.7214, train=0.3872, test=0.4076 + [FA] Epoch 100: loss=1.7181, train=0.3879, test=0.4094 + Final test acc: 0.4094 + +All results saved to results/fa_depth_scan_d512/results_cifar10.json + L=12 +Using device: cuda:0 + +============================================================ +Seed 42 +============================================================ + +--- FA --- + [FA] Epoch 1: loss=2.0427, train=0.2406, test=0.2963 + [FA] Epoch 10: loss=1.8510, train=0.3333, test=0.3712 + [FA] Epoch 20: loss=1.8069, train=0.3520, test=0.3747 + [FA] Epoch 30: loss=1.7837, train=0.3593, test=0.3827 + [FA] Epoch 40: loss=1.7677, train=0.3684, test=0.4088 + [FA] Epoch 50: loss=1.7521, train=0.3730, test=0.3905 + [FA] Epoch 60: loss=1.7441, train=0.3767, test=0.4042 + [FA] Epoch 70: loss=1.7356, train=0.3807, test=0.4046 + [FA] Epoch 80: loss=1.7307, train=0.3824, test=0.4037 + [FA] Epoch 90: loss=1.7268, train=0.3839, test=0.4014 + [FA] Epoch 100: loss=1.7295, train=0.3834, test=0.4035 + Final test acc: 0.4035 + +All results saved to results/fa_depth_scan_d512/results_cifar10.json + +========================================== +FA EXTENSION EXPERIMENTS DONE +End: Wed Apr 22 11:13:24 PM CDT 2026 +========================================== diff --git a/results/fa_main_audit/results_cifar10.json b/results/fa_main_audit/results_cifar10.json new file mode 100644 index 0000000..902df50 --- /dev/null +++ b/results/fa_main_audit/results_cifar10.json @@ -0,0 +1,1179 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.049124941177368, + 1.971964567337036, + 1.9483523839569092, + 1.9344041421508789, + 1.9224640267944335, + 1.9110223713684082, + 1.9013612536621094, + 1.8914563995742797, + 1.8888161966705321, + 1.8849480002212524, + 1.8800471724700927, + 1.8740126901245118, + 1.8716543948745727, + 1.8681818910217285, + 1.867931474647522, + 1.866554428138733, + 1.8656607852935791, + 1.85699001285553, + 1.855739754257202, + 1.8523859079360963, + 1.8532698722076415, + 1.8493807650375367, + 1.84840350189209, + 1.8443957958602906, + 1.842935036239624, + 1.8395105728912353, + 1.8337825148773192, + 1.8333892483520509, + 1.8321472442245483, + 1.8294915601348878, + 1.8279203462982179, + 1.826354044494629, + 1.8228275035858155, + 1.8226369649887084, + 1.8176784851837158, + 1.8198615470123292, + 1.8155909259796144, + 1.8137150632476806, + 1.8144374758911133, + 1.816846872291565, + 1.8122928824234008, + 1.8105264510345458, + 1.8050418474578858, + 1.810307405052185, + 1.8034298502349853, + 1.8004834436798096, + 1.8044248488616943, + 1.8068354278182983, + 1.8045337250518798, + 1.8001124160766602, + 1.7975274431991577, + 1.7963342412567138, + 1.7940989971923829, + 1.7987028061294557, + 1.7954791347503662, + 1.7969434225845338, + 1.794546351585388, + 1.790341936378479, + 1.7904112981414795, + 1.796010117111206, + 1.7922658501434325, + 1.7879667960357666, + 1.787852370376587, + 1.7872013766479493, + 1.7859069106292724, + 1.7856995489120484, + 1.7886004698944091, + 1.7876731681060791, + 1.7865260543823243, + 1.785853352394104, + 1.7849980658340454, + 1.7868319046020509, + 1.7825030782318114, + 1.7818804879379273, + 1.7835813956069946, + 1.7858538039779663, + 1.7806534545516968, + 1.7811814181900025, + 1.7809104974365235, + 1.7819213286972047, + 1.7821454078674317, + 1.782828011817932, + 1.779138667678833, + 1.778261279373169, + 1.7800345748901367, + 1.776551816444397, + 1.7768282559967041, + 1.7764475776672364, + 1.7771262689590455, + 1.7769524020767211, + 1.7762922687149048, + 1.7768514712524415, + 1.779178479270935, + 1.7777679949951173, + 1.7764834703826904, + 1.77460149269104, + 1.7745818335342407, + 1.7761323275375367, + 1.7797669164276122, + 1.7776144277954102 + ], + "train_acc": [ + 0.2436, + 0.27338, + 0.2898, + 0.29694, + 0.30116, + 0.3039, + 0.30988, + 0.31412, + 0.31788, + 0.31722, + 0.31988, + 0.32372, + 0.323, + 0.32762, + 0.32672, + 0.3276, + 0.32934, + 0.33266, + 0.3352, + 0.33332, + 0.33394, + 0.33564, + 0.33554, + 0.33832, + 0.33744, + 0.33804, + 0.34066, + 0.34208, + 0.34102, + 0.34436, + 0.34358, + 0.34454, + 0.34578, + 0.34732, + 0.34992, + 0.34824, + 0.3498, + 0.35008, + 0.35088, + 0.35076, + 0.35254, + 0.35354, + 0.35364, + 0.35176, + 0.35556, + 0.35626, + 0.35458, + 0.35408, + 0.35588, + 0.35876, + 0.35938, + 0.36056, + 0.36056, + 0.35998, + 0.36086, + 0.35718, + 0.35898, + 0.36128, + 0.36022, + 0.3575, + 0.36098, + 0.36178, + 0.36196, + 0.36508, + 0.36294, + 0.3645, + 0.36438, + 0.36428, + 0.36462, + 0.36086, + 0.3643, + 0.3637, + 0.36544, + 0.36562, + 0.36574, + 0.36544, + 0.3643, + 0.36654, + 0.36688, + 0.36584, + 0.36738, + 0.36538, + 0.3685, + 0.36706, + 0.3681, + 0.36614, + 0.36798, + 0.36682, + 0.36688, + 0.36772, + 0.3687, + 0.36828, + 0.36478, + 0.3676, + 0.36774, + 0.36956, + 0.36956, + 0.36768, + 0.3664, + 0.36848 + ], + "test_acc": [ + 0.2789, + 0.3105, + 0.3048, + 0.3148, + 0.323, + 0.3367, + 0.3399, + 0.3468, + 0.3417, + 0.3416, + 0.3511, + 0.3527, + 0.342, + 0.353, + 0.3624, + 0.349, + 0.354, + 0.3675, + 0.3543, + 0.3609, + 0.3613, + 0.3656, + 0.3683, + 0.3572, + 0.365, + 0.3682, + 0.3725, + 0.3578, + 0.3725, + 0.3714, + 0.3683, + 0.3666, + 0.3621, + 0.3689, + 0.3747, + 0.3784, + 0.368, + 0.383, + 0.3698, + 0.3823, + 0.379, + 0.3766, + 0.3793, + 0.3789, + 0.3879, + 0.3813, + 0.373, + 0.3839, + 0.3814, + 0.3842, + 0.3882, + 0.381, + 0.3834, + 0.3841, + 0.3851, + 0.387, + 0.3809, + 0.3833, + 0.3834, + 0.3806, + 0.3854, + 0.3907, + 0.3796, + 0.3892, + 0.3915, + 0.3907, + 0.391, + 0.3903, + 0.3921, + 0.3892, + 0.3916, + 0.3909, + 0.3912, + 0.3923, + 0.3912, + 0.3912, + 0.3937, + 0.3894, + 0.3946, + 0.3919, + 0.3897, + 0.3923, + 0.3973, + 0.392, + 0.3971, + 0.3938, + 0.3908, + 0.3947, + 0.3922, + 0.3913, + 0.3909, + 0.3931, + 0.3937, + 0.3931, + 0.3933, + 0.392, + 0.392, + 0.3926, + 0.3928, + 0.3929 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.028316663578152657, + 0.0230555459856987, + -0.10937326401472092, + 0.9985983967781067 + ], + "perturbation_rho": [ + -0.018305372446775436, + -0.05561023950576782, + 0.04369408264756203, + 0.0030625772196799517 + ], + "nudging": { + "0.001": [ + -1.237727701663971e-06, + -8.684583008289337e-08, + 2.3283064365386963e-09, + -1.1352822184562683e-06 + ], + "0.003": [ + -3.5907141864299774e-06, + -3.394670784473419e-07, + 4.407484084367752e-07, + -4.475703462958336e-06 + ], + "0.01": [ + -1.1818483471870422e-05, + -7.816124707460403e-07, + 1.9797589629888535e-06, + -1.6960781067609787e-05 + ] + }, + "hidden_norms_per_layer": [ + 5829.11083984375, + 96512.5703125, + 1033049.3125, + 1362033.875, + 493509.875 + ], + "bp_grad_norms_per_layer": [ + 1.7891748939291574e-05, + 1.2141871366111445e-06, + 8.851626489558839e-07, + 8.892926075532159e-07, + 8.88764532191999e-07 + ] + }, + "drift": { + "embed.weight": 61.065183980831414, + "embed.bias": 37.30479105851126, + "blocks.0.ln.weight": 1.813413381576538, + "blocks.0.w1.weight": 19.65702484459222, + "blocks.0.w1.bias": 18.43264760143014, + "blocks.0.w2.weight": 80.7069052062468, + "blocks.1.ln.weight": 1.5857534408569336, + "blocks.1.w1.weight": 30.581704417045803, + "blocks.1.w1.bias": 32.64715538359668, + "blocks.1.w2.weight": 51.04559112831385, + "blocks.2.ln.weight": 1.1075904369354248, + "blocks.2.w1.weight": 27.221029191801634, + "blocks.2.w1.bias": 31.23130799360825, + "blocks.2.w2.weight": 40.72448084896488, + "blocks.3.ln.weight": 1.233646035194397, + "blocks.3.w1.weight": 30.670509581621424, + "blocks.3.w1.bias": 34.319449767073166, + "blocks.3.w2.weight": 50.44485584477348, + "out_ln.weight": 0.31939876079559326, + "out_head.weight": 5.424347923413205, + "out_head.bias": 2.5070306021804907 + } + } + }, + "123": { + "fa": { + "log": { + "train_loss": [ + 2.039442294845581, + 1.953594453125, + 1.9264128439712525, + 1.9111677870941162, + 1.9046263507843018, + 1.8970699446868897, + 1.8877645287322997, + 1.8802119188690185, + 1.8676571554565429, + 1.8596859856796264, + 1.8576147148895263, + 1.8543415798950196, + 1.847902306289673, + 1.843341840286255, + 1.8415791506576538, + 1.8326364449691772, + 1.8275398804473877, + 1.825090225830078, + 1.8180097577667236, + 1.812808359451294, + 1.8083131935882568, + 1.806013644104004, + 1.8015524674224854, + 1.7926762564086913, + 1.7879434857177734, + 1.7820632962799072, + 1.781419140357971, + 1.7759368832397462, + 1.7757835355377198, + 1.7717899426651, + 1.7707528903579712, + 1.7696525680923463, + 1.7655191495895386, + 1.765033065109253, + 1.762537513771057, + 1.7627749864959716, + 1.7591946435546875, + 1.7589446157455444, + 1.7533805700683593, + 1.757075972213745, + 1.7514215006256104, + 1.7560174445343018, + 1.7500310370254517, + 1.7495898828125, + 1.7477490761566161, + 1.7467513186264039, + 1.7485883990859985, + 1.7423759655380249, + 1.7408935341644287, + 1.7417366515731811, + 1.7416620874023439, + 1.7395717266845703, + 1.735324585494995, + 1.7391719909667969, + 1.7375017526245118, + 1.738703854637146, + 1.739165567626953, + 1.7372404578018188, + 1.7375502878189086, + 1.740667166519165, + 1.7426861660766602, + 1.7352294305801392, + 1.7360000601577759, + 1.7363917623901368, + 1.7359804382705688, + 1.7362036960601808, + 1.7375494842529298, + 1.7357752603149414, + 1.7346405282211304, + 1.734609012107849, + 1.7384431924057007, + 1.7361973517227174, + 1.7336053954696655, + 1.7382006414794922, + 1.7355592263031006, + 1.7354922631072998, + 1.738483149909973, + 1.733718469810486, + 1.7319189379882813, + 1.7344261421203613, + 1.731631244506836, + 1.7336486722564697, + 1.7326741513061523, + 1.7326253693389893, + 1.7323636251449586, + 1.7339774234771728, + 1.73317182472229, + 1.7338639688491821, + 1.7331051165008544, + 1.7340733782196045, + 1.7288542190170288, + 1.732238768005371, + 1.7312932516860962, + 1.7317837997436523, + 1.7301695737075806, + 1.7338779425430297, + 1.7309321548461913, + 1.726007070274353, + 1.7297403479385376, + 1.7313430319213867 + ], + "train_acc": [ + 0.25042, + 0.28976, + 0.30134, + 0.30602, + 0.3089, + 0.31554, + 0.31428, + 0.3196, + 0.32306, + 0.32846, + 0.33142, + 0.33322, + 0.33642, + 0.33992, + 0.34194, + 0.34438, + 0.34432, + 0.34526, + 0.34954, + 0.35184, + 0.35138, + 0.35452, + 0.35486, + 0.35792, + 0.35968, + 0.36406, + 0.36522, + 0.36342, + 0.3657, + 0.36458, + 0.36718, + 0.36438, + 0.3672, + 0.37022, + 0.37054, + 0.36724, + 0.36996, + 0.37222, + 0.3746, + 0.37048, + 0.37312, + 0.37208, + 0.37302, + 0.37558, + 0.37478, + 0.37606, + 0.37594, + 0.37622, + 0.37692, + 0.37654, + 0.37762, + 0.37828, + 0.38134, + 0.37722, + 0.37784, + 0.37886, + 0.37892, + 0.37816, + 0.38044, + 0.37908, + 0.37866, + 0.37992, + 0.38136, + 0.38112, + 0.38068, + 0.37866, + 0.38062, + 0.38226, + 0.38224, + 0.38394, + 0.37886, + 0.38148, + 0.38066, + 0.37898, + 0.3827, + 0.38146, + 0.38008, + 0.38496, + 0.38278, + 0.3828, + 0.38456, + 0.38362, + 0.38398, + 0.38284, + 0.384, + 0.38284, + 0.38282, + 0.38388, + 0.38298, + 0.38384, + 0.3847, + 0.38326, + 0.38218, + 0.38316, + 0.38288, + 0.38104, + 0.3823, + 0.38442, + 0.38372, + 0.3814 + ], + "test_acc": [ + 0.2905, + 0.3333, + 0.3424, + 0.3445, + 0.3473, + 0.357, + 0.3561, + 0.3439, + 0.3547, + 0.3514, + 0.3624, + 0.3717, + 0.3659, + 0.3715, + 0.3653, + 0.3669, + 0.3677, + 0.375, + 0.3807, + 0.3801, + 0.3731, + 0.3794, + 0.3873, + 0.3816, + 0.3903, + 0.3866, + 0.3876, + 0.387, + 0.3915, + 0.3944, + 0.3878, + 0.391, + 0.3973, + 0.3913, + 0.3949, + 0.4015, + 0.3969, + 0.3971, + 0.3933, + 0.3937, + 0.4026, + 0.3957, + 0.3986, + 0.3994, + 0.4041, + 0.4033, + 0.3987, + 0.4023, + 0.3982, + 0.4046, + 0.4073, + 0.4061, + 0.3998, + 0.4057, + 0.4036, + 0.4066, + 0.4057, + 0.4056, + 0.4099, + 0.4087, + 0.4031, + 0.4052, + 0.4102, + 0.4108, + 0.4056, + 0.4083, + 0.4111, + 0.4113, + 0.4088, + 0.4048, + 0.4085, + 0.4104, + 0.4085, + 0.4067, + 0.4074, + 0.406, + 0.4059, + 0.4112, + 0.4093, + 0.407, + 0.4077, + 0.4078, + 0.4096, + 0.4083, + 0.4103, + 0.4093, + 0.4092, + 0.4094, + 0.4095, + 0.4093, + 0.41, + 0.409, + 0.4071, + 0.4087, + 0.4098, + 0.4095, + 0.4103, + 0.4095, + 0.4099, + 0.4099 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.024554381147027016, + 0.11956729739904404, + -0.05824935808777809, + 0.9959120750427246 + ], + "perturbation_rho": [ + 0.04051800072193146, + 0.000565994530916214, + -0.01674405112862587, + 0.03625570237636566 + ], + "nudging": { + "0.001": [ + -1.0242220014333725e-06, + -8.98377038538456e-07, + 1.8649734556674957e-07, + -3.294786438345909e-06 + ], + "0.003": [ + -2.9887305572628975e-06, + -2.6676570996642113e-06, + 6.059417501091957e-07, + -1.0517891496419907e-05 + ], + "0.01": [ + -1.013989094644785e-05, + -9.021139703691006e-06, + 1.9457656890153885e-06, + -3.530189860612154e-05 + ] + }, + "hidden_norms_per_layer": [ + 4442.23876953125, + 69271.0234375, + 112683.53125, + 377440.78125, + 101219.984375 + ], + "bp_grad_norms_per_layer": [ + 3.0517328923451714e-05, + 3.3021829040080775e-06, + 1.6336410908479593e-06, + 1.6025945797082386e-06, + 1.6025023796828464e-06 + ] + }, + "drift": { + "embed.weight": 42.23262803830075, + "embed.bias": 38.89359718847897, + "blocks.0.ln.weight": 1.485857367515564, + "blocks.0.w1.weight": 16.185781545723177, + "blocks.0.w1.bias": 19.38631244368407, + "blocks.0.w2.weight": 67.24000322735762, + "blocks.1.ln.weight": 1.1729528903961182, + "blocks.1.w1.weight": 16.924164447681708, + "blocks.1.w1.bias": 12.398522096155984, + "blocks.1.w2.weight": 66.57045868925029, + "blocks.2.ln.weight": 1.1481270790100098, + "blocks.2.w1.weight": 22.87352078284913, + "blocks.2.w1.bias": 15.551994064272593, + "blocks.2.w2.weight": 47.0174765021604, + "blocks.3.ln.weight": 1.3614349365234375, + "blocks.3.w1.weight": 24.802445502073887, + "blocks.3.w1.bias": 20.203860687494405, + "blocks.3.w2.weight": 27.95273009401102, + "out_ln.weight": 0.2712816894054413, + "out_head.weight": 3.4626060309461333, + "out_head.bias": 5.313953466040826 + } + } + }, + "456": { + "fa": { + "log": { + "train_loss": [ + 2.064060781517029, + 1.9899321492004394, + 1.9583305028533935, + 1.9413242259979249, + 1.9259093924713135, + 1.9170702835845947, + 1.9052985552215576, + 1.896912885131836, + 1.8825988665771485, + 1.8796403802490234, + 1.8680178726577759, + 1.866754577331543, + 1.8631378329849244, + 1.8581646213150025, + 1.8511347687530517, + 1.8506925588226317, + 1.8502174087905883, + 1.844791806640625, + 1.8447570569229126, + 1.8405440747833253, + 1.8362063940048219, + 1.8338028798675536, + 1.8322235885238647, + 1.8271099829483033, + 1.8259497284698487, + 1.8257087524032594, + 1.8204508585357666, + 1.8167411252593995, + 1.8126173525619507, + 1.8123734192276002, + 1.8087152871322631, + 1.808808218383789, + 1.8022823331069946, + 1.801023843383789, + 1.7965255283355712, + 1.7960807333755493, + 1.7931187894439697, + 1.7941291955947876, + 1.7897003564453124, + 1.7906446847534179, + 1.786394164466858, + 1.7862309146881103, + 1.780618618774414, + 1.7832628304672242, + 1.7785604644012452, + 1.7788331238174437, + 1.7817785488510132, + 1.7777097594833373, + 1.7799998865509032, + 1.775255680809021, + 1.7771116664886475, + 1.7775196334075927, + 1.7734333945465088, + 1.774940083694458, + 1.7707776821517944, + 1.7724536269378661, + 1.7712063220214844, + 1.769450744857788, + 1.7644718740463257, + 1.7703942450332641, + 1.7682801098251342, + 1.7650428762817383, + 1.766692326889038, + 1.7640820398330688, + 1.7642735467147828, + 1.7657033821868897, + 1.7645520711517333, + 1.7599635735702515, + 1.7630397510528564, + 1.7610415447616576, + 1.7589880530548097, + 1.759702002029419, + 1.758202407836914, + 1.7584852575683594, + 1.758759799156189, + 1.7612190747833252, + 1.756975512008667, + 1.7578368276977538, + 1.752756251296997, + 1.7538298551177978, + 1.7545775293350219, + 1.7533795291137695, + 1.7553060709381103, + 1.7546167678833007, + 1.7510459091186523, + 1.7534167892074586, + 1.7509777270889282, + 1.7524606351089478, + 1.753954182395935, + 1.7540196291732788, + 1.7530623047637939, + 1.754902084274292, + 1.7526578115844726, + 1.7541147797012329, + 1.7511920746612548, + 1.7532148589324952, + 1.7491467144393922, + 1.7507730041885377, + 1.7531463625335693, + 1.7500363696670531 + ], + "train_acc": [ + 0.24014, + 0.26772, + 0.28264, + 0.29302, + 0.29948, + 0.30404, + 0.30992, + 0.31306, + 0.3199, + 0.31706, + 0.32588, + 0.32474, + 0.32728, + 0.32896, + 0.33068, + 0.332, + 0.33344, + 0.3358, + 0.33592, + 0.33646, + 0.33824, + 0.34346, + 0.34112, + 0.3446, + 0.3442, + 0.34498, + 0.34916, + 0.35224, + 0.3488, + 0.35114, + 0.35106, + 0.35238, + 0.35462, + 0.35262, + 0.35602, + 0.35642, + 0.3596, + 0.35778, + 0.35798, + 0.35972, + 0.36258, + 0.36078, + 0.36174, + 0.36244, + 0.36306, + 0.3641, + 0.3633, + 0.361, + 0.36384, + 0.36622, + 0.36642, + 0.3651, + 0.36634, + 0.36576, + 0.3673, + 0.36726, + 0.36696, + 0.36972, + 0.3705, + 0.36848, + 0.36996, + 0.36794, + 0.37034, + 0.37048, + 0.37012, + 0.367, + 0.37118, + 0.37072, + 0.37212, + 0.37374, + 0.37076, + 0.37256, + 0.37136, + 0.37392, + 0.37394, + 0.3732, + 0.37276, + 0.37222, + 0.37412, + 0.3723, + 0.37526, + 0.37592, + 0.37682, + 0.37498, + 0.37556, + 0.37602, + 0.3759, + 0.37684, + 0.37452, + 0.37624, + 0.3739, + 0.37434, + 0.37596, + 0.37428, + 0.37492, + 0.3748, + 0.37806, + 0.37718, + 0.37634, + 0.37502 + ], + "test_acc": [ + 0.2713, + 0.3034, + 0.3137, + 0.3223, + 0.3204, + 0.3317, + 0.3372, + 0.3338, + 0.3448, + 0.3422, + 0.3584, + 0.3723, + 0.3648, + 0.363, + 0.3547, + 0.3679, + 0.3611, + 0.3653, + 0.3704, + 0.3666, + 0.3724, + 0.3691, + 0.3744, + 0.3743, + 0.379, + 0.375, + 0.3703, + 0.3688, + 0.3839, + 0.3808, + 0.3808, + 0.3799, + 0.3901, + 0.3821, + 0.3861, + 0.388, + 0.3872, + 0.3823, + 0.3851, + 0.3872, + 0.3816, + 0.3883, + 0.3914, + 0.3866, + 0.3824, + 0.3861, + 0.3912, + 0.3831, + 0.3882, + 0.39, + 0.3935, + 0.3951, + 0.3927, + 0.3947, + 0.3931, + 0.3904, + 0.3849, + 0.3965, + 0.389, + 0.3957, + 0.3993, + 0.3929, + 0.3939, + 0.3988, + 0.3972, + 0.3919, + 0.3956, + 0.3937, + 0.3964, + 0.3968, + 0.3954, + 0.3934, + 0.3957, + 0.3954, + 0.3983, + 0.399, + 0.3964, + 0.3988, + 0.3968, + 0.3999, + 0.3998, + 0.3988, + 0.3981, + 0.3991, + 0.3994, + 0.4001, + 0.3957, + 0.3991, + 0.3996, + 0.3992, + 0.4004, + 0.3989, + 0.3985, + 0.4004, + 0.4008, + 0.3998, + 0.4001, + 0.3995, + 0.3999, + 0.3996 + ] + }, + "diagnostics": { + "bp_cosine": [ + -0.005567306652665138, + 0.07410226762294769, + -0.08782696723937988, + 0.9968396425247192 + ], + "perturbation_rho": [ + 0.018982190638780594, + 0.020260581746697426, + -0.010329723358154297, + -0.0005609926301985979 + ], + "nudging": { + "0.001": [ + 1.0007061064243317e-06, + -3.9103906601667404e-07, + 1.1315569281578064e-07, + -1.5826663002371788e-06 + ], + "0.003": [ + 3.3027026802301407e-06, + -1.0563526302576065e-06, + 3.9511360228061676e-07, + -5.659414455294609e-06 + ], + "0.01": [ + 1.0776100680232048e-05, + -3.261142410337925e-06, + 1.505366526544094e-06, + -1.979817170649767e-05 + ] + }, + "hidden_norms_per_layer": [ + 7183.7138671875, + 54252.28125, + 433092.46875, + 853070.375, + 358308.375 + ], + "bp_grad_norms_per_layer": [ + 2.274825965287164e-05, + 2.2719700609741267e-06, + 1.0948514272968168e-06, + 1.0902091389652924e-06, + 1.0879763294724398e-06 + ] + }, + "drift": { + "embed.weight": 59.198994919141114, + "embed.bias": 39.04210511940659, + "blocks.0.ln.weight": 1.5643459558486938, + "blocks.0.w1.weight": 17.61955302910138, + "blocks.0.w1.bias": 14.864856243639734, + "blocks.0.w2.weight": 72.0693062056265, + "blocks.1.ln.weight": 1.3596134185791016, + "blocks.1.w1.weight": 23.46545214362022, + "blocks.1.w1.bias": 20.737719695924945, + "blocks.1.w2.weight": 57.54218956802357, + "blocks.2.ln.weight": 1.0050276517868042, + "blocks.2.w1.weight": 26.292893266385775, + "blocks.2.w1.bias": 30.330843617032436, + "blocks.2.w2.weight": 38.29821365284518, + "blocks.3.ln.weight": 1.097004771232605, + "blocks.3.w1.weight": 25.15473199291234, + "blocks.3.w1.bias": 27.34892428441179, + "blocks.3.w2.weight": 48.94715322109553, + "out_ln.weight": 0.27024972438812256, + "out_head.weight": 4.837433858659456, + "out_head.bias": 1.4275285869872485 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42, + 123, + 456 + ], + "gpu": 0, + "output_dir": "results/fa_main_audit", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_no_penalty_30ep/results_cifar10.json b/results/fa_no_penalty_30ep/results_cifar10.json new file mode 100644 index 0000000..b555be4 --- /dev/null +++ b/results/fa_no_penalty_30ep/results_cifar10.json @@ -0,0 +1,549 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.049124941177368, + 1.971962617225647, + 1.948333747253418, + 1.93432448513031, + 1.9224566204071045, + 1.911075694732666, + 1.9016322410583497, + 1.891839779663086, + 1.888196748046875, + 1.8842478870391846, + 1.8793783734512328, + 1.873243000831604, + 1.8691495156478881, + 1.8672306230163573, + 1.8655509171295166, + 1.8641471955108642, + 1.8627385038757325, + 1.8566293866348267, + 1.8572382849121094, + 1.854542728805542, + 1.853432930831909, + 1.853121408996582, + 1.8526418188476563, + 1.850091189842224, + 1.8493881246566772, + 1.8482141668319703, + 1.8446919596099853, + 1.8451624103164672, + 1.8452935827255248, + 1.842558771057129 + ], + "train_acc": [ + 0.2436, + 0.27344, + 0.28986, + 0.29662, + 0.30162, + 0.30408, + 0.31004, + 0.31504, + 0.31764, + 0.31756, + 0.32138, + 0.32436, + 0.32408, + 0.32572, + 0.33018, + 0.32902, + 0.3305, + 0.33264, + 0.33294, + 0.3324, + 0.3352, + 0.33392, + 0.33506, + 0.33476, + 0.33738, + 0.33864, + 0.3368, + 0.3421, + 0.34022, + 0.34144 + ], + "test_acc": [ + 0.2789, + 0.3106, + 0.3051, + 0.3149, + 0.3232, + 0.3345, + 0.3394, + 0.3473, + 0.3416, + 0.3427, + 0.3494, + 0.3508, + 0.3543, + 0.3612, + 0.3602, + 0.3571, + 0.3605, + 0.3616, + 0.3652, + 0.363, + 0.3649, + 0.3599, + 0.3637, + 0.3606, + 0.3631, + 0.3672, + 0.3653, + 0.3664, + 0.3659, + 0.3655 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.039796918630599976, + -0.0086748618632555, + -0.05908923223614693, + 0.9973034262657166 + ], + "perturbation_rho": [ + 0.014420752413570881, + 0.010810372419655323, + -0.0031975782476365566, + 0.018547791987657547 + ], + "nudging": { + "0.001": [ + -1.748558133840561e-06, + 4.516914486885071e-08, + 1.9744038581848145e-07, + -3.418419510126114e-06 + ], + "0.003": [ + -5.089212208986282e-06, + -9.313225746154785e-09, + 6.165355443954468e-07, + -1.0619405657052994e-05 + ], + "0.01": [ + -1.7355196177959442e-05, + -1.9045546650886536e-07, + 2.042856067419052e-06, + -3.545219078660011e-05 + ] + }, + "hidden_norms_per_layer": [ + 3135.737060546875, + 27729.080078125, + 142235.453125, + 221539.90625, + 98576.6796875 + ], + "bp_grad_norms_per_layer": [ + 2.2086309400037862e-05, + 2.296714228577912e-06, + 1.807363787520444e-06, + 1.8235258494314621e-06, + 1.7601513491172227e-06 + ] + }, + "drift": { + "embed.weight": 27.4249662858567, + "embed.bias": 13.120389146639068, + "blocks.0.ln.weight": 1.030566692352295, + "blocks.0.w1.weight": 12.544252287230748, + "blocks.0.w1.bias": 11.915807637034495, + "blocks.0.w2.weight": 40.35155370112267, + "blocks.1.ln.weight": 0.8229038715362549, + "blocks.1.w1.weight": 13.580795814003102, + "blocks.1.w1.bias": 13.535923105255597, + "blocks.1.w2.weight": 26.893514153884407, + "blocks.2.ln.weight": 0.7509615421295166, + "blocks.2.w1.weight": 12.776503039551056, + "blocks.2.w1.bias": 13.791977535501145, + "blocks.2.w2.weight": 24.833686221953496, + "blocks.3.ln.weight": 0.7571015954017639, + "blocks.3.w1.weight": 12.83601203126826, + "blocks.3.w1.bias": 14.312451854403847, + "blocks.3.w2.weight": 27.882819964995566, + "out_ln.weight": 0.14845088124275208, + "out_head.weight": 2.4655734611964566, + "out_head.bias": 1.0148059705009767 + } + } + }, + "123": { + "fa": { + "log": { + "train_loss": [ + 2.039442294845581, + 1.9535883404159546, + 1.926400708694458, + 1.911022697982788, + 1.9043336145401002, + 1.8967472591400147, + 1.8879544982147216, + 1.880874624710083, + 1.8684730854034424, + 1.8600735509490967, + 1.856403991165161, + 1.8502347243499755, + 1.8449485485458375, + 1.8409737787246705, + 1.8416496984100341, + 1.834242384109497, + 1.8294944548797607, + 1.8298376247406005, + 1.826137925453186, + 1.8233746087646485, + 1.82085874168396, + 1.8194894942855835, + 1.8188257317733765, + 1.8151497143936157, + 1.813930884475708, + 1.8120647540283203, + 1.813437055130005, + 1.8121181848526, + 1.8119350430297851, + 1.8114447354888916 + ], + "train_acc": [ + 0.25042, + 0.2898, + 0.30126, + 0.3061, + 0.3093, + 0.31594, + 0.31488, + 0.31876, + 0.3224, + 0.32746, + 0.3304, + 0.33324, + 0.33658, + 0.34018, + 0.34052, + 0.34376, + 0.34472, + 0.34468, + 0.34946, + 0.34906, + 0.34908, + 0.35074, + 0.35076, + 0.35142, + 0.35312, + 0.35506, + 0.3535, + 0.35112, + 0.35142, + 0.35268 + ], + "test_acc": [ + 0.2905, + 0.3334, + 0.3427, + 0.3456, + 0.347, + 0.3554, + 0.355, + 0.344, + 0.3537, + 0.3514, + 0.3589, + 0.3694, + 0.361, + 0.3679, + 0.3707, + 0.3645, + 0.3766, + 0.3737, + 0.3685, + 0.3753, + 0.3749, + 0.3762, + 0.3779, + 0.3776, + 0.3771, + 0.3778, + 0.3785, + 0.3791, + 0.3788, + 0.3793 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.04084893688559532, + 0.03916510194540024, + -0.05586155131459236, + 0.9959583282470703 + ], + "perturbation_rho": [ + 0.037160176783800125, + 0.010217813774943352, + 0.009163782931864262, + 0.09005501866340637 + ], + "nudging": { + "0.001": [ + -3.252178430557251e-06, + -4.302710294723511e-07, + 2.551823854446411e-07, + -5.400157533586025e-06 + ], + "0.003": [ + -9.553623385727406e-06, + -1.2621749192476273e-06, + 7.917406037449837e-07, + -1.6256701201200485e-05 + ], + "0.01": [ + -3.1867995858192444e-05, + -4.215282388031483e-06, + 2.7384376153349876e-06, + -5.423836410045624e-05 + ] + }, + "hidden_norms_per_layer": [ + 2199.5390625, + 23238.517578125, + 33747.62109375, + 60688.390625, + 42941.73828125 + ], + "bp_grad_norms_per_layer": [ + 3.650465077953413e-05, + 3.7466418234544108e-06, + 2.5758349693205673e-06, + 2.5369565719302045e-06, + 2.5125191314145923e-06 + ] + }, + "drift": { + "embed.weight": 22.579546770636362, + "embed.bias": 11.99376839559686, + "blocks.0.ln.weight": 0.9874183535575867, + "blocks.0.w1.weight": 11.111551574509331, + "blocks.0.w1.bias": 10.99245593612019, + "blocks.0.w2.weight": 37.55172206485303, + "blocks.1.ln.weight": 0.8036943674087524, + "blocks.1.w1.weight": 11.271966509968653, + "blocks.1.w1.bias": 9.571436867060473, + "blocks.1.w2.weight": 31.072497431661937, + "blocks.2.ln.weight": 0.7531070709228516, + "blocks.2.w1.weight": 12.21459668569711, + "blocks.2.w1.bias": 10.794010887239482, + "blocks.2.w2.weight": 27.176995851908416, + "blocks.3.ln.weight": 0.7162689566612244, + "blocks.3.w1.weight": 10.48127263461816, + "blocks.3.w1.bias": 11.865909016175292, + "blocks.3.w2.weight": 20.70822324996214, + "out_ln.weight": 0.11810445040464401, + "out_head.weight": 2.0129756552130713, + "out_head.bias": 2.429219803928322 + } + } + }, + "456": { + "fa": { + "log": { + "train_loss": [ + 2.064060781517029, + 1.9899379243469237, + 1.9583307926940918, + 1.9413945766830445, + 1.9259979123687745, + 1.917475108642578, + 1.9057074185180665, + 1.8973549193572998, + 1.8837018002319337, + 1.880822947654724, + 1.8696432660675049, + 1.8690105139541626, + 1.864671703262329, + 1.860279491958618, + 1.8527126348114014, + 1.852656520614624, + 1.8518596002197265, + 1.8463533575820923, + 1.8458245706939698, + 1.8417340802764892, + 1.8398325180053712, + 1.8388891097640991, + 1.8387157396697997, + 1.834438217163086, + 1.835012681427002, + 1.835493402671814, + 1.8353626998138428, + 1.8342380188751222, + 1.8330915405654906, + 1.8353773220443725 + ], + "train_acc": [ + 0.24014, + 0.26754, + 0.28264, + 0.29278, + 0.29936, + 0.30364, + 0.3098, + 0.3125, + 0.3179, + 0.31672, + 0.32526, + 0.32416, + 0.32716, + 0.32892, + 0.33012, + 0.3305, + 0.33536, + 0.33526, + 0.3355, + 0.33674, + 0.3366, + 0.33948, + 0.3374, + 0.34192, + 0.34176, + 0.33942, + 0.33996, + 0.34336, + 0.34066, + 0.34118 + ], + "test_acc": [ + 0.2713, + 0.3036, + 0.3138, + 0.3217, + 0.3206, + 0.3309, + 0.3341, + 0.3322, + 0.3431, + 0.3426, + 0.3568, + 0.3643, + 0.359, + 0.3584, + 0.3556, + 0.3673, + 0.3669, + 0.366, + 0.3678, + 0.3677, + 0.3664, + 0.3691, + 0.369, + 0.3699, + 0.3699, + 0.3695, + 0.3706, + 0.3701, + 0.3698, + 0.3699 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.005443892907351255, + 0.056010693311691284, + -0.02994656004011631, + 0.993930459022522 + ], + "perturbation_rho": [ + 0.041760556399822235, + 0.011946016922593117, + 0.03664176166057587, + 0.019783515483140945 + ], + "nudging": { + "0.001": [ + -4.184548743069172e-07, + -1.57160684466362e-07, + 1.1589145287871361e-07, + -2.3843604139983654e-06 + ], + "0.003": [ + -1.4954712241888046e-06, + -6.176414899528027e-07, + 2.6955967769026756e-07, + -7.573107723146677e-06 + ], + "0.01": [ + -4.984147381037474e-06, + -2.076325472444296e-06, + 7.987837307155132e-07, + -2.5291461497545242e-05 + ] + }, + "hidden_norms_per_layer": [ + 2585.784423828125, + 30518.099609375, + 125489.96875, + 234551.46875, + 135467.90625 + ], + "bp_grad_norms_per_layer": [ + 2.3132513888413087e-05, + 1.7084096270991722e-06, + 1.3069517308395007e-06, + 1.3104119034323958e-06, + 1.2661037089856109e-06 + ] + }, + "drift": { + "embed.weight": 25.036403923789077, + "embed.bias": 16.467870087010475, + "blocks.0.ln.weight": 1.1242742538452148, + "blocks.0.w1.weight": 13.107427774588935, + "blocks.0.w1.bias": 10.049346359618044, + "blocks.0.w2.weight": 42.48070715823368, + "blocks.1.ln.weight": 0.858128547668457, + "blocks.1.w1.weight": 13.379689964096187, + "blocks.1.w1.bias": 11.437785495003597, + "blocks.1.w2.weight": 32.57709335591351, + "blocks.2.ln.weight": 0.724311351776123, + "blocks.2.w1.weight": 13.542630387751261, + "blocks.2.w1.bias": 15.41909845322556, + "blocks.2.w2.weight": 27.566931719808178, + "blocks.3.ln.weight": 0.7379146218299866, + "blocks.3.w1.weight": 12.288692457217095, + "blocks.3.w1.bias": 12.33989285228832, + "blocks.3.w2.weight": 30.803534080173662, + "out_ln.weight": 0.14593681693077087, + "out_head.weight": 2.504909261880854, + "out_head.bias": 1.0678192714646333 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 30, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42, + 123, + 456 + ], + "gpu": 0, + "output_dir": "results/fa_no_penalty_30ep", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_penalty_30ep/results_cifar10.json b/results/fa_penalty_30ep/results_cifar10.json new file mode 100644 index 0000000..bd94af9 --- /dev/null +++ b/results/fa_penalty_30ep/results_cifar10.json @@ -0,0 +1,549 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.001119369430542, + 1.9333542645263673, + 1.9240801361083983, + 1.9216663201141357, + 1.918965964050293, + 1.9119908702850341, + 1.9097738315582276, + 1.9049057181167603, + 1.9038385245513916, + 1.901488224105835, + 1.8975307934951782, + 1.8936708712768555, + 1.8894188814544677, + 1.8864306787872314, + 1.8895101571655273, + 1.8835669344329835, + 1.8832266155624389, + 1.8791197660064698, + 1.8797098063659667, + 1.8761708308410645, + 1.875095842552185, + 1.8748220810317993, + 1.8743982402801513, + 1.8738871404266357, + 1.8735411545562743, + 1.8693548073577881, + 1.8659294427871704, + 1.8685178174591064, + 1.8670922566986083, + 1.8683236782455444 + ], + "train_acc": [ + 0.27484, + 0.30662, + 0.31156, + 0.3128, + 0.31684, + 0.3173, + 0.32102, + 0.32226, + 0.3248, + 0.32594, + 0.32582, + 0.33134, + 0.33202, + 0.33344, + 0.33274, + 0.33626, + 0.33556, + 0.33774, + 0.33918, + 0.33762, + 0.34082, + 0.34052, + 0.34092, + 0.34432, + 0.34294, + 0.3436, + 0.34578, + 0.34654, + 0.34636, + 0.3472 + ], + "test_acc": [ + 0.3237, + 0.3505, + 0.3372, + 0.3327, + 0.3511, + 0.3455, + 0.3449, + 0.344, + 0.3413, + 0.3462, + 0.35, + 0.3511, + 0.3604, + 0.3569, + 0.3565, + 0.3584, + 0.3634, + 0.3651, + 0.3684, + 0.362, + 0.3674, + 0.3586, + 0.3704, + 0.3688, + 0.3729, + 0.3715, + 0.3702, + 0.3713, + 0.3715, + 0.3713 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.154661625623703, + 0.148232564330101, + 0.11714619398117065, + 0.9808143377304077 + ], + "perturbation_rho": [ + 0.2043638527393341, + 0.03451352193951607, + 0.20099492371082306, + 0.6849091649055481 + ], + "nudging": { + "0.001": [ + -7.456343155354261e-06, + -7.0015667006373405e-06, + -6.1238533817231655e-06, + -4.935957258567214e-05 + ], + "0.003": [ + -2.2464897483587265e-05, + -2.095731906592846e-05, + -1.8515565898269415e-05, + -0.00014814879978075624 + ], + "0.01": [ + -7.48702441342175e-05, + -6.984942592680454e-05, + -6.168842082843184e-05, + -0.0004938761703670025 + ] + }, + "hidden_norms_per_layer": [ + 12005.7001953125, + 12007.416015625, + 12011.4248046875, + 12003.9306640625, + 12006.5830078125 + ], + "bp_grad_norms_per_layer": [ + 1.7161242794827558e-05, + 1.712461198621895e-05, + 1.7262802430195734e-05, + 1.6996016711345874e-05, + 1.602559132152237e-05 + ] + }, + "drift": { + "embed.weight": 78.51231332812849, + "embed.bias": 102.12148052673881, + "blocks.0.ln.weight": 0.2991768419742584, + "blocks.0.w1.weight": 3.288291856838975, + "blocks.0.w1.bias": 5.396373593768907, + "blocks.0.w2.weight": 18.07176198506393, + "blocks.1.ln.weight": 0.30484291911125183, + "blocks.1.w1.weight": 3.363888651672944, + "blocks.1.w1.bias": 5.6524749125654825, + "blocks.1.w2.weight": 19.430774882635927, + "blocks.2.ln.weight": 0.3102824091911316, + "blocks.2.w1.weight": 3.4139876185805202, + "blocks.2.w1.bias": 5.452178486528747, + "blocks.2.w2.weight": 20.125989565290535, + "blocks.3.ln.weight": 0.3048917055130005, + "blocks.3.w1.weight": 3.480676314799551, + "blocks.3.w1.bias": 5.215887111238424, + "blocks.3.w2.weight": 18.601435796851646, + "out_ln.weight": 0.20746967196464539, + "out_head.weight": 2.830159264066497, + "out_head.bias": 1.3353853268796754 + } + } + }, + "123": { + "fa": { + "log": { + "train_loss": [ + 1.9927266687774658, + 1.9300689645767213, + 1.92596540309906, + 1.9193654559326172, + 1.9174729382324218, + 1.914334167137146, + 1.9108343372344971, + 1.9098404777526856, + 1.9070983071517944, + 1.9046820580291748, + 1.9035546282958984, + 1.902319673461914, + 1.9007709796524048, + 1.8978776998901368, + 1.896000676651001, + 1.8922497060394288, + 1.889219735031128, + 1.8912379706573486, + 1.8899934759902954, + 1.886697685585022, + 1.8858332043457031, + 1.8851475219726563, + 1.8850083080673217, + 1.8836419734191894, + 1.8818570189666748, + 1.880578341293335, + 1.8828362002563477, + 1.8791149730682373, + 1.881042135925293, + 1.8783651885986328 + ], + "train_acc": [ + 0.28066, + 0.30852, + 0.31142, + 0.3131, + 0.31726, + 0.31894, + 0.31884, + 0.3222, + 0.32346, + 0.325, + 0.32962, + 0.32534, + 0.32718, + 0.32974, + 0.33498, + 0.33296, + 0.33498, + 0.33642, + 0.33698, + 0.33676, + 0.3368, + 0.33892, + 0.33858, + 0.34048, + 0.34248, + 0.34264, + 0.34458, + 0.34442, + 0.34196, + 0.3428 + ], + "test_acc": [ + 0.3339, + 0.344, + 0.3447, + 0.3537, + 0.3523, + 0.3448, + 0.3569, + 0.3548, + 0.3564, + 0.3513, + 0.3558, + 0.3555, + 0.3578, + 0.3565, + 0.3511, + 0.3617, + 0.3621, + 0.3603, + 0.3619, + 0.36, + 0.3663, + 0.3679, + 0.3665, + 0.3645, + 0.363, + 0.3632, + 0.3656, + 0.3666, + 0.3666, + 0.366 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.10381253063678741, + 0.11509518325328827, + 0.17561045289039612, + 0.9849518537521362 + ], + "perturbation_rho": [ + -0.013579179532825947, + 0.098294198513031, + 0.1665695607662201, + 0.7168847322463989 + ], + "nudging": { + "0.001": [ + -3.704102709889412e-06, + -5.607143975794315e-06, + -8.841510862112045e-06, + -5.108967889100313e-05 + ], + "0.003": [ + -1.0923598892986774e-05, + -1.6769510693848133e-05, + -2.6412191800773144e-05, + -0.00015316426288336515 + ], + "0.01": [ + -3.6393641494214535e-05, + -5.586724728345871e-05, + -8.803850505501032e-05, + -0.0005103998119011521 + ] + }, + "hidden_norms_per_layer": [ + 12513.1455078125, + 12514.5771484375, + 12517.3583984375, + 12520.568359375, + 12517.57421875 + ], + "bp_grad_norms_per_layer": [ + 1.9153141693095677e-05, + 1.9135366528644226e-05, + 1.8911112420028076e-05, + 1.892356522148475e-05, + 1.7506923541077413e-05 + ] + }, + "drift": { + "embed.weight": 89.53228672658916, + "embed.bias": 152.8584432062178, + "blocks.0.ln.weight": 0.2654326856136322, + "blocks.0.w1.weight": 3.3634291127986065, + "blocks.0.w1.bias": 5.280393391584056, + "blocks.0.w2.weight": 19.28885452796699, + "blocks.1.ln.weight": 0.2610284984111786, + "blocks.1.w1.weight": 3.3605846097214807, + "blocks.1.w1.bias": 4.889820303164699, + "blocks.1.w2.weight": 20.281310673104024, + "blocks.2.ln.weight": 0.2743469476699829, + "blocks.2.w1.weight": 3.391098035259505, + "blocks.2.w1.bias": 4.733926864945133, + "blocks.2.w2.weight": 20.84238400209855, + "blocks.3.ln.weight": 0.28759220242500305, + "blocks.3.w1.weight": 3.4831818014183473, + "blocks.3.w1.bias": 4.539792496168549, + "blocks.3.w2.weight": 17.922103982945213, + "out_ln.weight": 0.23056229948997498, + "out_head.weight": 2.947940047368198, + "out_head.bias": 0.999532168893943 + } + } + }, + "456": { + "fa": { + "log": { + "train_loss": [ + 2.007471420669556, + 1.9475741510772706, + 1.935076278114319, + 1.9286792455673218, + 1.9233443587493897, + 1.921987562637329, + 1.9133965182113648, + 1.9098651378631593, + 1.9032709611511232, + 1.9005253164291382, + 1.8951593602752685, + 1.8959063681793213, + 1.8941862697982788, + 1.8903014281845092, + 1.887538351173401, + 1.8868630823516845, + 1.8855092962646485, + 1.882849036026001, + 1.88272844581604, + 1.880844307899475, + 1.8779692416000366, + 1.8777563387298584, + 1.8754584911346435, + 1.8720009698104858, + 1.8718394606018067, + 1.871922555580139, + 1.8742619204330444, + 1.871095755996704, + 1.8686067379379272, + 1.8697797116470336 + ], + "train_acc": [ + 0.27346, + 0.30166, + 0.30968, + 0.31384, + 0.3171, + 0.31864, + 0.32234, + 0.32546, + 0.32768, + 0.32608, + 0.33286, + 0.33212, + 0.33646, + 0.33228, + 0.33466, + 0.33496, + 0.33834, + 0.34088, + 0.33898, + 0.34018, + 0.3417, + 0.34182, + 0.3446, + 0.34544, + 0.34776, + 0.34688, + 0.3461, + 0.34764, + 0.3487, + 0.34986 + ], + "test_acc": [ + 0.3312, + 0.3461, + 0.3375, + 0.3554, + 0.3417, + 0.349, + 0.3492, + 0.3315, + 0.3524, + 0.3561, + 0.3452, + 0.3612, + 0.3675, + 0.361, + 0.3588, + 0.3671, + 0.3652, + 0.3621, + 0.3608, + 0.3682, + 0.3607, + 0.3563, + 0.3698, + 0.3701, + 0.3677, + 0.373, + 0.3704, + 0.3674, + 0.37, + 0.3695 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.1374516785144806, + 0.16781684756278992, + 0.13626089692115784, + 0.9774131774902344 + ], + "perturbation_rho": [ + 0.0486995093524456, + 0.06349372863769531, + 0.007340744137763977, + 0.6342363357543945 + ], + "nudging": { + "0.001": [ + -5.1066745072603226e-06, + -7.127760909497738e-06, + -6.226240657269955e-06, + -4.646868910640478e-05 + ], + "0.003": [ + -1.5374505892395973e-05, + -2.130062784999609e-05, + -1.874461304396391e-05, + -0.0001394655555486679 + ], + "0.01": [ + -5.139666609466076e-05, + -7.112661842256784e-05, + -6.259605288505554e-05, + -0.0004648104077205062 + ] + }, + "hidden_norms_per_layer": [ + 12142.6181640625, + 12149.4814453125, + 12152.4228515625, + 12154.3623046875, + 12153.6318359375 + ], + "bp_grad_norms_per_layer": [ + 1.7370659406878985e-05, + 1.7375436073052697e-05, + 1.7465752534917556e-05, + 1.755404082359746e-05, + 1.6847530787345022e-05 + ] + }, + "drift": { + "embed.weight": 81.50708429789378, + "embed.bias": 90.80998898870095, + "blocks.0.ln.weight": 0.30019262433052063, + "blocks.0.w1.weight": 3.2393405099782573, + "blocks.0.w1.bias": 5.11706466812565, + "blocks.0.w2.weight": 17.26999421209171, + "blocks.1.ln.weight": 0.29391026496887207, + "blocks.1.w1.weight": 3.2524768916503883, + "blocks.1.w1.bias": 5.375142102957966, + "blocks.1.w2.weight": 17.822565033735142, + "blocks.2.ln.weight": 0.29711613059043884, + "blocks.2.w1.weight": 3.347203060867532, + "blocks.2.w1.bias": 5.458224100319586, + "blocks.2.w2.weight": 19.038612136675116, + "blocks.3.ln.weight": 0.3008574843406677, + "blocks.3.w1.weight": 3.407494069111075, + "blocks.3.w1.bias": 5.482950458082574, + "blocks.3.w2.weight": 17.98442064571867, + "out_ln.weight": 0.21060113608837128, + "out_head.weight": 2.7473203281973158, + "out_head.bias": 0.9871044007414218 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 30, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42, + 123, + 456 + ], + "gpu": 0, + "output_dir": "results/fa_penalty_30ep", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.01, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_penalty_lam1e-4_30ep/results_cifar10.json b/results/fa_penalty_lam1e-4_30ep/results_cifar10.json new file mode 100644 index 0000000..d0738b9 --- /dev/null +++ b/results/fa_penalty_lam1e-4_30ep/results_cifar10.json @@ -0,0 +1,549 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.0345889838027955, + 1.9560137537384032, + 1.9360403113555908, + 1.9143632999420166, + 1.8943112557601929, + 1.883546539993286, + 1.87789137008667, + 1.8733353225708007, + 1.8758181957626343, + 1.8700244534301758, + 1.8686021829605102, + 1.8640803202056884, + 1.8597132386779784, + 1.8604257183837891, + 1.860362275390625, + 1.8570359252166748, + 1.8561881994247436, + 1.8504637549209595, + 1.8508178236389161, + 1.8495385042572021, + 1.8523688412475585, + 1.8491829050445556, + 1.849259608154297, + 1.8475904892730712, + 1.846731012802124, + 1.8426731698989869, + 1.8409621490097046, + 1.8425804473876952, + 1.8410273541259765, + 1.8429513037109375 + ], + "train_acc": [ + 0.25574, + 0.29268, + 0.30346, + 0.31094, + 0.31784, + 0.3242, + 0.32714, + 0.3321, + 0.33104, + 0.3346, + 0.33454, + 0.33844, + 0.33788, + 0.33896, + 0.3398, + 0.34038, + 0.3428, + 0.34448, + 0.34504, + 0.34356, + 0.34638, + 0.34678, + 0.3455, + 0.34862, + 0.34636, + 0.35102, + 0.34858, + 0.35316, + 0.3505, + 0.3521 + ], + "test_acc": [ + 0.2909, + 0.3244, + 0.3269, + 0.3335, + 0.3455, + 0.3577, + 0.3581, + 0.3473, + 0.359, + 0.3635, + 0.3513, + 0.3583, + 0.3735, + 0.3642, + 0.3646, + 0.364, + 0.3653, + 0.3734, + 0.3717, + 0.3682, + 0.3792, + 0.3722, + 0.3728, + 0.3747, + 0.3749, + 0.3751, + 0.3748, + 0.375, + 0.3766, + 0.3759 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.03903631120920181, + 0.014399020932614803, + -0.033061157912015915, + 0.8762983083724976 + ], + "perturbation_rho": [ + 0.04046167433261871, + 0.029874827712774277, + 0.018399983644485474, + 0.5844764113426208 + ], + "nudging": { + "0.001": [ + -3.4736585803329945e-06, + -1.6225967556238174e-06, + 8.612405508756638e-07, + -4.002035711891949e-05 + ], + "0.003": [ + -1.0443473001942039e-05, + -4.940724465996027e-06, + 2.589426003396511e-06, + -0.00012009820784442127 + ], + "0.01": [ + -3.477800055406988e-05, + -1.6514735762029886e-05, + 8.722592610865831e-06, + -0.00040041320607997477 + ] + }, + "hidden_norms_per_layer": [ + 9302.052734375, + 9300.826171875, + 9304.12890625, + 9388.5966796875, + 9324.71484375 + ], + "bp_grad_norms_per_layer": [ + 1.8842416466213763e-05, + 1.7771730199456215e-05, + 1.7083899365388788e-05, + 1.6848869563546032e-05, + 1.1624123544606846e-05 + ] + }, + "drift": { + "embed.weight": 59.712949390275256, + "embed.bias": 109.7242290933353, + "blocks.0.ln.weight": 0.4963775873184204, + "blocks.0.w1.weight": 6.181720741000805, + "blocks.0.w1.bias": 4.020397499075911, + "blocks.0.w2.weight": 27.582674728139626, + "blocks.1.ln.weight": 0.5040290951728821, + "blocks.1.w1.weight": 6.5836738314372525, + "blocks.1.w1.bias": 3.6793793299266104, + "blocks.1.w2.weight": 30.666596548839138, + "blocks.2.ln.weight": 0.5469347834587097, + "blocks.2.w1.weight": 6.732229120538098, + "blocks.2.w1.bias": 3.7361472645162905, + "blocks.2.w2.weight": 32.57792259935943, + "blocks.3.ln.weight": 0.5613139867782593, + "blocks.3.w1.weight": 6.581817915516039, + "blocks.3.w1.bias": 3.910112736502136, + "blocks.3.w2.weight": 30.460380132293686, + "out_ln.weight": 0.13369733095169067, + "out_head.weight": 2.149734268892375, + "out_head.bias": 1.8324297050244958 + } + } + }, + "123": { + "fa": { + "log": { + "train_loss": [ + 2.02592181602478, + 1.9371592224121095, + 1.8974695124053955, + 1.8743731897735596, + 1.86856899684906, + 1.866699522743225, + 1.8631246627807616, + 1.8699150874328614, + 1.8676455249404906, + 1.8666385611724854, + 1.8668112030792237, + 1.8644982721710206, + 1.8628531475448609, + 1.8602032104492188, + 1.859476708946228, + 1.8544428533554078, + 1.8534636458969116, + 1.8551011206054688, + 1.854274546432495, + 1.850538496170044, + 1.8506593602752686, + 1.849616947631836, + 1.8474625555419921, + 1.8434547743988037, + 1.8423766018676757, + 1.8429830319976808, + 1.8445982418823241, + 1.842038772354126, + 1.8427192529296874, + 1.8390978100585937 + ], + "train_acc": [ + 0.26002, + 0.30074, + 0.31632, + 0.32608, + 0.33058, + 0.3327, + 0.33324, + 0.3331, + 0.33306, + 0.3358, + 0.33708, + 0.33474, + 0.33836, + 0.33936, + 0.34254, + 0.34292, + 0.34406, + 0.34572, + 0.34532, + 0.34724, + 0.34686, + 0.35018, + 0.34834, + 0.35062, + 0.35238, + 0.35146, + 0.35322, + 0.3507, + 0.35174, + 0.35298 + ], + "test_acc": [ + 0.3099, + 0.3371, + 0.35, + 0.3548, + 0.3568, + 0.3543, + 0.3595, + 0.3565, + 0.3605, + 0.3532, + 0.3572, + 0.3574, + 0.3587, + 0.3635, + 0.3598, + 0.3716, + 0.3664, + 0.3641, + 0.3588, + 0.3685, + 0.3704, + 0.3738, + 0.3678, + 0.3716, + 0.3726, + 0.3706, + 0.3715, + 0.3711, + 0.3718, + 0.3725 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.12851236760616302, + 0.0212344229221344, + 0.06850548088550568, + 0.912018895149231 + ], + "perturbation_rho": [ + 0.1286340206861496, + 0.05542958900332451, + 0.095004141330719, + 0.5912511944770813 + ], + "nudging": { + "0.001": [ + -1.2503878679126501e-05, + -2.949964255094528e-06, + -4.535817424766719e-06, + -4.0970538975670934e-05 + ], + "0.003": [ + -3.755756188184023e-05, + -8.84209293872118e-06, + -1.3517230399884284e-05, + -0.00012284089461900294 + ], + "0.01": [ + -0.00012522244651336223, + -2.9436778277158737e-05, + -4.5000226236879826e-05, + -0.0004094060859642923 + ] + }, + "hidden_norms_per_layer": [ + 8440.2509765625, + 8675.947265625, + 8919.455078125, + 9151.62890625, + 9001.966796875 + ], + "bp_grad_norms_per_layer": [ + 2.2785976398154162e-05, + 1.9475846784189343e-05, + 1.76876437762985e-05, + 1.6414065612480044e-05, + 1.2445364518498536e-05 + ] + }, + "drift": { + "embed.weight": 58.570146595416965, + "embed.bias": 129.66757192630527, + "blocks.0.ln.weight": 0.506428599357605, + "blocks.0.w1.weight": 6.100287737701914, + "blocks.0.w1.bias": 4.352361447004298, + "blocks.0.w2.weight": 29.053039362183725, + "blocks.1.ln.weight": 0.48886165022850037, + "blocks.1.w1.weight": 6.289397911525807, + "blocks.1.w1.bias": 4.491060429115057, + "blocks.1.w2.weight": 31.966204529299326, + "blocks.2.ln.weight": 0.47751307487487793, + "blocks.2.w1.weight": 6.165341672963252, + "blocks.2.w1.bias": 4.204746402092476, + "blocks.2.w2.weight": 31.785153220803842, + "blocks.3.ln.weight": 0.522526741027832, + "blocks.3.w1.weight": 6.2976831221121925, + "blocks.3.w1.bias": 3.31165977931123, + "blocks.3.w2.weight": 29.399820039375125, + "out_ln.weight": 0.12725675106048584, + "out_head.weight": 2.171575181004019, + "out_head.bias": 1.7024717076770008 + } + } + }, + "456": { + "fa": { + "log": { + "train_loss": [ + 2.0370823514556884, + 1.9467987035751342, + 1.912523896484375, + 1.8945044506072999, + 1.881802406349182, + 1.8781717306137085, + 1.8688158060073852, + 1.8661776677703859, + 1.8603027975845337, + 1.8572819149017334, + 1.8532151040267943, + 1.8534327938461304, + 1.8522534008407592, + 1.8507964714050293, + 1.8465819549179077, + 1.8432143920135498, + 1.8410576669311522, + 1.8371334344863892, + 1.8372844079208375, + 1.8335014197540282, + 1.8299405459213256, + 1.8304111400985719, + 1.8290630575561524, + 1.8247105368041991, + 1.8245959228515625, + 1.825667812576294, + 1.8260319690322877, + 1.8243235848999024, + 1.8212057236480712, + 1.8206781775283813 + ], + "train_acc": [ + 0.25624, + 0.29612, + 0.3112, + 0.31836, + 0.32406, + 0.32822, + 0.33338, + 0.3358, + 0.3376, + 0.33728, + 0.34174, + 0.34368, + 0.34628, + 0.34296, + 0.34426, + 0.34552, + 0.34848, + 0.35006, + 0.34902, + 0.34998, + 0.35384, + 0.35252, + 0.35216, + 0.35542, + 0.35508, + 0.35538, + 0.35446, + 0.35524, + 0.3556, + 0.35744 + ], + "test_acc": [ + 0.2999, + 0.3434, + 0.3434, + 0.364, + 0.3546, + 0.3612, + 0.3581, + 0.3492, + 0.367, + 0.3567, + 0.3655, + 0.3723, + 0.3761, + 0.3771, + 0.3778, + 0.3763, + 0.3825, + 0.3744, + 0.3812, + 0.3831, + 0.3751, + 0.3821, + 0.3809, + 0.3833, + 0.3812, + 0.3832, + 0.3835, + 0.3836, + 0.3841, + 0.3837 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.11049995571374893, + 0.014905610121786594, + -0.0519559383392334, + 0.861151933670044 + ], + "perturbation_rho": [ + 0.058121487498283386, + 0.011935144662857056, + -0.056115295737981796, + 0.5436498522758484 + ], + "nudging": { + "0.001": [ + -7.875027222326025e-06, + -1.205597072839737e-06, + 2.052285708487034e-06, + -3.838281554635614e-05 + ], + "0.003": [ + -2.359108839300461e-05, + -3.5760422179009765e-06, + 6.082700565457344e-06, + -0.00011519025429151952 + ], + "0.01": [ + -7.864644430810586e-05, + -1.1975058441748843e-05, + 2.0268842490622774e-05, + -0.000383934035198763 + ] + }, + "hidden_norms_per_layer": [ + 8074.6318359375, + 8344.6298828125, + 8543.16796875, + 8806.9365234375, + 8809.208984375 + ], + "bp_grad_norms_per_layer": [ + 2.0788113033631817e-05, + 1.6527425032109022e-05, + 1.6617635992588475e-05, + 1.6509349734405987e-05, + 1.1822515261883382e-05 + ] + }, + "drift": { + "embed.weight": 55.0795347886669, + "embed.bias": 104.1791057670674, + "blocks.0.ln.weight": 0.5131628513336182, + "blocks.0.w1.weight": 6.423288157104268, + "blocks.0.w1.bias": 5.260214874942604, + "blocks.0.w2.weight": 28.84901365790228, + "blocks.1.ln.weight": 0.5011720657348633, + "blocks.1.w1.weight": 6.239148515891604, + "blocks.1.w1.bias": 3.694106675391347, + "blocks.1.w2.weight": 28.607867363928534, + "blocks.2.ln.weight": 0.46569541096687317, + "blocks.2.w1.weight": 6.112045116014977, + "blocks.2.w1.bias": 4.730623150261222, + "blocks.2.w2.weight": 28.99578369272475, + "blocks.3.ln.weight": 0.5072412490844727, + "blocks.3.w1.weight": 6.376723566598171, + "blocks.3.w1.bias": 4.743548803408704, + "blocks.3.w2.weight": 30.777217385288502, + "out_ln.weight": 0.1257992684841156, + "out_head.weight": 2.0103689615464178, + "out_head.bias": 1.8804179129019218 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 30, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42, + 123, + 456 + ], + "gpu": 0, + "output_dir": "results/fa_penalty_lam1e-4_30ep", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0001, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_random_targets_s42/results_cifar10.json b/results/fa_random_targets_s42/results_cifar10.json new file mode 100644 index 0000000..33fa066 --- /dev/null +++ b/results/fa_random_targets_s42/results_cifar10.json @@ -0,0 +1,411 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.3210656842803954, + 2.3150660639953613, + 2.31394060256958, + 2.3137241358184815, + 2.313226750564575, + 2.312898317718506, + 2.3115276292419433, + 2.3112998097229003, + 2.312373881149292, + 2.3112880519104, + 2.3110179233551027, + 2.3094827233886717, + 2.3097055834197997, + 2.309138665313721, + 2.308755371398926, + 2.309352218093872, + 2.3081418705749512, + 2.3084949157714845, + 2.307779687347412, + 2.3071102055358885, + 2.3067180699920655, + 2.306768290481567, + 2.306924085235596, + 2.3063941722869874, + 2.306285163803101, + 2.3056750212860107, + 2.3055725598907473, + 2.3055327081298826, + 2.3053424534606934, + 2.3050209490966798, + 2.3051358866882325, + 2.304873356246948, + 2.3045516327667235, + 2.3042913507080076, + 2.304434557800293, + 2.3042074402618407, + 2.3039160931396485, + 2.304433411102295, + 2.3036191528320313, + 2.3038957733154297, + 2.303670624084473, + 2.303965150299072, + 2.303804924468994, + 2.3035577531433105, + 2.303584457550049, + 2.303571176300049, + 2.30339740234375, + 2.303557454452515, + 2.30346072013855, + 2.3033780017852785, + 2.303559310379028, + 2.3033258779144288, + 2.3033027320098878, + 2.303322297897339, + 2.30284891708374, + 2.3031199270629883, + 2.302974216156006, + 2.3031179604339598, + 2.303074387893677, + 2.3032363694763185, + 2.3030751721191405, + 2.3031637412261965, + 2.3029984189605712, + 2.302961895904541, + 2.3029372912597657, + 2.30282003036499, + 2.3028841104888915, + 2.302927931060791, + 2.30283846572876, + 2.3028817950439455, + 2.302871095428467, + 2.302848546066284, + 2.3027211415863036, + 2.302883378829956, + 2.3028133658599854, + 2.302789750442505, + 2.30266556968689, + 2.3026871818542483, + 2.3026958183288575, + 2.3027067826080323, + 2.302647667160034, + 2.3026781449890135, + 2.3026719428253175, + 2.302752258377075, + 2.3025668229675293, + 2.3027340324401857, + 2.3026700817871095, + 2.3026870487976074, + 2.3026517832946776, + 2.302639422836304, + 2.3026284925842284, + 2.302678219604492, + 2.3025743629455566, + 2.302645079803467, + 2.3026254082489013, + 2.3025657051086426, + 2.302568560562134, + 2.302596629562378, + 2.3026467196655274, + 2.30262910446167 + ], + "train_acc": [ + 0.09916, + 0.09876, + 0.10146, + 0.09992, + 0.10196, + 0.09982, + 0.09806, + 0.10266, + 0.10208, + 0.10162, + 0.10026, + 0.09852, + 0.10096, + 0.09824, + 0.09968, + 0.09992, + 0.10076, + 0.09998, + 0.09878, + 0.10038, + 0.1028, + 0.1005, + 0.09974, + 0.10234, + 0.10008, + 0.10142, + 0.10072, + 0.09942, + 0.09986, + 0.09736, + 0.09932, + 0.09888, + 0.10114, + 0.10054, + 0.09992, + 0.09942, + 0.09802, + 0.09746, + 0.09936, + 0.10052, + 0.09982, + 0.0983, + 0.09972, + 0.0996, + 0.09914, + 0.09954, + 0.1011, + 0.09934, + 0.09846, + 0.10064, + 0.09888, + 0.10058, + 0.10134, + 0.10024, + 0.10182, + 0.1007, + 0.10096, + 0.10038, + 0.1018, + 0.09818, + 0.10196, + 0.10002, + 0.10066, + 0.1021, + 0.09886, + 0.10106, + 0.09934, + 0.0975, + 0.10084, + 0.09938, + 0.1002, + 0.09842, + 0.1023, + 0.09958, + 0.10186, + 0.1014, + 0.10052, + 0.09892, + 0.09954, + 0.10054, + 0.09992, + 0.09962, + 0.10002, + 0.10092, + 0.10188, + 0.099, + 0.10124, + 0.09706, + 0.10042, + 0.10008, + 0.09796, + 0.09966, + 0.10016, + 0.09854, + 0.0991, + 0.10094, + 0.10254, + 0.09996, + 0.09714, + 0.09882 + ], + "test_acc": [ + 0.0948, + 0.0811, + 0.1058, + 0.0925, + 0.1061, + 0.0856, + 0.1171, + 0.1042, + 0.0998, + 0.0894, + 0.0985, + 0.0732, + 0.0858, + 0.0861, + 0.1274, + 0.0628, + 0.1051, + 0.104, + 0.1127, + 0.1045, + 0.1051, + 0.1183, + 0.0754, + 0.0837, + 0.0762, + 0.1063, + 0.1, + 0.0952, + 0.0933, + 0.1226, + 0.1009, + 0.1043, + 0.117, + 0.0946, + 0.0931, + 0.1133, + 0.0955, + 0.0995, + 0.0942, + 0.0837, + 0.109, + 0.1136, + 0.0937, + 0.0749, + 0.1135, + 0.1174, + 0.086, + 0.0883, + 0.108, + 0.0945, + 0.1034, + 0.0948, + 0.0933, + 0.1122, + 0.0924, + 0.0801, + 0.0856, + 0.0858, + 0.1102, + 0.1195, + 0.1005, + 0.0939, + 0.0985, + 0.0976, + 0.091, + 0.0991, + 0.0951, + 0.1015, + 0.0989, + 0.1104, + 0.0934, + 0.0977, + 0.1057, + 0.0996, + 0.0998, + 0.0893, + 0.1136, + 0.1194, + 0.1184, + 0.0927, + 0.1011, + 0.1005, + 0.0986, + 0.0947, + 0.1008, + 0.1149, + 0.1071, + 0.0997, + 0.0718, + 0.117, + 0.1241, + 0.1006, + 0.1178, + 0.0988, + 0.0993, + 0.1143, + 0.117, + 0.1217, + 0.1193, + 0.12 + ] + }, + "diagnostics": { + "bp_cosine": [ + -0.25835472345352173, + -0.36130592226982117, + -0.2657647430896759, + 0.9999802708625793 + ], + "perturbation_rho": [ + -0.043529167771339417, + 0.022320769727230072, + 0.011527110822498798, + 0.0071525974199175835 + ], + "nudging": { + "0.001": [ + 1.0617077350616455e-07, + 1.210719347000122e-07, + 9.499490261077881e-08, + -3.259629011154175e-07 + ], + "0.003": [ + 3.129243850708008e-07, + 3.7997961044311523e-07, + 2.682209014892578e-07, + -1.0058283805847168e-06 + ], + "0.01": [ + 1.0654330253601074e-06, + 1.2777745723724365e-06, + 8.754432201385498e-07, + -3.335997462272644e-06 + ] + }, + "hidden_norms_per_layer": [ + 1076.538818359375, + 51151.71875, + 101311.3046875, + 119784.4453125, + 128818.7890625 + ], + "bp_grad_norms_per_layer": [ + 2.6915265038951475e-07, + 2.3872169663263776e-07, + 2.386778135132772e-07, + 2.389696476257086e-07, + 2.3893238676464534e-07 + ] + }, + "drift": { + "embed.weight": 9.80184818135659, + "embed.bias": 19.49465769896296, + "blocks.0.ln.weight": 0.26961779594421387, + "blocks.0.w1.weight": 9.358735259494848, + "blocks.0.w1.bias": 12.749895988211446, + "blocks.0.w2.weight": 19.870291739972853, + "blocks.1.ln.weight": 0.3701913058757782, + "blocks.1.w1.weight": 9.472528827416744, + "blocks.1.w1.bias": 11.79483149750118, + "blocks.1.w2.weight": 14.954104550351216, + "blocks.2.ln.weight": 0.3377458453178406, + "blocks.2.w1.weight": 7.33340319144077, + "blocks.2.w1.bias": 9.089785454164343, + "blocks.2.w2.weight": 12.405169546116843, + "blocks.3.ln.weight": 0.49049311876296997, + "blocks.3.w1.weight": 5.387752188715276, + "blocks.3.w1.bias": 5.9034627839649305, + "blocks.3.w2.weight": 7.60319140703279, + "out_ln.weight": 0.5721396803855896, + "out_head.weight": 0.9433164083950665, + "out_head.bias": 0.4188211351080402 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 100, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_random_targets_s42", + "methods": [ + "fa" + ], + "random_targets": true, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file diff --git a/results/fa_smoke_test/results_cifar10.json b/results/fa_smoke_test/results_cifar10.json new file mode 100644 index 0000000..a8b563b --- /dev/null +++ b/results/fa_smoke_test/results_cifar10.json @@ -0,0 +1,120 @@ +{ + "42": { + "fa": { + "log": { + "train_loss": [ + 2.049124941177368, + 1.9718045804214477, + 1.9505524127578735 + ], + "train_acc": [ + 0.2436, + 0.27518, + 0.2894 + ], + "test_acc": [ + 0.2789, + 0.3087, + 0.3122 + ] + }, + "diagnostics": { + "bp_cosine": [ + 0.07161466032266617, + -0.008746136911213398, + -0.016568297520279884, + 0.9941877722740173 + ], + "perturbation_rho": [ + 0.041674911975860596, + 0.0022312882356345654, + -0.008362723514437675, + 0.2924357056617737 + ], + "nudging": { + "0.001": [ + -3.0086375772953033e-06, + 5.8673322200775146e-08, + 7.078051567077637e-08, + -9.350478649139404e-06 + ], + "0.003": [ + -8.966773748397827e-06, + 1.1548399925231934e-07, + 1.7695128917694092e-07, + -2.79964879155159e-05 + ], + "0.01": [ + -2.9983464628458023e-05, + 3.6461278796195984e-07, + 5.96512109041214e-07, + -9.332690387964249e-05 + ] + }, + "hidden_norms_per_layer": [ + 828.2960815429688, + 8464.98046875, + 19738.599609375, + 21368.41796875, + 19217.69140625 + ], + "bp_grad_norms_per_layer": [ + 1.9956107280449942e-05, + 4.579987034958322e-06, + 4.49442450189963e-06, + 4.525154963630484e-06, + 4.308007646613987e-06 + ] + }, + "drift": { + "embed.weight": 7.785380853670397, + "embed.bias": 6.819826161992119, + "blocks.0.ln.weight": 0.42689943313598633, + "blocks.0.w1.weight": 6.155192994774994, + "blocks.0.w1.bias": 6.637409518031749, + "blocks.0.w2.weight": 17.833861612088054, + "blocks.1.ln.weight": 0.352157860994339, + "blocks.1.w1.weight": 5.231153715209586, + "blocks.1.w1.bias": 6.610786582785788, + "blocks.1.w2.weight": 13.997490142949763, + "blocks.2.ln.weight": 0.3031489849090576, + "blocks.2.w1.weight": 4.218717880513288, + "blocks.2.w1.bias": 4.851412113278458, + "blocks.2.w2.weight": 12.191107541748561, + "blocks.3.ln.weight": 0.260187029838562, + "blocks.3.w1.weight": 3.545270787599183, + "blocks.3.w1.bias": 3.6891700957298967, + "blocks.3.w2.weight": 10.80288177600079, + "out_ln.weight": 0.04471425712108612, + "out_head.weight": 0.9764490646917799, + "out_head.bias": 0.41938112118479187 + } + } + }, + "config": { + "dataset": "cifar10", + "d_hidden": 256, + "num_blocks": 4, + "batch_size": 128, + "epochs": 3, + "lr": 0.001, + "lr_fb": 0.001, + "wd": 0.01, + "lam": 0.1, + "K": 4, + "sigma_bridge": 0.05, + "ema_momentum": 0.995, + "term_grad_weight": 1.0, + "seeds": [ + 42 + ], + "gpu": 0, + "output_dir": "results/fa_smoke_test", + "methods": [ + "fa" + ], + "random_targets": false, + "penalty_lam": 0.0, + "num_classes": 10 + } +}
\ No newline at end of file |
