diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 05:57:53 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 05:57:53 -0500 |
| commit | be39c2b5ebec37f993b1a862459455a98cf39eb2 (patch) | |
| tree | 0b373ccfd983ae866f12c9029db3bfd863a8e2fd /experiments | |
| parent | 52693a9be4349c2820ac79e3e3d9af53813a7412 (diff) | |
Round 35: SB and CB also show data-agnostic Mode 1 growth on random targets
- experiments/cifar_resmlp.py: add --methods filter and --random_targets flag;
extend compute_diagnostics to log hidden_norms_per_layer and bp_grad_norms_per_layer
- paper/main.tex §3 ¶1: broaden random-target finding to all 3 fixed-feedback methods
(DFA: ||h_L||=14510, SB: ||h_L||=6225, CB: ||h_L||=19974 at ep 3, all at chance acc)
- paper/main.tex Appendix J: extended with cross-method smoke-test table
This generalizes the §3 mechanism story from 'DFA-specific' to 'all 3 audited
fixed-feedback local-credit methods'. Combined with rounds 32-34, the proximate
cause of Mode 1 (a) is now well-localized:
- Not requires residual skip (round 33 H2 walkback)
- Not requires task signal (round 34 random targets, DFA)
- Not DFA-specific (round 35 random targets, SB+CB)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cifar_resmlp.py | 110 |
1 files changed, 67 insertions, 43 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 1582f6d..4324e9e 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -99,6 +99,8 @@ def train_bp(model, train_loader, test_loader, device, args): for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) + if getattr(args, 'random_targets', False): + y = torch.randint(0, args.num_classes, y.shape, device=device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() @@ -160,6 +162,8 @@ def train_dfa(model, train_loader, test_loader, device, args): for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) + if getattr(args, 'random_targets', False): + y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) # Forward pass (no grad for hidden states) @@ -262,6 +266,8 @@ def train_state_bridge(model, train_loader, test_loader, device, args): for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) + if getattr(args, 'random_targets', False): + y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) with torch.no_grad(): @@ -418,6 +424,8 @@ def train_credit_bridge(model, train_loader, test_loader, device, args): for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) + if getattr(args, 'random_targets', False): + y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) with torch.no_grad(): @@ -595,10 +603,16 @@ def compute_diagnostics(model, method_name, test_loader, device, args, e_T[torch.arange(batch), y] -= 1 s = e_T.detach() + # Per-layer hidden norms (median across batch) and BP grad norms (per-sample L2, median) + hidden_norms_per_layer = [float(hiddens[l].detach().norm(dim=-1).median().item()) for l in range(L + 1)] + bp_grad_norms_per_layer = [float(bp_grads[l].norm(dim=-1).median().item()) for l in range(L + 1)] + results = { 'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001': [], '0.003': [], '0.01': []}, + 'hidden_norms_per_layer': hidden_norms_per_layer, + 'bp_grad_norms_per_layer': bp_grad_norms_per_layer, } for l in range(L): @@ -673,56 +687,62 @@ def run_experiment(args): seed_results = {} + methods_to_run = getattr(args, 'methods', ['bp', 'dfa', 'state_bridge', 'credit_bridge']) + # ---- BP ---- - print("\n--- BP ---") - model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) - init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()} - bp_log = train_bp(model_bp, train_loader, test_loader, device, args) - bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) - bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()}) - seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift} - print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}") + if 'bp' in methods_to_run: + print("\n--- BP ---") + model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()} + bp_log = train_bp(model_bp, train_loader, test_loader, device, args) + bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) + bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()}) + seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift} + print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}") # ---- DFA ---- - print("\n--- DFA ---") - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) - init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()} - dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) - dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) - dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()}) - seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} - print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") + if 'dfa' in methods_to_run: + print("\n--- DFA ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()} + dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) + dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) + dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()}) + seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} + print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") # ---- State Bridge ---- - print("\n--- State Bridge ---") - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) - init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()} - sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) - sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, - state_predictor=state_pred) - sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()}) - seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift} - print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}") + if 'state_bridge' in methods_to_run: + print("\n--- State Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()} + sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) + sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, + state_predictor=state_pred) + sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()}) + seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift} + print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}") # ---- Credit Bridge ---- - print("\n--- Credit Bridge ---") - torch.manual_seed(seed) - np.random.seed(seed) - torch.cuda.manual_seed_all(seed) - model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) - init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()} - cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) - cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, - value_net=vnet) - cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()}) - seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift} - print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}") + if 'credit_bridge' in methods_to_run: + print("\n--- Credit Bridge ---") + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()} + cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) + cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, + value_net=vnet) + cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()}) + seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift} + print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}") all_results[seed] = seed_results @@ -767,6 +787,10 @@ def main(): parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/cifar10') + parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'state_bridge', 'credit_bridge'], + help='Subset of methods to run.') + parser.add_argument('--random_targets', action='store_true', + help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') args = parser.parse_args() run_experiment(args) |
