summaryrefslogtreecommitdiff
path: root/research/flossing/flossing_suite/summarize_flossing.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/flossing_suite/summarize_flossing.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/flossing_suite/summarize_flossing.py')
-rw-r--r--research/flossing/flossing_suite/summarize_flossing.py311
1 files changed, 311 insertions, 0 deletions
diff --git a/research/flossing/flossing_suite/summarize_flossing.py b/research/flossing/flossing_suite/summarize_flossing.py
new file mode 100644
index 0000000..e753b13
--- /dev/null
+++ b/research/flossing/flossing_suite/summarize_flossing.py
@@ -0,0 +1,311 @@
+from __future__ import annotations
+
+import csv
+import json
+from pathlib import Path
+from typing import Any
+
+import matplotlib.pyplot as plt
+
+
+ROOT = Path("/home/yurenh2/rrm")
+FLOSS = ROOT / "research/flossing"
+SUITE = FLOSS / "flossing_suite"
+OUT = SUITE / "results/summary"
+
+
+def load_json(path: Path) -> dict[str, Any] | None:
+ try:
+ return json.loads(path.read_text())
+ except Exception as exc:
+ print(f"[skip] {path}: {exc}")
+ return None
+
+
+def fnum(value: Any, default: float = float("nan")) -> float:
+ try:
+ if value is None:
+ return default
+ return float(value)
+ except Exception:
+ return default
+
+
+def run_name(path: Path) -> str:
+ return path.relative_to(FLOSS).with_suffix("").as_posix()
+
+
+def iter_json_paths() -> list[Path]:
+ patterns = [
+ "engelken_python/*.json",
+ "engelken_paper_faithful/*.json",
+ "step6_*.json",
+ "step7_*.json",
+ "flossing_suite/results/**/*.json",
+ ]
+ paths: list[Path] = []
+ for pattern in patterns:
+ paths.extend(FLOSS.glob(pattern))
+ return sorted(set(paths))
+
+
+def classify(path: Path, data: dict[str, Any]) -> str:
+ if "config" in data:
+ return "toy_rnn"
+ if "phase1_steps" in data:
+ return "step6_prefloss_hrm"
+ if "args" in data and "floss_episodes" in data:
+ args = data.get("args", {})
+ model = args.get("model", "unknown")
+ return f"step7_interfloss_{model}"
+ return "unknown"
+
+
+def selector_from_args(args: dict[str, Any]) -> str:
+ schedule = str(args.get("interfloss_at", ""))
+ steps = int(fnum(args.get("floss_steps"), 0))
+ mode = str(args.get("floss_mode", "none"))
+ kl = fnum(args.get("kl_beta"), 0.0)
+ every = int(fnum(args.get("interfloss_every"), 0))
+ if not schedule or steps <= 0:
+ return "baseline_no_floss"
+ label = f"{mode}@{schedule}_steps{steps}"
+ if every > 0:
+ label += f"_every{every}"
+ if kl > 0:
+ label += f"_kl{kl:g}"
+ return label
+
+
+def summarize_toy(path: Path, data: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
+ cfg = data.get("config", {})
+ evals = data.get("evals", [])
+ first = evals[0] if evals else {}
+ last = evals[-1] if evals else {}
+ peak = max((fnum(e.get("eval_accuracy")) for e in evals), default=float("nan"))
+ pre_epochs = int(fnum(cfg.get("pre_epochs"), 0))
+ inter_epochs = int(fnum(cfg.get("inter_epochs"), 0))
+ max_inter = int(fnum(cfg.get("max_inter_episodes"), 0))
+ if pre_epochs <= 0 and max_inter <= 0:
+ selector = "baseline_no_floss"
+ elif pre_epochs > 0 and max_inter <= 0:
+ selector = "prefloss"
+ else:
+ selector = "pre_interfloss"
+
+ summary = {
+ "name": run_name(path),
+ "path": str(path),
+ "family": "toy_rnn",
+ "model": "vanilla_rnn",
+ "selector": selector,
+ "floss_mode": "engelken_l2",
+ "schedule": f"pre={pre_epochs},inter_period={cfg.get('inter_period')},max_inter={max_inter}",
+ "train_steps": cfg.get("train_epochs"),
+ "floss_steps": pre_epochs + inter_epochs * max_inter,
+ "k_lyap": cfg.get("n_lyap"),
+ "kl_beta": 0,
+ "initial_acc": fnum(first.get("eval_accuracy")),
+ "final_acc": fnum(last.get("eval_accuracy")),
+ "best_acc": peak,
+ "delta_final": fnum(last.get("eval_accuracy")) - fnum(first.get("eval_accuracy")),
+ "n_evals": len(evals),
+ "n_floss_episodes": (1 if pre_epochs > 0 else 0) + max_inter,
+ "status": "complete" if evals else "unknown",
+ }
+ rows = []
+ for e in evals:
+ rows.append(
+ {
+ "name": summary["name"],
+ "family": summary["family"],
+ "model": summary["model"],
+ "selector": summary["selector"],
+ "x": e.get("epoch"),
+ "acc": e.get("eval_accuracy"),
+ "loss": e.get("eval_loss"),
+ "kind": "eval",
+ }
+ )
+ return summary, rows
+
+
+def summarize_step7(path: Path, data: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
+ args = data.get("args", {})
+ evals = data.get("evals", [])
+ first = evals[0] if evals else {}
+ final_acc = fnum(data.get("final_acc"), fnum(evals[-1].get("acc") if evals else None))
+ best = max((fnum(e.get("acc")) for e in evals), default=float("nan"))
+ model = str(args.get("model", "unknown"))
+ train_steps = int(fnum(args.get("train_steps"), 0))
+ task_batch = int(fnum(args.get("task_batch_size", args.get("batch_size")), 0))
+ floss_batch = int(fnum(args.get("floss_batch_size", args.get("batch_size")), 0))
+ every = int(fnum(args.get("interfloss_every"), 0))
+ start = int(fnum(args.get("interfloss_start"), 0))
+ stop = int(fnum(args.get("interfloss_stop"), -1))
+ schedule = str(args.get("interfloss_at", ""))
+ if every > 0:
+ schedule = f"{schedule}; every={every}, start={start}, stop={stop}"
+ summary = {
+ "name": run_name(path),
+ "path": str(path),
+ "family": "rrm_step7",
+ "model": model,
+ "selector": selector_from_args(args),
+ "floss_mode": args.get("floss_mode", "none"),
+ "schedule": schedule,
+ "train_steps": train_steps,
+ "task_batch_size": task_batch,
+ "floss_batch_size": floss_batch,
+ "effective_official_gbs768_steps": train_steps * task_batch / 768 if task_batch else "",
+ "floss_steps": args.get("floss_steps"),
+ "k_lyap": args.get("k_lyap"),
+ "kl_beta": args.get("kl_beta", 0),
+ "initial_acc": fnum(data.get("initial_acc"), fnum(first.get("acc"))),
+ "final_acc": final_acc,
+ "best_acc": best,
+ "delta_final": final_acc - fnum(data.get("initial_acc"), fnum(first.get("acc"))),
+ "n_evals": len(evals),
+ "n_floss_episodes": len(data.get("floss_episodes", [])),
+ "status": "complete" if "final_acc" in data else "running_or_incomplete",
+ }
+ rows = []
+ for e in evals:
+ rows.append(
+ {
+ "name": summary["name"],
+ "family": summary["family"],
+ "model": model,
+ "selector": summary["selector"],
+ "x": e.get("train_step"),
+ "acc": e.get("acc"),
+ "loss": "",
+ "kind": e.get("kind", "eval"),
+ }
+ )
+ return summary, rows
+
+
+def summarize_step6(path: Path, data: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]:
+ args = data.get("args", {})
+ phase1 = data.get("phase1_evals", [])
+ phase2 = data.get("phase2_evals", [])
+ evals = []
+ evals.extend({"kind": "phase1", **e} for e in phase1)
+ evals.extend({"kind": "phase2", **e} for e in phase2)
+ final_acc = fnum(data.get("final_acc"), fnum(phase2[-1].get("acc") if phase2 else None))
+ best = max((fnum(e.get("acc")) for e in evals), default=float("nan"))
+ pre_steps = int(fnum(args.get("prefloss_steps"), 0))
+ selector = "baseline_no_prefloss" if pre_steps <= 0 else f"prefloss_{args.get('floss_mode')}"
+ summary = {
+ "name": run_name(path),
+ "path": str(path),
+ "family": "hrm_step6",
+ "model": "hrm",
+ "selector": selector,
+ "floss_mode": args.get("floss_mode", "none"),
+ "schedule": f"pre={pre_steps}",
+ "train_steps": args.get("train_steps"),
+ "floss_steps": args.get("prefloss_steps"),
+ "k_lyap": args.get("k_lyap"),
+ "kl_beta": 0,
+ "initial_acc": fnum(data.get("initial_acc")),
+ "final_acc": final_acc,
+ "best_acc": best,
+ "delta_final": final_acc - fnum(data.get("initial_acc")),
+ "n_evals": len(evals),
+ "n_floss_episodes": 1 if pre_steps > 0 else 0,
+ "status": "complete" if "final_acc" in data else "running_or_incomplete",
+ }
+ rows = []
+ for e in evals:
+ rows.append(
+ {
+ "name": summary["name"],
+ "family": summary["family"],
+ "model": "hrm",
+ "selector": summary["selector"],
+ "x": e.get("step"),
+ "acc": e.get("acc"),
+ "loss": "",
+ "kind": e.get("kind", "eval"),
+ }
+ )
+ return summary, rows
+
+
+def collect() -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
+ summaries: list[dict[str, Any]] = []
+ eval_rows: list[dict[str, Any]] = []
+ for path in iter_json_paths():
+ data = load_json(path)
+ if data is None:
+ continue
+ family = classify(path, data)
+ if family == "toy_rnn":
+ summary, rows = summarize_toy(path, data)
+ elif family.startswith("step7"):
+ summary, rows = summarize_step7(path, data)
+ elif family == "step6_prefloss_hrm":
+ summary, rows = summarize_step6(path, data)
+ else:
+ continue
+ summaries.append(summary)
+ eval_rows.extend(rows)
+ return summaries, eval_rows
+
+
+def write_csv(path: Path, rows: list[dict[str, Any]]) -> None:
+ if not rows:
+ path.write_text("")
+ return
+ keys = list(rows[0].keys())
+ for row in rows[1:]:
+ for key in row:
+ if key not in keys:
+ keys.append(key)
+ with path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=keys)
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def plot_family(eval_rows: list[dict[str, Any]], family: str, out: Path, title: str) -> None:
+ rows = [r for r in eval_rows if r["family"] == family and r.get("x") not in ("", None)]
+ if not rows:
+ return
+ fig, ax = plt.subplots(figsize=(8.5, 4.6))
+ names = sorted({r["name"] for r in rows})
+ for name in names:
+ nr = [r for r in rows if r["name"] == name]
+ nr.sort(key=lambda r: fnum(r["x"]))
+ xs = [fnum(r["x"]) for r in nr]
+ ys = [fnum(r["acc"]) for r in nr]
+ label = name.split("/")[-1]
+ ax.plot(xs, ys, marker="o", linewidth=1.8, label=label)
+ ax.set_title(title)
+ ax.set_xlabel("epoch / train step")
+ ax.set_ylabel("eval exact accuracy")
+ ax.grid(alpha=0.22)
+ ax.legend(frameon=False, fontsize=7)
+ fig.tight_layout()
+ fig.savefig(out, dpi=180)
+ plt.close(fig)
+
+
+def main() -> None:
+ OUT.mkdir(parents=True, exist_ok=True)
+ summaries, eval_rows = collect()
+ summaries.sort(key=lambda r: (str(r["family"]), str(r["model"]), str(r["name"])))
+ write_csv(OUT / "flossing_runs_summary.csv", summaries)
+ write_csv(OUT / "flossing_eval_curves.csv", eval_rows)
+ plot_family(eval_rows, "toy_rnn", OUT / "toy_rnn_eval_curves.png", "Toy RNN Engelken-style flossing")
+ plot_family(eval_rows, "rrm_step7", OUT / "rrm_step7_eval_curves.png", "HRM/TRM interfloss analogues")
+ plot_family(eval_rows, "hrm_step6", OUT / "hrm_step6_eval_curves.png", "HRM prefloss experiments")
+ print(f"wrote {OUT / 'flossing_runs_summary.csv'}")
+ print(f"wrote {OUT / 'flossing_eval_curves.csv'}")
+ print(f"runs: {len(summaries)} eval points: {len(eval_rows)}")
+
+
+if __name__ == "__main__":
+ main()