summaryrefslogtreecommitdiff
path: root/protocol/examples/verify_pitfalls.py
blob: d329331f71b451776066cfe254fd1bfcc42588f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Pipeline pitfalls verifier: empirically demonstrate bugs 1-3 from
`protocol/CHECKLIST.md` so the catalog is grounded in reproducible
synthetic evidence rather than in-vivo anecdote.

Bug 1: `tensor.norm(-1)` is the L_{-1} 'norm' of the entire tensor,
       NOT 'L_2 along dim=-1'. The correct call is `tensor.norm(dim=-1)`.

Bug 2: `F.cosine_similarity(a, b)` clamps the divisor by eps=1e-8 by
       default. When ||b|| ~ 1e-10 (which is the regime BP grads land in
       on DFA-trained pre-LN ResMLPs), the divisor becomes ||a|| * 1e-8
       instead of ||a|| * 1e-10, scaling the reported cosine by ~100x
       in the wrong direction.

Bug 3: fp16 mixed precision underflows BP grads at hidden layers when
       they sit at ~5e-10 (well below fp16's smallest subnormal of
       ~6e-8). bf16 works because it has the same exponent range as fp32.

This script does NOT use GPU and runs in <1 second.

Run:
    python -m protocol.examples.verify_pitfalls
"""
import math

import torch
import torch.nn.functional as F


def banner(title):
    print("=" * 72)
    print(title)
    print("=" * 72)


def bug1_norm_minus_one():
    banner("BUG 1: tensor.norm(-1) is NOT 'L_2 along dim=-1'")
    torch.manual_seed(0)
    x = torch.tensor([[3.0, 4.0], [6.0, 8.0]])  # rows have L2 norms 5 and 10
    correct = x.norm(dim=-1)  # this is what callers usually mean
    bug = x.norm(-1)          # this is what `.norm(-1)` actually computes

    # Hand-compute the L_{-1} 'norm' of the whole tensor for clarity:
    # ||x||_{-1} = (sum_i |x_i|^{-1})^{-1} = harmonic-mean-like quantity
    flat = x.flatten()
    hand_neg1 = (flat.abs().pow(-1).sum()).pow(-1).item()

    print(f"  x = {x.tolist()}")
    print(f"  x.norm(dim=-1) (correct, L_2 along last dim): {correct.tolist()}")
    print(f"  x.norm(-1)     (bug, L_{{-1}} of whole tensor): {bug.item():.6f}")
    print(f"  hand-computed L_{{-1}} of flat tensor:           {hand_neg1:.6f}")
    print(f"  -> the two values match: {abs(bug.item() - hand_neg1) < 1e-6}")
    print(f"  -> the bug version is unrelated to per-row L_2 norms.")
    print()


def bug2_cosine_eps_clamp():
    banner("BUG 2: F.cosine_similarity(a, b) clamps divisor by eps=1e-8")
    # Construct a case where one vector has a tiny but non-zero magnitude.
    # We use float64 throughout to avoid confounding with fp underflow.
    torch.manual_seed(0)
    a = torch.randn(1, 100, dtype=torch.float64)
    direction = torch.randn(100, dtype=torch.float64)
    direction = direction / direction.norm()
    # b is just direction scaled to a tiny magnitude
    b_scale = 5e-10  # the magnitude DFA-trained nets give for BP grads at hidden layers
    b = (direction * b_scale).unsqueeze(0)

    # True cosine, no clamp
    true_cos = (a @ b.T).item() / (a.norm().item() * b.norm().item())
    # PyTorch's F.cosine_similarity with default eps=1e-8
    pytorch_cos = F.cosine_similarity(a, b, dim=-1).item()

    ratio = pytorch_cos / true_cos if abs(true_cos) > 1e-30 else float('nan')
    print(f"  ||a|| = {a.norm().item():.4e}")
    print(f"  ||b|| = {b.norm().item():.4e}  (intentionally below eps=1e-8)")
    print(f"  true cosine     (no clamp):       {true_cos:+.6f}")
    print(f"  F.cosine_similarity (default eps): {pytorch_cos:+.6f}")
    print(f"  ratio reported/true: {ratio:.6e}  (should be 1.0)")
    print(f"  scaling distortion: {b_scale / 1e-8:.4e}x  (i.e. PyTorch divides by")
    print(f"    ||a||*1e-8 instead of ||a||*{b_scale:.0e}, off by ~{1e-8/b_scale:.0e}x)")
    print()


def bug3_fp16_underflow():
    banner("BUG 3: fp16 mixed precision underflows BP grads at ~5e-10")
    # The smallest positive subnormal in fp16 is approximately 6e-8.
    # Anything below that becomes 0.
    fp16_min = torch.tensor(6e-8, dtype=torch.float16)
    bp_grad_magnitude = 5e-10  # typical for DFA-trained pre-LN ResMLPs

    # Try to represent the magnitude in fp16
    val_fp16 = torch.tensor(bp_grad_magnitude, dtype=torch.float16)
    val_bf16 = torch.tensor(bp_grad_magnitude, dtype=torch.bfloat16)
    val_fp32 = torch.tensor(bp_grad_magnitude, dtype=torch.float32)

    print(f"  BP grad magnitude on DFA-trained ResMLP: {bp_grad_magnitude:.0e}")
    print(f"  fp16 representation:  {val_fp16.item():.4e}  (-> 0 = UNDERFLOW)")
    print(f"  bf16 representation:  {val_bf16.item():.4e}  (works, same exp range as fp32)")
    print(f"  fp32 representation:  {val_fp32.item():.4e}  (works)")

    # Show what happens to a downstream cosine computation
    a = torch.randn(100)
    direction = torch.randn(100); direction = direction / direction.norm()
    b32 = direction * bp_grad_magnitude
    b16 = b32.half()
    bbf = b32.bfloat16()
    print()
    print("  cosine of random vector with the BP-grad-magnitude direction, by precision:")
    print(f"    fp32 cosine: {F.cosine_similarity(a.unsqueeze(0), b32.unsqueeze(0)).item():+.4f}  (correct)")
    print(f"    fp16 cosine: {F.cosine_similarity(a.half().unsqueeze(0), b16.unsqueeze(0)).item():+.4f}  (corrupt — divisor underflowed)")
    print(f"    bf16 cosine: {F.cosine_similarity(a.bfloat16().unsqueeze(0), bbf.unsqueeze(0)).float().item():+.4f}  (correct)")
    print()


def main():
    bug1_norm_minus_one()
    bug2_cosine_eps_clamp()
    bug3_fp16_underflow()
    print("All 3 reproducers ran. Each demonstrates the documented bug from")
    print("protocol/CHECKLIST.md. Bugs 4-6 require a trained network and are")
    print("verified inside the audit_table and ablation_decision_utility scripts.")


if __name__ == "__main__":
    main()