diff options
Diffstat (limited to 'research/flossing/track_problem_learning.py')
| -rw-r--r-- | research/flossing/track_problem_learning.py | 375 |
1 files changed, 375 insertions, 0 deletions
diff --git a/research/flossing/track_problem_learning.py b/research/flossing/track_problem_learning.py new file mode 100644 index 0000000..d811bfd --- /dev/null +++ b/research/flossing/track_problem_learning.py @@ -0,0 +1,375 @@ +"""Track individual test problems across HRM/TRM checkpoint diagnostics. + +The HRM and TRM diagnostic runs in this workspace reuse the same 512 test +indices across checkpoints. This script turns those panel diagnostics into a +per-problem learning-order view: first success, stable success, regressions, +and spectrum changes around fail->success / success->fail transitions. +""" +from __future__ import annotations + +import argparse +import csv +import math +import re +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +ROOT = Path("/home/yurenh2/rrm/research/flossing") +DATA_ROOT = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test") + + +def parse_step(path: Path, kind: str) -> int: + if kind == "HRM": + return int(re.search(r"step_(\d+)_512", path.name).group(1)) + return int(re.search(r"step(\d+)_512", path.name).group(1)) + + +def step_to_epoch(step: int, kind: str) -> int: + if kind == "HRM": + return int(round(step * 20000 / 26040)) + # TRM file names correspond to 5k, 10k, ... 50k epochs. + return int(round(step * 5000 / 26041)) + + +def discover(kind: str) -> list[Path]: + if kind == "HRM": + files = sorted(ROOT.glob("diag_hrm_step_*_512.npz"), key=lambda p: parse_step(p, kind)) + else: + files = sorted(ROOT.glob("diag_trm_singleGPU_step*_512.npz"), key=lambda p: parse_step(p, kind)) + if not files: + raise FileNotFoundError(f"No {kind} diagnostics found") + return files + + +def sorted_features(lyap: np.ndarray) -> dict[str, np.ndarray]: + s = np.sort(lyap, axis=1)[:, ::-1] + pos = np.clip(s, 0, None) + return { + "lambda_max": s[:, 0], + "lambda_2": s[:, 1], + "lambda_min8": s[:, -1], + "mean8": s.mean(axis=1), + "tail_mean_5_8": s[:, 4:].mean(axis=1), + "positive_sum": pos.sum(axis=1), + "positive_count": (s > 0).sum(axis=1).astype(float), + "spread": s[:, 0] - s[:, -1], + } + + +def write_csv(path: Path, rows: list[dict]) -> None: + if not rows: + return + keys: list[str] = [] + for row in rows: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def classify(success: np.ndarray) -> tuple[str, int | None, int | None, int, int]: + """Return group, first_success_idx, stable_success_idx, n_success, transitions.""" + success = np.asarray(success, dtype=bool) + n_success = int(success.sum()) + transitions = int(np.abs(np.diff(success.astype(int))).sum()) + first_success_idx = int(np.argmax(success)) if success.any() else None + + stable_success_idx = None + for i in range(len(success)): + if success[i:].all(): + stable_success_idx = i + break + + if n_success == 0: + group = "never" + elif n_success == len(success): + group = "always" + elif stable_success_idx is not None: + if transitions <= 1: + group = "stable_learned" + else: + group = "flaky_then_stable" + else: + group = "transient_or_regressed" + return group, first_success_idx, stable_success_idx, n_success, transitions + + +def load_panel(kind: str): + files = discover(kind) + steps = np.array([parse_step(p, kind) for p in files], dtype=int) + epochs = np.array([step_to_epoch(s, kind) for s in steps], dtype=int) + + idx0 = np.load(files[0])["idx"] + exact_list = [] + token_list = [] + feature_time: dict[str, list[np.ndarray]] = {} + spectra = [] + + for path in files: + data = np.load(path) + idx = data["idx"] + if not np.array_equal(idx0, idx): + raise ValueError(f"{kind} diagnostics do not share the same idx array: {path}") + exact_list.append(data["exact_correct"].astype(bool)) + token_list.append(data["token_acc"].astype(float)) + lyap = data["lyap_spec"].astype(float) + spectra.append(np.sort(lyap, axis=1)[:, ::-1]) + for name, arr in sorted_features(lyap).items(): + feature_time.setdefault(name, []).append(arr) + + exact = np.stack(exact_list, axis=1) # (N, T) + token_acc = np.stack(token_list, axis=1) # (N, T) + spectrum = np.stack(spectra, axis=1) # (N, T, K) + features = {k: np.stack(v, axis=1) for k, v in feature_time.items()} + return files, idx0.astype(int), steps, epochs, exact, token_acc, spectrum, features + + +def add_problem_features(rows: list[dict], idx: np.ndarray) -> None: + inputs = np.load(DATA_ROOT / "all__inputs.npy", mmap_mode="r") + labels = np.load(DATA_ROOT / "all__labels.npy", mmap_mode="r") + sample_inputs = inputs[idx] + sample_labels = labels[idx] + # In these Sudoku tensors, 0 is padding, 1 is the blank-cell token, and + # digits are represented by 2..10. + givens = (sample_inputs > 1).sum(axis=1) + blanks = (sample_inputs == 1).sum(axis=1) + label_nonzero = (sample_labels > 0).sum(axis=1) + for row, g, b, lnz in zip(rows, givens, blanks, label_nonzero): + row["givens"] = int(g) + row["blanks"] = int(b) + row["label_nonzero"] = int(lnz) + + +def analyze_kind(kind: str, out_dir: Path) -> dict: + _, idx, steps, epochs, exact, token_acc, spectrum, features = load_panel(kind) + n, t = exact.shape + + problem_rows: list[dict] = [] + for i in range(n): + group, first_i, stable_i, n_success, transitions = classify(exact[i]) + row = { + "kind": kind, + "panel_row": i, + "test_idx": int(idx[i]), + "group": group, + "n_success_checkpoints": n_success, + "transitions": transitions, + "first_success_step": int(steps[first_i]) if first_i is not None else "", + "first_success_epoch": int(epochs[first_i]) if first_i is not None else "", + "stable_success_step": int(steps[stable_i]) if stable_i is not None else "", + "stable_success_epoch": int(epochs[stable_i]) if stable_i is not None else "", + "final_success": bool(exact[i, -1]), + "final_token_acc": float(token_acc[i, -1]), + "final_lambda_max": float(features["lambda_max"][i, -1]), + "final_mean8": float(features["mean8"][i, -1]), + "final_tail_mean_5_8": float(features["tail_mean_5_8"][i, -1]), + "final_positive_count": float(features["positive_count"][i, -1]), + "final_positive_sum": float(features["positive_sum"][i, -1]), + } + for j, step in enumerate(steps): + row[f"success@{step}"] = int(exact[i, j]) + row[f"lambda_max@{step}"] = float(features["lambda_max"][i, j]) + row[f"mean8@{step}"] = float(features["mean8"][i, j]) + row[f"positive_count@{step}"] = float(features["positive_count"][i, j]) + problem_rows.append(row) + add_problem_features(problem_rows, idx) + write_csv(out_dir / f"{kind.lower()}_problem_tracks.csv", problem_rows) + + event_rows: list[dict] = [] + for j in range(1, t): + prev = exact[:, j - 1] + cur = exact[:, j] + for event_name, mask in [ + ("learned_fail_to_success", (~prev) & cur), + ("lost_success_to_fail", prev & (~cur)), + ("stayed_failure", (~prev) & (~cur)), + ("stayed_success", prev & cur), + ]: + if not mask.any(): + continue + row = { + "kind": kind, + "from_step": int(steps[j - 1]), + "to_step": int(steps[j]), + "from_epoch": int(epochs[j - 1]), + "to_epoch": int(epochs[j]), + "event": event_name, + "n": int(mask.sum()), + "from_token_acc_mean": float(token_acc[mask, j - 1].mean()), + "to_token_acc_mean": float(token_acc[mask, j].mean()), + } + for feat_name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count", "positive_sum"]: + before = features[feat_name][mask, j - 1] + after = features[feat_name][mask, j] + row[f"{feat_name}_before"] = float(before.mean()) + row[f"{feat_name}_after"] = float(after.mean()) + row[f"{feat_name}_delta"] = float((after - before).mean()) + event_rows.append(row) + write_csv(out_dir / f"{kind.lower()}_learning_events.csv", event_rows) + + group_rows: list[dict] = [] + groups = sorted(set(r["group"] for r in problem_rows)) + for group in groups: + mask = np.array([r["group"] == group for r in problem_rows]) + row = { + "kind": kind, + "group": group, + "n": int(mask.sum()), + "fraction": float(mask.mean()), + "givens_mean": float(np.mean([r["givens"] for r, m in zip(problem_rows, mask) if m])), + "final_success_rate": float(exact[mask, -1].mean()), + "final_token_acc_mean": float(token_acc[mask, -1].mean()), + } + for feat_name in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count", "positive_sum"]: + row[f"initial_{feat_name}"] = float(features[feat_name][mask, 0].mean()) + row[f"final_{feat_name}"] = float(features[feat_name][mask, -1].mean()) + group_rows.append(row) + write_csv(out_dir / f"{kind.lower()}_learning_groups.csv", group_rows) + + plot_problem_heatmap(kind, steps, epochs, exact, features, problem_rows, out_dir) + plot_group_dynamics(kind, steps, epochs, exact, features, problem_rows, out_dir) + plot_event_deltas(kind, event_rows, out_dir) + + return { + "kind": kind, + "steps": steps.tolist(), + "epochs": epochs.tolist(), + "n": n, + "group_rows": group_rows, + "event_rows": event_rows, + } + + +def sort_order(problem_rows: list[dict], exact: np.ndarray) -> np.ndarray: + def key(i: int): + r = problem_rows[i] + first = r["first_success_epoch"] + stable = r["stable_success_epoch"] + first_val = int(first) if first != "" else 10**9 + stable_val = int(stable) if stable != "" else 10**9 + return (stable_val, first_val, -r["n_success_checkpoints"], r["transitions"], r["test_idx"]) + + return np.array(sorted(range(len(problem_rows)), key=key), dtype=int) + + +def plot_problem_heatmap( + kind: str, + steps: np.ndarray, + epochs: np.ndarray, + exact: np.ndarray, + features: dict[str, np.ndarray], + problem_rows: list[dict], + out_dir: Path, +) -> None: + order = sort_order(problem_rows, exact) + fig, axes = plt.subplots(1, 4, figsize=(14, 8), sharey=True) + matrices = [ + ("success", exact.astype(float), "Greens", 0, 1), + ("λmax", features["lambda_max"], "coolwarm", -0.12 if kind == "TRM" else -0.25, 0.12), + ("mean8", features["mean8"], "coolwarm", -0.16 if kind == "TRM" else -0.28, 0.08), + ("# positive", features["positive_count"], "magma", 0, 8), + ] + for ax, (title, mat, cmap, vmin, vmax) in zip(axes, matrices): + im = ax.imshow(mat[order], aspect="auto", interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) + ax.set_title(title) + ax.set_xticks(range(len(epochs))) + ax.set_xticklabels(epochs, rotation=45, ha="right") + ax.set_xlabel("epoch") + fig.colorbar(im, ax=ax, fraction=0.045, pad=0.02) + axes[0].set_ylabel("test problems sorted by stable learning time") + fig.suptitle(f"{kind}: individual problem learning tracks on matched 512-test panel") + fig.tight_layout(rect=[0, 0, 1, 0.96]) + fig.savefig(out_dir / f"{kind.lower()}_problem_track_heatmap.png", dpi=150) + plt.close(fig) + + +def plot_group_dynamics( + kind: str, + steps: np.ndarray, + epochs: np.ndarray, + exact: np.ndarray, + features: dict[str, np.ndarray], + problem_rows: list[dict], + out_dir: Path, +) -> None: + groups = ["always", "stable_learned", "flaky_then_stable", "transient_or_regressed", "never"] + colors = { + "always": "C2", + "stable_learned": "C0", + "flaky_then_stable": "C4", + "transient_or_regressed": "C1", + "never": "C3", + } + fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True) + for group in groups: + mask = np.array([r["group"] == group for r in problem_rows]) + if not mask.any(): + continue + label = f"{group} n={mask.sum()}" + axes[0, 0].plot(epochs, exact[mask].mean(axis=0), "-o", color=colors[group], label=label) + axes[0, 1].plot(epochs, features["lambda_max"][mask].mean(axis=0), "-o", color=colors[group]) + axes[1, 0].plot(epochs, features["mean8"][mask].mean(axis=0), "-o", color=colors[group]) + axes[1, 1].plot(epochs, features["positive_count"][mask].mean(axis=0), "-o", color=colors[group]) + axes[0, 0].set_title("success rate") + axes[0, 1].set_title("λmax") + axes[1, 0].set_title("mean top-8") + axes[1, 1].set_title("# positive among top 8") + for ax in axes.flat: + ax.axhline(0, color="k", lw=0.8, alpha=0.25) + ax.grid(alpha=0.3) + ax.set_xlabel("epoch") + axes[0, 0].legend(fontsize=8, loc="best") + fig.suptitle(f"{kind}: dynamics by per-problem learning class") + fig.tight_layout(rect=[0, 0, 1, 0.96]) + fig.savefig(out_dir / f"{kind.lower()}_learning_group_dynamics.png", dpi=150) + plt.close(fig) + + +def plot_event_deltas(kind: str, event_rows: list[dict], out_dir: Path) -> None: + learn = [r for r in event_rows if r["event"] == "learned_fail_to_success"] + if not learn: + return + epochs = [r["to_epoch"] for r in learn] + fig, ax = plt.subplots(figsize=(10, 4)) + for feat in ["lambda_max", "mean8", "tail_mean_5_8", "positive_count"]: + ax.plot(epochs, [r[f"{feat}_delta"] for r in learn], "-o", label=feat) + ax.axhline(0, color="k", lw=0.8, alpha=0.35) + ax.set_title(f"{kind}: mean spectrum change on fail→success transitions") + ax.set_xlabel("to epoch") + ax.set_ylabel("after - before") + ax.legend(ncol=2, fontsize=8) + ax.grid(alpha=0.3) + fig.tight_layout() + fig.savefig(out_dir / f"{kind.lower()}_learn_event_deltas.png", dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", default=str(ROOT / "problem_tracks")) + args = parser.parse_args() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + for kind in ["HRM", "TRM"]: + result = analyze_kind(kind, out_dir) + print(f"\n{kind}: N={result['n']} checkpoints={len(result['steps'])}") + for row in result["group_rows"]: + print( + f" {row['group']:<22s} n={row['n']:>3d} " + f"final_acc={row['final_success_rate']:.3f} " + f"final_mean8={row['final_mean8']:+.4f} " + f"final_poscnt={row['final_positive_count']:.2f}" + ) + + print(f"\nWrote {out_dir}") + + +if __name__ == "__main__": + main() |
