summaryrefslogtreecommitdiff
path: root/ep_run/track_probe.py
blob: 2846745a3119927c4548a85cec72dc4df547a764 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch, math
import lt_ep_train as M
from pathlib import Path
import pickle
M.DD = Path('/tmp/lt_ep/data/tinystories')
M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size']
from lt_ep_train import EQBlock, get_batch, bptt_step, relax
from holo_ep import holo_a_select2, rforce, rgrad_ce
import torch.func as tf

def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0):
    """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous
    common mode of the two phases — exact transposed differential dynamics, loose-tolerant,
    no compounding linearization error."""
    B = zs.size(0)
    Z = torch.cat([zs, zs], 0)
    X2 = torch.cat([xin, xin], 0)
    y2 = torch.cat([y, y], 0)
    sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0)
    fnc = lambda zz: blk.nc_force(zz)
    a_prev = a_best = None
    inc_min, t_best = float('inf'), 0
    for t in range(1, T2max + 1):
        with torch.no_grad():
            zbar = 0.5 * (Z[:B] + Z[B:])
            zb2 = torch.cat([zbar, zbar], 0)
            f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel())
            v = (Z - zb2).contiguous()
            _, Jv = tf.jvp(fnc, (zb2,), (v,))
            JTv = tf.vjp(fnc, zb2)[1](v)[0]
            Z = Z + eps * (f - (Jv - JTv))
        if t % K == 0 or t == T2max:
            a_t = (Z[B:] - Z[:B]) / (2 * r)
            if not torch.isfinite(a_t).all():
                break
            if a_prev is not None:
                inc = (a_t - a_prev).norm().item()
                if inc < inc_min:
                    inc_min, a_best, t_best = inc, a_t, t
                elif inc > exit_mult * inc_min and t >= 3 * K:
                    break
            a_prev = a_t
    if a_best is None:
        a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r)
        t_best = T2max
    return a_best.detach(), t_best

if __name__ == '__main__':
    torch.manual_seed(0)
    B, T, C, H = 8, 256, 256, 8
    blk = EQBlock(C, H, 256, T, attn_mode='thick')
    ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt')
    for p, w in zip(blk.allp, ck['allp']):
        with torch.no_grad():
            p.copy_(w.to('cuda'))
    idx, y = get_batch('train', B, T)
    xin = blk.embed(idx).detach()
    ref150 = bptt_step(blk, idx, y, 150, 0.1)
    def flat(g):
        keep = [p for p in blk.block if g.get(id(p)) is not None]
        return torch.cat([g[id(p)].reshape(-1) for p in keep])
    v150 = flat(ref150)
    def gfrom(zs, a_):
        with torch.enable_grad():
            x2 = blk.embed(idx)
            f = blk.force(zs.detach(), x2, cg=True)
            return {id(p): g for p, g in zip(blk.block, torch.autograd.grad((a_*f).sum(), blk.block, allow_unused=True))}
    z = xin.clone(); prev = 0
    for T1 in (75, 150, 600):
        z = relax(blk, z, xin, T1 - prev, 0.1); prev = T1
        res = (relax(blk, z, xin, 1, 0.1) - z).norm().item() / z.norm().item()
        for name, fn in (('frozen', holo_a_select2), ('track', holo_a_track)):
            for T2m in (120, 300):
                a, tb = fn(blk, z, xin, y, 0.02, T2m, 0.1)
                va = flat(gfrom(z, a))
                c = (va @ v150 / (va.norm() * v150.norm() + 1e-12)).item()
                print(f"T1={T1:>4} res={res:.1e} {name:>6} T2max={T2m:>3}: t_best={tb:>3} cos_vs150={c:.3f}", flush=True)