diff options
Diffstat (limited to 'research/flossing/flossing_suite/summarize_flossing.py')
| -rw-r--r-- | research/flossing/flossing_suite/summarize_flossing.py | 311 |
1 files changed, 311 insertions, 0 deletions
diff --git a/research/flossing/flossing_suite/summarize_flossing.py b/research/flossing/flossing_suite/summarize_flossing.py new file mode 100644 index 0000000..e753b13 --- /dev/null +++ b/research/flossing/flossing_suite/summarize_flossing.py @@ -0,0 +1,311 @@ +from __future__ import annotations + +import csv +import json +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt + + +ROOT = Path("/home/yurenh2/rrm") +FLOSS = ROOT / "research/flossing" +SUITE = FLOSS / "flossing_suite" +OUT = SUITE / "results/summary" + + +def load_json(path: Path) -> dict[str, Any] | None: + try: + return json.loads(path.read_text()) + except Exception as exc: + print(f"[skip] {path}: {exc}") + return None + + +def fnum(value: Any, default: float = float("nan")) -> float: + try: + if value is None: + return default + return float(value) + except Exception: + return default + + +def run_name(path: Path) -> str: + return path.relative_to(FLOSS).with_suffix("").as_posix() + + +def iter_json_paths() -> list[Path]: + patterns = [ + "engelken_python/*.json", + "engelken_paper_faithful/*.json", + "step6_*.json", + "step7_*.json", + "flossing_suite/results/**/*.json", + ] + paths: list[Path] = [] + for pattern in patterns: + paths.extend(FLOSS.glob(pattern)) + return sorted(set(paths)) + + +def classify(path: Path, data: dict[str, Any]) -> str: + if "config" in data: + return "toy_rnn" + if "phase1_steps" in data: + return "step6_prefloss_hrm" + if "args" in data and "floss_episodes" in data: + args = data.get("args", {}) + model = args.get("model", "unknown") + return f"step7_interfloss_{model}" + return "unknown" + + +def selector_from_args(args: dict[str, Any]) -> str: + schedule = str(args.get("interfloss_at", "")) + steps = int(fnum(args.get("floss_steps"), 0)) + mode = str(args.get("floss_mode", "none")) + kl = fnum(args.get("kl_beta"), 0.0) + every = int(fnum(args.get("interfloss_every"), 0)) + if not schedule or steps <= 0: + return "baseline_no_floss" + label = f"{mode}@{schedule}_steps{steps}" + if every > 0: + label += f"_every{every}" + if kl > 0: + label += f"_kl{kl:g}" + return label + + +def summarize_toy(path: Path, data: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]: + cfg = data.get("config", {}) + evals = data.get("evals", []) + first = evals[0] if evals else {} + last = evals[-1] if evals else {} + peak = max((fnum(e.get("eval_accuracy")) for e in evals), default=float("nan")) + pre_epochs = int(fnum(cfg.get("pre_epochs"), 0)) + inter_epochs = int(fnum(cfg.get("inter_epochs"), 0)) + max_inter = int(fnum(cfg.get("max_inter_episodes"), 0)) + if pre_epochs <= 0 and max_inter <= 0: + selector = "baseline_no_floss" + elif pre_epochs > 0 and max_inter <= 0: + selector = "prefloss" + else: + selector = "pre_interfloss" + + summary = { + "name": run_name(path), + "path": str(path), + "family": "toy_rnn", + "model": "vanilla_rnn", + "selector": selector, + "floss_mode": "engelken_l2", + "schedule": f"pre={pre_epochs},inter_period={cfg.get('inter_period')},max_inter={max_inter}", + "train_steps": cfg.get("train_epochs"), + "floss_steps": pre_epochs + inter_epochs * max_inter, + "k_lyap": cfg.get("n_lyap"), + "kl_beta": 0, + "initial_acc": fnum(first.get("eval_accuracy")), + "final_acc": fnum(last.get("eval_accuracy")), + "best_acc": peak, + "delta_final": fnum(last.get("eval_accuracy")) - fnum(first.get("eval_accuracy")), + "n_evals": len(evals), + "n_floss_episodes": (1 if pre_epochs > 0 else 0) + max_inter, + "status": "complete" if evals else "unknown", + } + rows = [] + for e in evals: + rows.append( + { + "name": summary["name"], + "family": summary["family"], + "model": summary["model"], + "selector": summary["selector"], + "x": e.get("epoch"), + "acc": e.get("eval_accuracy"), + "loss": e.get("eval_loss"), + "kind": "eval", + } + ) + return summary, rows + + +def summarize_step7(path: Path, data: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]: + args = data.get("args", {}) + evals = data.get("evals", []) + first = evals[0] if evals else {} + final_acc = fnum(data.get("final_acc"), fnum(evals[-1].get("acc") if evals else None)) + best = max((fnum(e.get("acc")) for e in evals), default=float("nan")) + model = str(args.get("model", "unknown")) + train_steps = int(fnum(args.get("train_steps"), 0)) + task_batch = int(fnum(args.get("task_batch_size", args.get("batch_size")), 0)) + floss_batch = int(fnum(args.get("floss_batch_size", args.get("batch_size")), 0)) + every = int(fnum(args.get("interfloss_every"), 0)) + start = int(fnum(args.get("interfloss_start"), 0)) + stop = int(fnum(args.get("interfloss_stop"), -1)) + schedule = str(args.get("interfloss_at", "")) + if every > 0: + schedule = f"{schedule}; every={every}, start={start}, stop={stop}" + summary = { + "name": run_name(path), + "path": str(path), + "family": "rrm_step7", + "model": model, + "selector": selector_from_args(args), + "floss_mode": args.get("floss_mode", "none"), + "schedule": schedule, + "train_steps": train_steps, + "task_batch_size": task_batch, + "floss_batch_size": floss_batch, + "effective_official_gbs768_steps": train_steps * task_batch / 768 if task_batch else "", + "floss_steps": args.get("floss_steps"), + "k_lyap": args.get("k_lyap"), + "kl_beta": args.get("kl_beta", 0), + "initial_acc": fnum(data.get("initial_acc"), fnum(first.get("acc"))), + "final_acc": final_acc, + "best_acc": best, + "delta_final": final_acc - fnum(data.get("initial_acc"), fnum(first.get("acc"))), + "n_evals": len(evals), + "n_floss_episodes": len(data.get("floss_episodes", [])), + "status": "complete" if "final_acc" in data else "running_or_incomplete", + } + rows = [] + for e in evals: + rows.append( + { + "name": summary["name"], + "family": summary["family"], + "model": model, + "selector": summary["selector"], + "x": e.get("train_step"), + "acc": e.get("acc"), + "loss": "", + "kind": e.get("kind", "eval"), + } + ) + return summary, rows + + +def summarize_step6(path: Path, data: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]: + args = data.get("args", {}) + phase1 = data.get("phase1_evals", []) + phase2 = data.get("phase2_evals", []) + evals = [] + evals.extend({"kind": "phase1", **e} for e in phase1) + evals.extend({"kind": "phase2", **e} for e in phase2) + final_acc = fnum(data.get("final_acc"), fnum(phase2[-1].get("acc") if phase2 else None)) + best = max((fnum(e.get("acc")) for e in evals), default=float("nan")) + pre_steps = int(fnum(args.get("prefloss_steps"), 0)) + selector = "baseline_no_prefloss" if pre_steps <= 0 else f"prefloss_{args.get('floss_mode')}" + summary = { + "name": run_name(path), + "path": str(path), + "family": "hrm_step6", + "model": "hrm", + "selector": selector, + "floss_mode": args.get("floss_mode", "none"), + "schedule": f"pre={pre_steps}", + "train_steps": args.get("train_steps"), + "floss_steps": args.get("prefloss_steps"), + "k_lyap": args.get("k_lyap"), + "kl_beta": 0, + "initial_acc": fnum(data.get("initial_acc")), + "final_acc": final_acc, + "best_acc": best, + "delta_final": final_acc - fnum(data.get("initial_acc")), + "n_evals": len(evals), + "n_floss_episodes": 1 if pre_steps > 0 else 0, + "status": "complete" if "final_acc" in data else "running_or_incomplete", + } + rows = [] + for e in evals: + rows.append( + { + "name": summary["name"], + "family": summary["family"], + "model": "hrm", + "selector": summary["selector"], + "x": e.get("step"), + "acc": e.get("acc"), + "loss": "", + "kind": e.get("kind", "eval"), + } + ) + return summary, rows + + +def collect() -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + summaries: list[dict[str, Any]] = [] + eval_rows: list[dict[str, Any]] = [] + for path in iter_json_paths(): + data = load_json(path) + if data is None: + continue + family = classify(path, data) + if family == "toy_rnn": + summary, rows = summarize_toy(path, data) + elif family.startswith("step7"): + summary, rows = summarize_step7(path, data) + elif family == "step6_prefloss_hrm": + summary, rows = summarize_step6(path, data) + else: + continue + summaries.append(summary) + eval_rows.extend(rows) + return summaries, eval_rows + + +def write_csv(path: Path, rows: list[dict[str, Any]]) -> None: + if not rows: + path.write_text("") + return + keys = list(rows[0].keys()) + for row in rows[1:]: + for key in row: + if key not in keys: + keys.append(key) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=keys) + writer.writeheader() + writer.writerows(rows) + + +def plot_family(eval_rows: list[dict[str, Any]], family: str, out: Path, title: str) -> None: + rows = [r for r in eval_rows if r["family"] == family and r.get("x") not in ("", None)] + if not rows: + return + fig, ax = plt.subplots(figsize=(8.5, 4.6)) + names = sorted({r["name"] for r in rows}) + for name in names: + nr = [r for r in rows if r["name"] == name] + nr.sort(key=lambda r: fnum(r["x"])) + xs = [fnum(r["x"]) for r in nr] + ys = [fnum(r["acc"]) for r in nr] + label = name.split("/")[-1] + ax.plot(xs, ys, marker="o", linewidth=1.8, label=label) + ax.set_title(title) + ax.set_xlabel("epoch / train step") + ax.set_ylabel("eval exact accuracy") + ax.grid(alpha=0.22) + ax.legend(frameon=False, fontsize=7) + fig.tight_layout() + fig.savefig(out, dpi=180) + plt.close(fig) + + +def main() -> None: + OUT.mkdir(parents=True, exist_ok=True) + summaries, eval_rows = collect() + summaries.sort(key=lambda r: (str(r["family"]), str(r["model"]), str(r["name"]))) + write_csv(OUT / "flossing_runs_summary.csv", summaries) + write_csv(OUT / "flossing_eval_curves.csv", eval_rows) + plot_family(eval_rows, "toy_rnn", OUT / "toy_rnn_eval_curves.png", "Toy RNN Engelken-style flossing") + plot_family(eval_rows, "rrm_step7", OUT / "rrm_step7_eval_curves.png", "HRM/TRM interfloss analogues") + plot_family(eval_rows, "hrm_step6", OUT / "hrm_step6_eval_curves.png", "HRM prefloss experiments") + print(f"wrote {OUT / 'flossing_runs_summary.csv'}") + print(f"wrote {OUT / 'flossing_eval_curves.csv'}") + print(f"runs: {len(summaries)} eval points: {len(eval_rows)}") + + +if __name__ == "__main__": + main() |
