summaryrefslogtreecommitdiff
path: root/diag/ptrm_color.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/ptrm_color.py')
-rw-r--r--diag/ptrm_color.py33
1 files changed, 21 insertions, 12 deletions
diff --git a/diag/ptrm_color.py b/diag/ptrm_color.py
index 4004297..b24097f 100644
--- a/diag/ptrm_color.py
+++ b/diag/ptrm_color.py
@@ -3,7 +3,7 @@
deterministic / pass@K (conflict-min, ground truth) / lambda-select (min lambda1) / random.
-Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_full_...pt
+Run: PYTHONPATH=/home/yurenh2/rrog python3 diag/ptrm_color.py --ckpt runs/ckpt_color_rrog_trm_gin_full_...pt
"""
import argparse, json, os
import numpy as np
@@ -18,20 +18,24 @@ OUT = '/home/yurenh2/rrog/runs'
def rollout(model, xin, ei, sigma, n_sup, T, dev, seed):
gen = torch.Generator(device=dev).manual_seed(seed)
- h0 = model.lin_in(xin)
- z = torch.zeros_like(h0)
- v = torch.randn(h0.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
- def step(zz):
- return model.block(zz + h0, ei)
+ ctx = model.aggregate(xin, ei)
+ y, z = ctx, torch.zeros_like(ctx)
+ state = torch.cat([y, z], dim=-1)
+ v = torch.randn(state.shape, generator=gen, device=dev); v = v / (v.norm() + 1e-12)
+ def step(ss):
+ yy, zz = ss.chunk(2, dim=-1)
+ yy, zz = model.recurse(yy, zz, ctx, noise=False)
+ return torch.cat([yy, zz], dim=-1)
lam = 0.0
for _ in range(n_sup * T):
- z_det, Jv = torch.autograd.functional.jvp(step, z, v)
+ state_det, Jv = torch.autograd.functional.jvp(step, state, v)
nv = Jv.norm(); lam += torch.log(nv + 1e-12).item(); v = (Jv / (nv + 1e-12)).detach()
- z = z_det.detach()
+ state = state_det.detach()
if sigma > 0:
- z = z + sigma * torch.randn(z.shape, generator=gen, device=dev)
- lam /= (n_sup * T)
- col = model.head(z).argmax(-1)
+ state = state + sigma * torch.randn(state.shape, generator=gen, device=dev)
+ lam /= max(n_sup * T, 1)
+ y, _ = state.chunk(2, dim=-1)
+ col = model.head(y).argmax(-1)
conf = (col[ei[0]] == col[ei[1]]).sum().item() // 2
return conf, lam
@@ -47,7 +51,11 @@ def main():
ck = torch.load(args.ckpt, weights_only=False); c = ck['cfg']
deg = torch.tensor(c['deg']) if c.get('deg') else None
model = RecGINColor(c['in_dim'], c['hidden'], c['k'], c['T'], c['n_sup'],
- grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg).to(dev)
+ grad_mode=c['grad_mode'], conv=c.get('conv', 'gin'), deg=deg,
+ agg_layers=c.get('agg_layers', 1),
+ compute_layers=c.get('compute_layers', 2),
+ compute=(c.get('compute') if c.get('compute') == 'trm' else 'trm'),
+ attn_heads=c.get('attn_heads', 4)).to(dev)
model.load_state_dict(ck['state']); model.eval()
nsup, T = c['n_sup'], c['T']
te = featurize(make_split('test', 50, 3, 0.2, 8, 500, 100000), c.get('pe', 'none'), c.get('rwse_k', 16))
@@ -56,6 +64,7 @@ def main():
det = sum(rollout(model, r['xin'].to(dev), r['edge_index'].to(dev), 0.0, nsup, T, dev, 0)[0] == 0
for r in te) / n
out = {'conv': c.get('conv', 'gin'), 'pe': c.get('pe', 'none'), 'seed': c.get('seed'),
+ 'arch': c.get('arch', 'legacy'),
'grad_mode': c['grad_mode'], 'contract': c.get('contract', False), 'det': det, 'sigmas': {}}
print(f"[pe={out['pe']} s{out['seed']}] deterministic solve_rate = {det:.3f} (n={n}, K={args.K})")
print(f"{'sigma':>6} {'pass@K':>8} {'lam-sel':>8} {'random':>8} {'perRoll':>8} {'AUROC(s|-lam)':>14}")