"""Track individual test problems across HRM/TRM checkpoint diagnostics. The HRM and TRM diagnostic runs in this workspace reuse the same 512 test indices across checkpoints. This script turns those panel diagnostics into a per-problem learning-order view: first success, stable success, regressions, and spectrum changes around fail->success / success->fail transitions. """ from __future__ import annotations import argparse import csv import math import re from pathlib import Path import matplotlib.pyplot as plt import numpy as np ROOT = Path("/home/yurenh2/rrm/research/flossing") DATA_ROOT = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test") def parse_step(path: Path, kind: str) -> int: if kind == "HRM": return int(re.search(r"step_(\d+)_512", path.name).group(1)) return int(re.search(r"step(\d+)_512", path.name).group(1)) def step_to_epoch(step: int, kind: str) -> int: if kind == "HRM": return int(round(step * 20000 / 26040)) # TRM file names correspond to 5k, 10k, ... 50k epochs. return int(round(step * 5000 / 26041)) def discover(kind: str) -> list[Path]: if kind == "HRM": files = sorted(ROOT.glob("diag_hrm_step_*_512.npz"), key=lambda p: parse_step(p, kind)) else: files = sorted(ROOT.glob("diag_trm_singleGPU_step*_512.npz"), key=lambda p: parse_step(p, kind)) if not files: raise FileNotFoundError(f"No {kind} diagnostics found") return files def sorted_features(lyap: np.ndarray) -> dict[str, np.ndarray]: s = np.sort(lyap, axis=1)[:, ::-1] pos = np.clip(s, 0, None) return { "lambda_max": s[:, 0], "lambda_2": s[:, 1], "lambda_min8": s[:, -1], "mean8": s.mean(axis=1), "tail_mean_5_8": s[:, 4:].mean(axis=1), "positive_sum": pos.sum(axis=1), "positive_count": (s > 0).sum(axis=1).astype(float), "spread": s[:, 0] - s[:, -1], } def write_csv(path: Path, rows: list[dict]) -> None: if not rows: return keys: list[str] = [] for row in rows: for key in row: if key not in keys: keys.append(key) with path.open("w", newline="") as f: writer = csv.DictWriter(f, fieldnames=keys) writer.writeheader() writer.writerows(rows) def classify(success: np.ndarray) -> tuple[str, int | None, int | None, int, int]: """Return group, first_success_idx, stable_success_idx, n_success, transitions.""" success = np.asarray(success, dtype=bool) n_success = int(success.sum()) transitions = int(np.abs(np.diff(success.astype(int))).sum()) first_success_idx = int(np.argmax(success)) if success.any() else None stable_success_idx = None for i in range(len(success)): if success[i:].all(): stable_success_idx = i break if n_success == 0: group = "never" elif n_success == len(success): group = "always" elif stable_success_idx is not None: if transitions <= 1: group = "stable_learned" else: group = "flaky_then_stable" else: group = "transient_or_regressed" return group, first_success_idx, stable_success_idx, n_success, transitions def load_panel(kind: str): files = discover(kind) steps = np.array([parse_step(p, kind) for p in files], dtype=int) epochs = np.array([step_to_epoch(s, kind) for s in steps], dtype=int) idx0 = np.load(files[0])["idx"] exact_list = [] token_list = [] feature_time: dict[str, list[np.ndarray]] = {} spectra = [] for path in files: data = np.load(path) idx = data["idx"] if not np.array_equal(idx0, idx): raise ValueError(f"{kind} diagnostics do not share the same idx array: {path}") exact_list.append(data["exact_correct"].astype(bool)) token_list.append(data["token_acc"].astype(float)) lyap = data["lyap_spec"].astype(float) spectra.append(np.sort(lyap, axis=1)[:, ::-1]) for name, arr in sorted_features(lyap).items(): feature_time.setdefault(name, []).append(arr) exact = np.stack(exact_list, axis=1) # (N, T) token_acc = np.stack(token_list, axis=1) # (N, T) spectrum = np.stack(spectra, axis=1) # (N, T, K) features = {k: np.stack(v, axis=1) for k, v in feature_time.items()} return files, idx0.astype(int), steps, epochs, exact, token_acc, spectrum, features def add_problem_features(rows: list[dict], idx: np.ndarray) -> None: inputs = np.load(DATA_ROOT / "all__inputs.npy", mmap_mode="r") labels = np.load(DATA_ROOT / "all__labels.npy", mmap_mode="r") sample_inputs = inputs[idx] sample_labels = labels[idx] # In these Sudoku tensors, 0 is padding, 1 is the blank-cell token, and # digits are represented by 2..10. givens = (sample_inputs > 1).sum(axis=1) blanks = (sample_inputs == 1).sum(axis=1) label_nonzero = (sample_labels > 0).sum(axis=1) for row, g, b, lnz in zip(rows, givens, blanks, label_nonzero): row["givens"] = int(g) row["blanks"] = int(b) row["label_nonzero"] = int(lnz) def analyze_kind(kind: str, out_dir: Path) -> dict: _, idx, steps, epochs, exact, token_acc, spectrum, features = load_panel(kind) n, t = exact.shape problem_rows: list[dict] = [] for i in range(n): group, first_i, stable_i, n_success, transitions = classify(exact[i]) row = { "kind": kind, "panel_row": i, "test_idx": int(idx[i]), "group": group, "n_success_checkpoints": n_success, "transitions": transitions, "first_success_step": int(steps[first_i]) if first_i is not None else "", "first_success_epoch": int(epochs[first_i]) if first_i is not None else "", "stable_success_step": int(steps[stable_i]) if stable_i is not None else "", "stable_success_epoch": int(epochs[stable_i]) if stable_i is not None else "", "final_success": bool(exact[i, -1]), "final_token_acc": float(token_acc[i, -1]), "final_lambda_max": float(features["lambda_max"][i, -1]), "final_mean8": float(features["mean8"][i, -1]), "final_tail_mean_5_8": float(features["tail_mean_5_8"][i, -1]), "final_positive_count": float(features["positive_count"][i, -1]), "final_positive_sum": float(features["positive_sum"][i, -1]), } for j, step in enumerate(steps): row[f"success@{step}"] = int(exact[i, j]) row[f"lambda_max@{step}"] = float(features["lambda_max"][i, j]) row[f"mean8@{step}"] = float(features["mean8"][i, j]) row[f"positive_count@{step}"] = float(features["positive_count"][i, j]) problem_rows.append(row) add_problem_features(problem_rows, idx) write_csv(out_dir / f"{kind.lower()}_problem_tracks.csv", problem_rows) event_rows: list[dict] = [] for j in range(1, t): prev = exact[:, j - 1] cur = exact[:, j] for event_name, mask in [ ("learned_fail_to_success", (~prev) & cur), ("lost_success_to_fail", prev & (~cur)), ("stayed_failure", (~prev) & (~cur)), ("stayed_success", prev & cur), ]: if not mask.any(): continue row = { "kind": kind, "from_step": int(steps[j - 1]), "to_step": int(steps[j]), "from_epoch": int(epochs[j - 1]), "to_epoch": int(epochs[j]), "event": event_name, "n": int(mask.sum()), "from_token_acc_mean": float(token_acc[mask, j - 1].mean()), "to_token_acc_mean": float(token_acc[mask, j].mean()), } for feat_name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count", "positive_sum"]: before = features[feat_name][mask, j - 1] after = features[feat_name][mask, j] row[f"{feat_name}_before"] = float(before.mean()) row[f"{feat_name}_after"] = float(after.mean()) row[f"{feat_name}_delta"] = float((after - before).mean()) event_rows.append(row) write_csv(out_dir / f"{kind.lower()}_learning_events.csv", event_rows) group_rows: list[dict] = [] groups = sorted(set(r["group"] for r in problem_rows)) for group in groups: mask = np.array([r["group"] == group for r in problem_rows]) row = { "kind": kind, "group": group, "n": int(mask.sum()), "fraction": float(mask.mean()), "givens_mean": float(np.mean([r["givens"] for r, m in zip(problem_rows, mask) if m])), "final_success_rate": float(exact[mask, -1].mean()), "final_token_acc_mean": float(token_acc[mask, -1].mean()), } for feat_name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count", "positive_sum"]: row[f"initial_{feat_name}"] = float(features[feat_name][mask, 0].mean()) row[f"final_{feat_name}"] = float(features[feat_name][mask, -1].mean()) group_rows.append(row) write_csv(out_dir / f"{kind.lower()}_learning_groups.csv", group_rows) plot_problem_heatmap(kind, steps, epochs, exact, features, problem_rows, out_dir) plot_group_dynamics(kind, steps, epochs, exact, features, problem_rows, out_dir) plot_event_deltas(kind, event_rows, out_dir) return { "kind": kind, "steps": steps.tolist(), "epochs": epochs.tolist(), "n": n, "group_rows": group_rows, "event_rows": event_rows, } def sort_order(problem_rows: list[dict], exact: np.ndarray) -> np.ndarray: def key(i: int): r = problem_rows[i] first = r["first_success_epoch"] stable = r["stable_success_epoch"] first_val = int(first) if first != "" else 10**9 stable_val = int(stable) if stable != "" else 10**9 return (stable_val, first_val, -r["n_success_checkpoints"], r["transitions"], r["test_idx"]) return np.array(sorted(range(len(problem_rows)), key=key), dtype=int) def plot_problem_heatmap( kind: str, steps: np.ndarray, epochs: np.ndarray, exact: np.ndarray, features: dict[str, np.ndarray], problem_rows: list[dict], out_dir: Path, ) -> None: order = sort_order(problem_rows, exact) fig, axes = plt.subplots(1, 4, figsize=(14, 8), sharey=True) matrices = [ ("success", exact.astype(float), "Greens", 0, 1), ("λmax", features["lambda_max"], "coolwarm", -0.12 if kind == "TRM" else -0.25, 0.12), ("mean8", features["mean8"], "coolwarm", -0.16 if kind == "TRM" else -0.28, 0.08), ("# positive", features["positive_count"], "magma", 0, 8), ] for ax, (title, mat, cmap, vmin, vmax) in zip(axes, matrices): im = ax.imshow(mat[order], aspect="auto", interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) ax.set_title(title) ax.set_xticks(range(len(epochs))) ax.set_xticklabels(epochs, rotation=45, ha="right") ax.set_xlabel("epoch") fig.colorbar(im, ax=ax, fraction=0.045, pad=0.02) axes[0].set_ylabel("test problems sorted by stable learning time") fig.suptitle(f"{kind}: individual problem learning tracks on matched 512-test panel") fig.tight_layout(rect=[0, 0, 1, 0.96]) fig.savefig(out_dir / f"{kind.lower()}_problem_track_heatmap.png", dpi=150) plt.close(fig) def plot_group_dynamics( kind: str, steps: np.ndarray, epochs: np.ndarray, exact: np.ndarray, features: dict[str, np.ndarray], problem_rows: list[dict], out_dir: Path, ) -> None: groups = ["always", "stable_learned", "flaky_then_stable", "transient_or_regressed", "never"] colors = { "always": "C2", "stable_learned": "C0", "flaky_then_stable": "C4", "transient_or_regressed": "C1", "never": "C3", } fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True) for group in groups: mask = np.array([r["group"] == group for r in problem_rows]) if not mask.any(): continue label = f"{group} n={mask.sum()}" axes[0, 0].plot(epochs, exact[mask].mean(axis=0), "-o", color=colors[group], label=label) axes[0, 1].plot(epochs, features["lambda_max"][mask].mean(axis=0), "-o", color=colors[group]) axes[1, 0].plot(epochs, features["mean8"][mask].mean(axis=0), "-o", color=colors[group]) axes[1, 1].plot(epochs, features["positive_count"][mask].mean(axis=0), "-o", color=colors[group]) axes[0, 0].set_title("success rate") axes[0, 1].set_title("λmax") axes[1, 0].set_title("mean top-8") axes[1, 1].set_title("# positive among top 8") for ax in axes.flat: ax.axhline(0, color="k", lw=0.8, alpha=0.25) ax.grid(alpha=0.3) ax.set_xlabel("epoch") axes[0, 0].legend(fontsize=8, loc="best") fig.suptitle(f"{kind}: dynamics by per-problem learning class") fig.tight_layout(rect=[0, 0, 1, 0.96]) fig.savefig(out_dir / f"{kind.lower()}_learning_group_dynamics.png", dpi=150) plt.close(fig) def plot_event_deltas(kind: str, event_rows: list[dict], out_dir: Path) -> None: learn = [r for r in event_rows if r["event"] == "learned_fail_to_success"] if not learn: return epochs = [r["to_epoch"] for r in learn] fig, ax = plt.subplots(figsize=(10, 4)) for feat in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count"]: ax.plot(epochs, [r[f"{feat}_delta"] for r in learn], "-o", label=feat) ax.axhline(0, color="k", lw=0.8, alpha=0.35) ax.set_title(f"{kind}: mean spectrum change on fail→success transitions") ax.set_xlabel("to epoch") ax.set_ylabel("after - before") ax.legend(ncol=2, fontsize=8) ax.grid(alpha=0.3) fig.tight_layout() fig.savefig(out_dir / f"{kind.lower()}_learn_event_deltas.png", dpi=150) plt.close(fig) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--out-dir", default=str(ROOT / "problem_tracks")) args = parser.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) for kind in ["HRM", "TRM"]: result = analyze_kind(kind, out_dir) print(f"\n{kind}: N={result['n']} checkpoints={len(result['steps'])}") for row in result["group_rows"]: print( f" {row['group']:<22s} n={row['n']:>3d} " f"final_acc={row['final_success_rate']:.3f} " f"final_mean8={row['final_mean8']:+.4f} " f"final_poscnt={row['final_positive_count']:.2f}" ) print(f"\nWrote {out_dir}") if __name__ == "__main__": main()