summaryrefslogtreecommitdiff
path: root/ep_run/floss_smoke.py
blob: 371f98a3283ed9e31917aa5da9bdc817fedc0bb6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""floss-ept mechanics smoke: (a) default-off path unchanged, (b) floss fires at random init
(rho_hat > target expected there) with finite grads, (c) the floss contribution is a bounded
perturbation of the task gradient (cos(g_off, g_floss) stays high — the 0.2 task-norm cap)."""
import math, torch
import lt_ep_train as L

torch.manual_seed(0)
blk = L.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick'); blk.qknorm = True
torch.manual_seed(7)
idx, y = L.get_batch('train', 8, 256)

g0, r0 = L.ep_step(blk, idx, y, 30, 8, 0.1, 0.02, jacreg=0.0)
print(f"[off  ] res={r0:.3e} n_grads={len(g0)}", flush=True)

g1, r1 = L.ep_step(blk, idx, y, 30, 8, 0.1, 0.02, jacreg=0.0, floss=0.2, floss_q=8, floss_bsub=4)
rho = getattr(blk, '_floss_rho', None)
fin = all(torch.isfinite(v).all().item() for v in g1.values() if v is not None)
print(f"[floss] res={r1:.3e} rho_hat={rho:.4f} finite={fin} n_grads={len(g1)}", flush=True)

dot = ne = nb = 0.0
for p in blk.block:
    a, b = g0.get(id(p)), g1.get(id(p))
    if a is None or b is None:
        continue
    dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum())
print(f"cos(g_off, g_floss)={dot / math.sqrt(ne * nb + 1e-20):.4f}  (below-target: should be 1.0 = untouched)", flush=True)

# force-fire the penalty path (target below the measured rho) — exercises grad/lam/accumulate end-to-end
g2, r2 = L.ep_step(blk, idx, y, 30, 8, 0.1, 0.02, jacreg=0.0, floss=0.2, floss_q=8, floss_bsub=4, floss_rho=0.90)
fin2 = all(torch.isfinite(v).all().item() for v in g2.values() if v is not None)
dot = ne = nb = 0.0
for p in blk.block:
    a, b = g0.get(id(p)), g2.get(id(p))
    if a is None or b is None:
        continue
    dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum())
cos2 = dot / math.sqrt(ne * nb + 1e-20)
print(f"[fire ] rho_hat={blk._floss_rho:.4f} finite={fin2} cos(g_off, g_fire)={cos2:.4f} (should be <1 but >0.9: capped perturbation)", flush=True)
print("FLOSS_SMOKE_DONE", flush=True)