summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 05:57:53 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 05:57:53 -0500
commitbe39c2b5ebec37f993b1a862459455a98cf39eb2 (patch)
tree0b373ccfd983ae866f12c9029db3bfd863a8e2fd /experiments
parent52693a9be4349c2820ac79e3e3d9af53813a7412 (diff)
Round 35: SB and CB also show data-agnostic Mode 1 growth on random targets
- experiments/cifar_resmlp.py: add --methods filter and --random_targets flag; extend compute_diagnostics to log hidden_norms_per_layer and bp_grad_norms_per_layer - paper/main.tex §3 ¶1: broaden random-target finding to all 3 fixed-feedback methods (DFA: ||h_L||=14510, SB: ||h_L||=6225, CB: ||h_L||=19974 at ep 3, all at chance acc) - paper/main.tex Appendix J: extended with cross-method smoke-test table This generalizes the §3 mechanism story from 'DFA-specific' to 'all 3 audited fixed-feedback local-credit methods'. Combined with rounds 32-34, the proximate cause of Mode 1 (a) is now well-localized: - Not requires residual skip (round 33 H2 walkback) - Not requires task signal (round 34 random targets, DFA) - Not DFA-specific (round 35 random targets, SB+CB) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cifar_resmlp.py110
1 files changed, 67 insertions, 43 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
index 1582f6d..4324e9e 100644
--- a/experiments/cifar_resmlp.py
+++ b/experiments/cifar_resmlp.py
@@ -99,6 +99,8 @@ def train_bp(model, train_loader, test_loader, device, args):
for x, y in train_loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
+ if getattr(args, 'random_targets', False):
+ y = torch.randint(0, args.num_classes, y.shape, device=device)
logits = model(x)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad()
@@ -160,6 +162,8 @@ def train_dfa(model, train_loader, test_loader, device, args):
for x, y in train_loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
+ if getattr(args, 'random_targets', False):
+ y = torch.randint(0, args.num_classes, y.shape, device=device)
batch = x.size(0)
# Forward pass (no grad for hidden states)
@@ -262,6 +266,8 @@ def train_state_bridge(model, train_loader, test_loader, device, args):
for x, y in train_loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
+ if getattr(args, 'random_targets', False):
+ y = torch.randint(0, args.num_classes, y.shape, device=device)
batch = x.size(0)
with torch.no_grad():
@@ -418,6 +424,8 @@ def train_credit_bridge(model, train_loader, test_loader, device, args):
for x, y in train_loader:
x = x.view(x.size(0), -1).to(device)
y = y.to(device)
+ if getattr(args, 'random_targets', False):
+ y = torch.randint(0, args.num_classes, y.shape, device=device)
batch = x.size(0)
with torch.no_grad():
@@ -595,10 +603,16 @@ def compute_diagnostics(model, method_name, test_loader, device, args,
e_T[torch.arange(batch), y] -= 1
s = e_T.detach()
+ # Per-layer hidden norms (median across batch) and BP grad norms (per-sample L2, median)
+ hidden_norms_per_layer = [float(hiddens[l].detach().norm(dim=-1).median().item()) for l in range(L + 1)]
+ bp_grad_norms_per_layer = [float(bp_grads[l].norm(dim=-1).median().item()) for l in range(L + 1)]
+
results = {
'bp_cosine': [],
'perturbation_rho': [],
'nudging': {'0.001': [], '0.003': [], '0.01': []},
+ 'hidden_norms_per_layer': hidden_norms_per_layer,
+ 'bp_grad_norms_per_layer': bp_grad_norms_per_layer,
}
for l in range(L):
@@ -673,56 +687,62 @@ def run_experiment(args):
seed_results = {}
+ methods_to_run = getattr(args, 'methods', ['bp', 'dfa', 'state_bridge', 'credit_bridge'])
+
# ---- BP ----
- print("\n--- BP ---")
- model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
- init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()}
- bp_log = train_bp(model_bp, train_loader, test_loader, device, args)
- bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args)
- bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()})
- seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift}
- print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}")
+ if 'bp' in methods_to_run:
+ print("\n--- BP ---")
+ model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()}
+ bp_log = train_bp(model_bp, train_loader, test_loader, device, args)
+ bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args)
+ bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()})
+ seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift}
+ print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}")
# ---- DFA ----
- print("\n--- DFA ---")
- torch.manual_seed(seed)
- np.random.seed(seed)
- torch.cuda.manual_seed_all(seed)
- model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
- init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()}
- dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args)
- dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs)
- dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()})
- seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift}
- print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}")
+ if 'dfa' in methods_to_run:
+ print("\n--- DFA ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()}
+ dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args)
+ dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs)
+ dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()})
+ seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift}
+ print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}")
# ---- State Bridge ----
- print("\n--- State Bridge ---")
- torch.manual_seed(seed)
- np.random.seed(seed)
- torch.cuda.manual_seed_all(seed)
- model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
- init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()}
- sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args)
- sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args,
- state_predictor=state_pred)
- sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()})
- seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift}
- print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}")
+ if 'state_bridge' in methods_to_run:
+ print("\n--- State Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()}
+ sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args)
+ sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args,
+ state_predictor=state_pred)
+ sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()})
+ seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift}
+ print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}")
# ---- Credit Bridge ----
- print("\n--- Credit Bridge ---")
- torch.manual_seed(seed)
- np.random.seed(seed)
- torch.cuda.manual_seed_all(seed)
- model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
- init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()}
- cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args)
- cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args,
- value_net=vnet)
- cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()})
- seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift}
- print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}")
+ if 'credit_bridge' in methods_to_run:
+ print("\n--- Credit Bridge ---")
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device)
+ init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()}
+ cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args)
+ cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args,
+ value_net=vnet)
+ cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()})
+ seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift}
+ print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}")
all_results[seed] = seed_results
@@ -767,6 +787,10 @@ def main():
parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
parser.add_argument('--gpu', type=int, default=1)
parser.add_argument('--output_dir', type=str, default='results/cifar10')
+ parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'state_bridge', 'credit_bridge'],
+ help='Subset of methods to run.')
+ parser.add_argument('--random_targets', action='store_true',
+ help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).')
args = parser.parse_args()
run_experiment(args)