summaryrefslogtreecommitdiff
path: root/research/flossing/track_problem_learning.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/track_problem_learning.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/track_problem_learning.py')
-rw-r--r--research/flossing/track_problem_learning.py375
1 files changed, 375 insertions, 0 deletions
diff --git a/research/flossing/track_problem_learning.py b/research/flossing/track_problem_learning.py
new file mode 100644
index 0000000..d811bfd
--- /dev/null
+++ b/research/flossing/track_problem_learning.py
@@ -0,0 +1,375 @@
+"""Track individual test problems across HRM/TRM checkpoint diagnostics.
+
+The HRM and TRM diagnostic runs in this workspace reuse the same 512 test
+indices across checkpoints. This script turns those panel diagnostics into a
+per-problem learning-order view: first success, stable success, regressions,
+and spectrum changes around fail->success / success->fail transitions.
+"""
+from __future__ import annotations
+
+import argparse
+import csv
+import math
+import re
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+ROOT = Path("/home/yurenh2/rrm/research/flossing")
+DATA_ROOT = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test")
+
+
+def parse_step(path: Path, kind: str) -> int:
+ if kind == "HRM":
+ return int(re.search(r"step_(\d+)_512", path.name).group(1))
+ return int(re.search(r"step(\d+)_512", path.name).group(1))
+
+
+def step_to_epoch(step: int, kind: str) -> int:
+ if kind == "HRM":
+ return int(round(step * 20000 / 26040))
+ # TRM file names correspond to 5k, 10k, ... 50k epochs.
+ return int(round(step * 5000 / 26041))
+
+
+def discover(kind: str) -> list[Path]:
+ if kind == "HRM":
+ files = sorted(ROOT.glob("diag_hrm_step_*_512.npz"), key=lambda p: parse_step(p, kind))
+ else:
+ files = sorted(ROOT.glob("diag_trm_singleGPU_step*_512.npz"), key=lambda p: parse_step(p, kind))
+ if not files:
+ raise FileNotFoundError(f"No {kind} diagnostics found")
+ return files
+
+
+def sorted_features(lyap: np.ndarray) -> dict[str, np.ndarray]:
+ s = np.sort(lyap, axis=1)[:, ::-1]
+ pos = np.clip(s, 0, None)
+ return {
+ "lambda_max": s[:, 0],
+ "lambda_2": s[:, 1],
+ "lambda_min8": s[:, -1],
+ "mean8": s.mean(axis=1),
+ "tail_mean_5_8": s[:, 4:].mean(axis=1),
+ "positive_sum": pos.sum(axis=1),
+ "positive_count": (s > 0).sum(axis=1).astype(float),
+ "spread": s[:, 0] - s[:, -1],
+ }
+
+
+def write_csv(path: Path, rows: list[dict]) -> None:
+ if not rows:
+ return
+ keys: list[str] = []
+ for row in rows:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def classify(success: np.ndarray) -> tuple[str, int | None, int | None, int, int]:
+ """Return group, first_success_idx, stable_success_idx, n_success, transitions."""
+ success = np.asarray(success, dtype=bool)
+ n_success = int(success.sum())
+ transitions = int(np.abs(np.diff(success.astype(int))).sum())
+ first_success_idx = int(np.argmax(success)) if success.any() else None
+
+ stable_success_idx = None
+ for i in range(len(success)):
+ if success[i:].all():
+ stable_success_idx = i
+ break
+
+ if n_success == 0:
+ group = "never"
+ elif n_success == len(success):
+ group = "always"
+ elif stable_success_idx is not None:
+ if transitions <= 1:
+ group = "stable_learned"
+ else:
+ group = "flaky_then_stable"
+ else:
+ group = "transient_or_regressed"
+ return group, first_success_idx, stable_success_idx, n_success, transitions
+
+
+def load_panel(kind: str):
+ files = discover(kind)
+ steps = np.array([parse_step(p, kind) for p in files], dtype=int)
+ epochs = np.array([step_to_epoch(s, kind) for s in steps], dtype=int)
+
+ idx0 = np.load(files[0])["idx"]
+ exact_list = []
+ token_list = []
+ feature_time: dict[str, list[np.ndarray]] = {}
+ spectra = []
+
+ for path in files:
+ data = np.load(path)
+ idx = data["idx"]
+ if not np.array_equal(idx0, idx):
+ raise ValueError(f"{kind} diagnostics do not share the same idx array: {path}")
+ exact_list.append(data["exact_correct"].astype(bool))
+ token_list.append(data["token_acc"].astype(float))
+ lyap = data["lyap_spec"].astype(float)
+ spectra.append(np.sort(lyap, axis=1)[:, ::-1])
+ for name, arr in sorted_features(lyap).items():
+ feature_time.setdefault(name, []).append(arr)
+
+ exact = np.stack(exact_list, axis=1) # (N, T)
+ token_acc = np.stack(token_list, axis=1) # (N, T)
+ spectrum = np.stack(spectra, axis=1) # (N, T, K)
+ features = {k: np.stack(v, axis=1) for k, v in feature_time.items()}
+ return files, idx0.astype(int), steps, epochs, exact, token_acc, spectrum, features
+
+
+def add_problem_features(rows: list[dict], idx: np.ndarray) -> None:
+ inputs = np.load(DATA_ROOT / "all__inputs.npy", mmap_mode="r")
+ labels = np.load(DATA_ROOT / "all__labels.npy", mmap_mode="r")
+ sample_inputs = inputs[idx]
+ sample_labels = labels[idx]
+ # In these Sudoku tensors, 0 is padding, 1 is the blank-cell token, and
+ # digits are represented by 2..10.
+ givens = (sample_inputs > 1).sum(axis=1)
+ blanks = (sample_inputs == 1).sum(axis=1)
+ label_nonzero = (sample_labels > 0).sum(axis=1)
+ for row, g, b, lnz in zip(rows, givens, blanks, label_nonzero):
+ row["givens"] = int(g)
+ row["blanks"] = int(b)
+ row["label_nonzero"] = int(lnz)
+
+
+def analyze_kind(kind: str, out_dir: Path) -> dict:
+ _, idx, steps, epochs, exact, token_acc, spectrum, features = load_panel(kind)
+ n, t = exact.shape
+
+ problem_rows: list[dict] = []
+ for i in range(n):
+ group, first_i, stable_i, n_success, transitions = classify(exact[i])
+ row = {
+ "kind": kind,
+ "panel_row": i,
+ "test_idx": int(idx[i]),
+ "group": group,
+ "n_success_checkpoints": n_success,
+ "transitions": transitions,
+ "first_success_step": int(steps[first_i]) if first_i is not None else "",
+ "first_success_epoch": int(epochs[first_i]) if first_i is not None else "",
+ "stable_success_step": int(steps[stable_i]) if stable_i is not None else "",
+ "stable_success_epoch": int(epochs[stable_i]) if stable_i is not None else "",
+ "final_success": bool(exact[i, -1]),
+ "final_token_acc": float(token_acc[i, -1]),
+ "final_lambda_max": float(features["lambda_max"][i, -1]),
+ "final_mean8": float(features["mean8"][i, -1]),
+ "final_tail_mean_5_8": float(features["tail_mean_5_8"][i, -1]),
+ "final_positive_count": float(features["positive_count"][i, -1]),
+ "final_positive_sum": float(features["positive_sum"][i, -1]),
+ }
+ for j, step in enumerate(steps):
+ row[f"success@{step}"] = int(exact[i, j])
+ row[f"lambda_max@{step}"] = float(features["lambda_max"][i, j])
+ row[f"mean8@{step}"] = float(features["mean8"][i, j])
+ row[f"positive_count@{step}"] = float(features["positive_count"][i, j])
+ problem_rows.append(row)
+ add_problem_features(problem_rows, idx)
+ write_csv(out_dir / f"{kind.lower()}_problem_tracks.csv", problem_rows)
+
+ event_rows: list[dict] = []
+ for j in range(1, t):
+ prev = exact[:, j - 1]
+ cur = exact[:, j]
+ for event_name, mask in [
+ ("learned_fail_to_success", (~prev) & cur),
+ ("lost_success_to_fail", prev & (~cur)),
+ ("stayed_failure", (~prev) & (~cur)),
+ ("stayed_success", prev & cur),
+ ]:
+ if not mask.any():
+ continue
+ row = {
+ "kind": kind,
+ "from_step": int(steps[j - 1]),
+ "to_step": int(steps[j]),
+ "from_epoch": int(epochs[j - 1]),
+ "to_epoch": int(epochs[j]),
+ "event": event_name,
+ "n": int(mask.sum()),
+ "from_token_acc_mean": float(token_acc[mask, j - 1].mean()),
+ "to_token_acc_mean": float(token_acc[mask, j].mean()),
+ }
+ for feat_name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count", "positive_sum"]:
+ before = features[feat_name][mask, j - 1]
+ after = features[feat_name][mask, j]
+ row[f"{feat_name}_before"] = float(before.mean())
+ row[f"{feat_name}_after"] = float(after.mean())
+ row[f"{feat_name}_delta"] = float((after - before).mean())
+ event_rows.append(row)
+ write_csv(out_dir / f"{kind.lower()}_learning_events.csv", event_rows)
+
+ group_rows: list[dict] = []
+ groups = sorted(set(r["group"] for r in problem_rows))
+ for group in groups:
+ mask = np.array([r["group"] == group for r in problem_rows])
+ row = {
+ "kind": kind,
+ "group": group,
+ "n": int(mask.sum()),
+ "fraction": float(mask.mean()),
+ "givens_mean": float(np.mean([r["givens"] for r, m in zip(problem_rows, mask) if m])),
+ "final_success_rate": float(exact[mask, -1].mean()),
+ "final_token_acc_mean": float(token_acc[mask, -1].mean()),
+ }
+ for feat_name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count", "positive_sum"]:
+ row[f"initial_{feat_name}"] = float(features[feat_name][mask, 0].mean())
+ row[f"final_{feat_name}"] = float(features[feat_name][mask, -1].mean())
+ group_rows.append(row)
+ write_csv(out_dir / f"{kind.lower()}_learning_groups.csv", group_rows)
+
+ plot_problem_heatmap(kind, steps, epochs, exact, features, problem_rows, out_dir)
+ plot_group_dynamics(kind, steps, epochs, exact, features, problem_rows, out_dir)
+ plot_event_deltas(kind, event_rows, out_dir)
+
+ return {
+ "kind": kind,
+ "steps": steps.tolist(),
+ "epochs": epochs.tolist(),
+ "n": n,
+ "group_rows": group_rows,
+ "event_rows": event_rows,
+ }
+
+
+def sort_order(problem_rows: list[dict], exact: np.ndarray) -> np.ndarray:
+ def key(i: int):
+ r = problem_rows[i]
+ first = r["first_success_epoch"]
+ stable = r["stable_success_epoch"]
+ first_val = int(first) if first != "" else 10**9
+ stable_val = int(stable) if stable != "" else 10**9
+ return (stable_val, first_val, -r["n_success_checkpoints"], r["transitions"], r["test_idx"])
+
+ return np.array(sorted(range(len(problem_rows)), key=key), dtype=int)
+
+
+def plot_problem_heatmap(
+ kind: str,
+ steps: np.ndarray,
+ epochs: np.ndarray,
+ exact: np.ndarray,
+ features: dict[str, np.ndarray],
+ problem_rows: list[dict],
+ out_dir: Path,
+) -> None:
+ order = sort_order(problem_rows, exact)
+ fig, axes = plt.subplots(1, 4, figsize=(14, 8), sharey=True)
+ matrices = [
+ ("success", exact.astype(float), "Greens", 0, 1),
+ ("λmax", features["lambda_max"], "coolwarm", -0.12 if kind == "TRM" else -0.25, 0.12),
+ ("mean8", features["mean8"], "coolwarm", -0.16 if kind == "TRM" else -0.28, 0.08),
+ ("# positive", features["positive_count"], "magma", 0, 8),
+ ]
+ for ax, (title, mat, cmap, vmin, vmax) in zip(axes, matrices):
+ im = ax.imshow(mat[order], aspect="auto", interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
+ ax.set_title(title)
+ ax.set_xticks(range(len(epochs)))
+ ax.set_xticklabels(epochs, rotation=45, ha="right")
+ ax.set_xlabel("epoch")
+ fig.colorbar(im, ax=ax, fraction=0.045, pad=0.02)
+ axes[0].set_ylabel("test problems sorted by stable learning time")
+ fig.suptitle(f"{kind}: individual problem learning tracks on matched 512-test panel")
+ fig.tight_layout(rect=[0, 0, 1, 0.96])
+ fig.savefig(out_dir / f"{kind.lower()}_problem_track_heatmap.png", dpi=150)
+ plt.close(fig)
+
+
+def plot_group_dynamics(
+ kind: str,
+ steps: np.ndarray,
+ epochs: np.ndarray,
+ exact: np.ndarray,
+ features: dict[str, np.ndarray],
+ problem_rows: list[dict],
+ out_dir: Path,
+) -> None:
+ groups = ["always", "stable_learned", "flaky_then_stable", "transient_or_regressed", "never"]
+ colors = {
+ "always": "C2",
+ "stable_learned": "C0",
+ "flaky_then_stable": "C4",
+ "transient_or_regressed": "C1",
+ "never": "C3",
+ }
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
+ for group in groups:
+ mask = np.array([r["group"] == group for r in problem_rows])
+ if not mask.any():
+ continue
+ label = f"{group} n={mask.sum()}"
+ axes[0, 0].plot(epochs, exact[mask].mean(axis=0), "-o", color=colors[group], label=label)
+ axes[0, 1].plot(epochs, features["lambda_max"][mask].mean(axis=0), "-o", color=colors[group])
+ axes[1, 0].plot(epochs, features["mean8"][mask].mean(axis=0), "-o", color=colors[group])
+ axes[1, 1].plot(epochs, features["positive_count"][mask].mean(axis=0), "-o", color=colors[group])
+ axes[0, 0].set_title("success rate")
+ axes[0, 1].set_title("λmax")
+ axes[1, 0].set_title("mean top-8")
+ axes[1, 1].set_title("# positive among top 8")
+ for ax in axes.flat:
+ ax.axhline(0, color="k", lw=0.8, alpha=0.25)
+ ax.grid(alpha=0.3)
+ ax.set_xlabel("epoch")
+ axes[0, 0].legend(fontsize=8, loc="best")
+ fig.suptitle(f"{kind}: dynamics by per-problem learning class")
+ fig.tight_layout(rect=[0, 0, 1, 0.96])
+ fig.savefig(out_dir / f"{kind.lower()}_learning_group_dynamics.png", dpi=150)
+ plt.close(fig)
+
+
+def plot_event_deltas(kind: str, event_rows: list[dict], out_dir: Path) -> None:
+ learn = [r for r in event_rows if r["event"] == "learned_fail_to_success"]
+ if not learn:
+ return
+ epochs = [r["to_epoch"] for r in learn]
+ fig, ax = plt.subplots(figsize=(10, 4))
+ for feat in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count"]:
+ ax.plot(epochs, [r[f"{feat}_delta"] for r in learn], "-o", label=feat)
+ ax.axhline(0, color="k", lw=0.8, alpha=0.35)
+ ax.set_title(f"{kind}: mean spectrum change on fail→success transitions")
+ ax.set_xlabel("to epoch")
+ ax.set_ylabel("after - before")
+ ax.legend(ncol=2, fontsize=8)
+ ax.grid(alpha=0.3)
+ fig.tight_layout()
+ fig.savefig(out_dir / f"{kind.lower()}_learn_event_deltas.png", dpi=150)
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--out-dir", default=str(ROOT / "problem_tracks"))
+ args = parser.parse_args()
+ out_dir = Path(args.out_dir)
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ for kind in ["HRM", "TRM"]:
+ result = analyze_kind(kind, out_dir)
+ print(f"\n{kind}: N={result['n']} checkpoints={len(result['steps'])}")
+ for row in result["group_rows"]:
+ print(
+ f" {row['group']:<22s} n={row['n']:>3d} "
+ f"final_acc={row['final_success_rate']:.3f} "
+ f"final_mean8={row['final_mean8']:+.4f} "
+ f"final_poscnt={row['final_positive_count']:.2f}"
+ )
+
+ print(f"\nWrote {out_dir}")
+
+
+if __name__ == "__main__":
+ main()