diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:29:00 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 22:29:00 -0500 |
| commit | 111bab56e2d49c9fb1f3bfb9e55ea2028da4d008 (patch) | |
| tree | 5963a8171c383023b3bd19ed3a86e460ebe99615 | |
| parent | 7b64702ad970c16171142665365e16a8e1737190 (diff) | |
Add audit table example: protocol applied to BP/DFA/SB/CB/EP
5-method audit table on 4-block d=256 ResMLP CIFAR-10 seed 42:
- BP: trustworthy (acc 0.615, h_L=2e2, g_L=4e-4, stab 0.099)
- DFA: walked back via (a)+(b)+(d) — h_L=4e8, g_L=4e-9, undercuts frozen
- State Bridge: walked back via all 4 diagnostics — stability 0.992 is the
cleanest possible drift-dominated case
- Credit Bridge: walked back via all 4 — stability 0.352, also drift mode
- EP: trustworthy (acc 0.359, h_L=3e3, g_L=2e-4, stab -0.036) — paper's
internal control case
This is the §2 audit evidence for the main-track paper. Confirms that
standard headline acc + Γ silently fails on 3 of 5 methods on this
architecture, while the 4-diagnostic protocol catches all three.
| -rw-r--r-- | protocol/examples/__init__.py | 0 | ||||
| -rw-r--r-- | protocol/examples/audit_table.py | 162 | ||||
| -rw-r--r-- | results/protocol_audit/audit_table_s42.json | 196 |
3 files changed, 358 insertions, 0 deletions
diff --git a/protocol/examples/__init__.py b/protocol/examples/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/protocol/examples/__init__.py diff --git a/protocol/examples/audit_table.py b/protocol/examples/audit_table.py new file mode 100644 index 0000000..1a75d96 --- /dev/null +++ b/protocol/examples/audit_table.py @@ -0,0 +1,162 @@ +""" +Reproduce the §2 audit table: apply the diagnostic protocol to BP / DFA / +State Bridge / Credit Bridge / EP checkpoints on the 4-block d=256 ResMLP / +CIFAR-10 setup. Single seed 42 for the table; the paper uses 3-seed means +elsewhere. + +Output is a per-method tabular summary that lists, for each diagnostic, +the per-layer values and the verdict. This is the audit evidence behind the +paper claim *"standard FA evaluation reports headline accuracy + Γ as +evidence of training, but on modern pre-LN residual networks both signals +silently fail for non-BP methods."* + +Run: + CUDA_VISIBLE_DEVICES=2 python -m protocol.examples.audit_table +""" +import os +import sys +import json + +import torch +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, REPO_ROOT) + +from models.residual_mlp import ResidualMLP # noqa: E402 +from protocol import diagnose # noqa: E402 + +CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/confirmatory/checkpoints_A2") +EP_CHECKPOINT_DIR = os.path.join(REPO_ROOT, "results/ep_baseline") +OUT_DIR = os.path.join(REPO_ROOT, "results/protocol_audit") +os.makedirs(OUT_DIR, exist_ok=True) + + +def load_eval_batches(n_batches=10, batch_size=128, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=0) + batches = [] + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batches.append((x, y)) + if len(batches) >= n_batches: + break + return batches + + +def evaluate(model, device): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + model.eval() + correct = total = 0 + with torch.no_grad(): + for x, y in loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + preds = model(x).argmax(-1) + correct += (preds == y).sum().item() + total += x.size(0) + return correct / total + + +def load_model(method: str, seed: int, device): + if method == "ep": + path = os.path.join(EP_CHECKPOINT_DIR, f"ep_s{seed}.pt") + else: + path = os.path.join(CHECKPOINT_DIR, f"{method}_s{seed}.pt") + ckpt = torch.load(path, map_location=device, weights_only=False) + sd = ckpt if not hasattr(ckpt, "state_dict") else ckpt.state_dict() + if isinstance(sd, dict) and "state_dict" in sd: + sd = sd["state_dict"] + model = ResidualMLP(input_dim=3072, d_hidden=256, num_classes=10, num_blocks=4).to(device) + model.load_state_dict(sd) + return model + + +# 3-seed mean shallow / frozen baseline accuracies (from +# project_resmlp_walkback_dfa_destroys_value memory entry — these are the +# same number for the DFA condition by design: the "deep blocks frozen at +# random init" is informationally equivalent to "no deep blocks"). +FROZEN_BASELINE_ACC = { + "bp": None, # BP-frozen is 34.6%; not the right comparator for BP-trainable + "dfa": 0.349, # DFA-frozen / DFA-shallow 3-seed mean + "state_bridge": 0.349, # uses the same architecture-matched control + "credit_bridge": 0.349, + "ep": None, # EP frozen-control not run yet +} + + +def main(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + eval_batches = load_eval_batches(n_batches=10, batch_size=128, device=device) + + methods = ["bp", "dfa", "state_bridge", "credit_bridge", "ep"] + rows = [] + reports = {} + for method in methods: + print(f"\n### {method.upper()} (seed 42)") + model = load_model(method, 42, device) + acc = evaluate(model, device) + report = diagnose( + model=model, + eval_batches=eval_batches, + headline_acc=acc, + frozen_baseline_acc=FROZEN_BASELINE_ACC.get(method), + method_name=method.upper(), + notes="4-block d=256 ResMLP, CIFAR-10, seed 42", + ) + print(report) + reports[method] = report.to_dict() + rows.append({ + "method": method, + "acc": acc, + "h_L": report.residual_norms[-1], + "g_L": report.bp_grad_norms[-1], + "stability": report.cross_batch_stability, + "frozen_acc": report.frozen_baseline_acc, + "verdict": report.verdict, + }) + + # Compact summary table + print("\n\n" + "=" * 100) + print("AUDIT SUMMARY (single seed 42, 4-block d=256 ResMLP, CIFAR-10)") + print("=" * 100) + header = ( + f"{'method':<16}{'acc':>8}{'||h_L||':>14}{'||g_L||':>14}" + f"{'stab(L/2)':>12}{'frozen':>10} verdict" + ) + print(header) + print("-" * 100) + for r in rows: + frozen = "n/a" if r["frozen_acc"] is None else f"{r['frozen_acc']:.4f}" + print( + f"{r['method']:<16}" + f"{r['acc']:>8.4f}" + f"{r['h_L']:>14.3e}" + f"{r['g_L']:>14.3e}" + f"{r['stability']:>12.3f}" + f"{frozen:>10} {r['verdict']}" + ) + + out_path = os.path.join(OUT_DIR, "audit_table_s42.json") + with open(out_path, "w") as f: + json.dump({"reports": reports, "summary": rows}, f, indent=2) + print(f"\nSaved {out_path}") + + +if __name__ == "__main__": + main() diff --git a/results/protocol_audit/audit_table_s42.json b/results/protocol_audit/audit_table_s42.json new file mode 100644 index 0000000..d1c1f84 --- /dev/null +++ b/results/protocol_audit/audit_table_s42.json @@ -0,0 +1,196 @@ +{ + "reports": { + "bp": { + "method_name": "BP", + "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42", + "residual_norms": [ + 251.83087158203125, + 226.57342529296875, + 212.16461181640625, + 205.60723876953125, + 205.75946044921875 + ], + "bp_grad_norms": [ + 0.0004396044823806733, + 0.0004709330096375197, + 0.0004792391264345497, + 0.00045345001854002476, + 0.0003701267414726317 + ], + "stability_layer": 2, + "cross_batch_stability": 0.09898398886952135, + "headline_acc": 0.6149, + "frozen_baseline_acc": null, + "verdict": "trustworthy", + "thresholds": { + "g_norm_floor": 1e-07, + "h_norm_explosion_ratio": 50.0, + "stability_drift_ceiling": 0.3, + "frozen_acc_margin_pp": 2.0 + } + }, + "dfa": { + "method_name": "DFA", + "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42", + "residual_norms": [ + 35824.796875, + 73202040.0, + 174312304.0, + 339040960.0, + 435299520.0 + ], + "bp_grad_norms": [ + 4.39066155877299e-07, + 4.1912620041273385e-09, + 4.183721813433294e-09, + 4.174094847542165e-09, + 4.174704582027289e-09 + ], + "stability_layer": 2, + "cross_batch_stability": 0.047060725092887876, + "headline_acc": 0.3107, + "frozen_baseline_acc": 0.349, + "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; deep blocks fail to beat frozen-random baseline", + "thresholds": { + "g_norm_floor": 1e-07, + "h_norm_explosion_ratio": 50.0, + "stability_drift_ceiling": 0.3, + "frozen_acc_margin_pp": 2.0 + } + }, + "state_bridge": { + "method_name": "STATE_BRIDGE", + "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42", + "residual_norms": [ + 906.3201293945312, + 11583499.0, + 34872504.0, + 208111168.0, + 228665568.0 + ], + "bp_grad_norms": [ + 8.369566785404459e-06, + 1.996277365634569e-09, + 1.9812380624983916e-09, + 1.8405569290891322e-09, + 1.8411722146893794e-09 + ], + "stability_layer": 2, + "cross_batch_stability": 0.99180050028695, + "headline_acc": 0.1695, + "frozen_baseline_acc": 0.349, + "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline", + "thresholds": { + "g_norm_floor": 1e-07, + "h_norm_explosion_ratio": 50.0, + "stability_drift_ceiling": 0.3, + "frozen_acc_margin_pp": 2.0 + } + }, + "credit_bridge": { + "method_name": "CREDIT_BRIDGE", + "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42", + "residual_norms": [ + 13249.662109375, + 24119914.0, + 554824896.0, + 548816832.0, + 606231552.0 + ], + "bp_grad_norms": [ + 7.185065555859182e-07, + 1.1024462454045647e-09, + 9.061909000962487e-10, + 9.013046420314197e-10, + 9.011226209665324e-10 + ], + "stability_layer": 2, + "cross_batch_stability": 0.3516695586343606, + "headline_acc": 0.2562, + "frozen_baseline_acc": 0.349, + "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline", + "thresholds": { + "g_norm_floor": 1e-07, + "h_norm_explosion_ratio": 50.0, + "stability_drift_ceiling": 0.3, + "frozen_acc_margin_pp": 2.0 + } + }, + "ep": { + "method_name": "EP", + "notes": "4-block d=256 ResMLP, CIFAR-10, seed 42", + "residual_norms": [ + 518.3867797851562, + 579.6542358398438, + 680.764892578125, + 1145.8692626953125, + 3286.841064453125 + ], + "bp_grad_norms": [ + 0.00022257285309024155, + 0.00022327345504891127, + 0.00021209640544839203, + 0.00021204684162512422, + 0.00016422539192717522 + ], + "stability_layer": 2, + "cross_batch_stability": -0.03589460700750351, + "headline_acc": 0.359, + "frozen_baseline_acc": null, + "verdict": "trustworthy", + "thresholds": { + "g_norm_floor": 1e-07, + "h_norm_explosion_ratio": 50.0, + "stability_drift_ceiling": 0.3, + "frozen_acc_margin_pp": 2.0 + } + } + }, + "summary": [ + { + "method": "bp", + "acc": 0.6149, + "h_L": 205.75946044921875, + "g_L": 0.0003701267414726317, + "stability": 0.09898398886952135, + "frozen_acc": null, + "verdict": "trustworthy" + }, + { + "method": "dfa", + "acc": 0.3107, + "h_L": 435299520.0, + "g_L": 4.174704582027289e-09, + "stability": 0.047060725092887876, + "frozen_acc": 0.349, + "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; deep blocks fail to beat frozen-random baseline" + }, + { + "method": "state_bridge", + "acc": 0.1695, + "h_L": 228665568.0, + "g_L": 1.8411722146893794e-09, + "stability": 0.99180050028695, + "frozen_acc": 0.349, + "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline" + }, + { + "method": "credit_bridge", + "acc": 0.2562, + "h_L": 606231552.0, + "g_L": 9.011226209665324e-10, + "stability": 0.3516695586343606, + "frozen_acc": 0.349, + "verdict": "needs walk-back: residual stream exploded; BP grad at numerical floor; BP grad direction is drift-dominated; deep blocks fail to beat frozen-random baseline" + }, + { + "method": "ep", + "acc": 0.359, + "h_L": 3286.841064453125, + "g_L": 0.00016422539192717522, + "stability": -0.03589460700750351, + "frozen_acc": null, + "verdict": "trustworthy" + } + ] +}
\ No newline at end of file |
