summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_dynamics_experiments.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analyze_dynamics_experiments.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analyze_dynamics_experiments.py')
-rw-r--r--research/flossing/analyze_dynamics_experiments.py289
1 files changed, 289 insertions, 0 deletions
diff --git a/research/flossing/analyze_dynamics_experiments.py b/research/flossing/analyze_dynamics_experiments.py
new file mode 100644
index 0000000..22491d0
--- /dev/null
+++ b/research/flossing/analyze_dynamics_experiments.py
@@ -0,0 +1,289 @@
+"""Summarize dynamics-control experiments into a markdown report."""
+from __future__ import annotations
+
+import json
+import re
+from dataclasses import dataclass
+from datetime import datetime
+from pathlib import Path
+
+
+ROOT = Path(__file__).resolve().parent
+
+
+@dataclass(frozen=True)
+class RunSpec:
+ name: str
+ model: str
+ family: str
+ json_name: str
+ log_name: str | None = None
+
+
+RUNS = [
+ RunSpec("HRM baseline 10k", "HRM", "baseline", "step3_L_baseline_26040_fast_10k.json"),
+ RunSpec("HRM mixed volume-CF", "HRM", "mixed_loss", "step3_M_volume_cf_26040_lstar_neg015_k8_a10_10k.json"),
+ RunSpec("HRM Engelken interfloss", "HRM", "interfloss", "step7_A_hrm_engelken_interfloss_26040_k8_10k.json"),
+ RunSpec("HRM Engelken+KL interfloss", "HRM", "interfloss_kl", "step7_C_hrm_engelken_interfloss_kl10_26040_k8_10k.json"),
+ RunSpec("HRM conservative Engelken+KL", "HRM", "interfloss_kl_conservative", "step7_I_hrm_engelken_interfloss_kl100_short_26040_k8_10k.json"),
+ RunSpec("HRM late Engelken+KL", "HRM", "late_interfloss_kl", "step7_E_hrm_late_engelken_interfloss_kl10_start12_26040_k8_10k.json"),
+ RunSpec("HRM volume-envelope+KL", "HRM", "volume_interfloss_kl", "step7_G_hrm_volume_envelope_interfloss_kl10_lstar_neg015_26040_k8_10k.json"),
+ RunSpec("HRM basin consistency", "HRM", "basin", "step8_A_hrm_basin_consistency_beta1_noise002_after8_26040_10k.json"),
+ RunSpec("HRM single perturbed CE", "HRM", "trajectory_augment_single", "step9_A_hrm_single_perturb_sigma1e-3_26040_10k.json"),
+ RunSpec("HRM clean+multi perturbed CE", "HRM", "trajectory_augment_multi", "step9_B_hrm_multi4_perturb_sigma1e-3_26040_10k.json"),
+ RunSpec("HRM fixed-unroll baseline 50k", "HRM", "trajectory_fixed_baseline", "step9_E_hrm_baseline_parallel_fixed_26040_50k.json"),
+ RunSpec("HRM multi4 loguniform 50k", "HRM", "trajectory_augment_loguniform", "step9_F_hrm_multi4_loguniform_ramp_26040_50k.json"),
+ RunSpec("TRM baseline 10k", "TRM", "baseline", "step5_L_trm_baseline_26041_batch4_fast_10k.json"),
+ RunSpec("TRM mixed volume-CF", "TRM", "mixed_loss", "step5_M_trm_volume_cf_26041_lstar002_batch4_k4_a10_10k.json"),
+ RunSpec("TRM Engelken interfloss", "TRM", "interfloss", "step7_B_trm_engelken_interfloss_26041_k4_batch4_10k.json"),
+ RunSpec("TRM Engelken+KL interfloss", "TRM", "interfloss_kl", "step7_D_trm_engelken_interfloss_kl10_26041_k4_batch4_10k.json"),
+ RunSpec("TRM late Engelken+KL", "TRM", "late_interfloss_kl", "step7_F_trm_late_engelken_interfloss_kl10_start12_26041_k4_batch4_10k.json"),
+ RunSpec("TRM volume-envelope+KL", "TRM", "volume_interfloss_kl", "step7_H_trm_volume_envelope_interfloss_kl10_lstar002_26041_k4_batch4_10k.json"),
+ RunSpec("TRM basin consistency", "TRM", "basin", "step8_B_trm_basin_consistency_beta1_noise002_after8_26041_batch4_10k.json"),
+ RunSpec("TRM single perturbed CE", "TRM", "trajectory_augment_single", "step9_C_trm_single_perturb_sigma1e-3_26041_batch4_10k.json"),
+ RunSpec("TRM clean+multi perturbed CE", "TRM", "trajectory_augment_multi", "step9_D_trm_multi4_perturb_sigma1e-3_26041_batch4_10k.json"),
+ RunSpec("TRM fixed-unroll baseline 50k", "TRM", "trajectory_fixed_baseline", "step9_G_trm_baseline_parallel_fixed_26041_batch4_50k.json"),
+ RunSpec("TRM multi4 loguniform 50k", "TRM", "trajectory_augment_loguniform", "step9_H_trm_multi4_loguniform_ramp_26041_batch4_50k.json"),
+]
+
+
+def load_json(path: Path):
+ if not path.exists():
+ return None
+ try:
+ return json.loads(path.read_text())
+ except Exception as exc: # noqa: BLE001
+ return {"_bad_json": str(exc)}
+
+
+def step_of(ev: dict):
+ return ev.get("step", ev.get("train_step"))
+
+
+def acc_of(ev: dict):
+ acc = ev.get("acc")
+ return None if acc is None else float(acc)
+
+
+def evals_of(data: dict):
+ return [ev for ev in data.get("evals", []) if acc_of(ev) is not None]
+
+
+def best_eval(evals: list[dict]):
+ if not evals:
+ return None
+ return max(evals, key=lambda ev: acc_of(ev))
+
+
+def last_eval(evals: list[dict]):
+ if not evals:
+ return None
+ return evals[-1]
+
+
+def fmt(x, digits=4):
+ if x is None:
+ return "NA"
+ return f"{float(x):.{digits}f}"
+
+
+def completion_status(data):
+ if data is None:
+ return "missing"
+ if "_bad_json" in data:
+ return "bad_json"
+ if data.get("final_acc") is not None:
+ return "complete"
+ evs = evals_of(data)
+ if evs:
+ return "partial"
+ return "started"
+
+
+def run_summary(spec: RunSpec):
+ path = ROOT / spec.json_name
+ data = load_json(path)
+ status = completion_status(data)
+ if data is None or "_bad_json" in data:
+ return {
+ "spec": spec,
+ "status": status,
+ "initial": None,
+ "final": None,
+ "last": None,
+ "best": None,
+ "n_evals": 0,
+ "data": data,
+ }
+ evs = evals_of(data)
+ last = last_eval(evs)
+ best = best_eval(evs)
+ final = data.get("final_acc")
+ if final is None and last is not None:
+ final = acc_of(last)
+ return {
+ "spec": spec,
+ "status": status,
+ "initial": data.get("initial_acc"),
+ "final": final,
+ "last": last,
+ "best": best,
+ "n_evals": len(evs),
+ "data": data,
+ }
+
+
+def collect_process_snapshot():
+ # Avoid importing psutil. This is a best-effort snapshot from /proc.
+ patterns = [
+ "step3_M_volume_cf_26040_lstar_neg015",
+ "step7_C_hrm_engelken_interfloss_kl10",
+ "step7_D_trm_engelken_interfloss_kl10",
+ "launch_dynamics_variants_queue",
+ "step7_E_hrm_late",
+ "step7_F_trm_late",
+ "step7_G_hrm_volume",
+ "step7_H_trm_volume",
+ "step8_A_hrm_basin",
+ "step8_B_trm_basin",
+ "step9_A_hrm_single",
+ "step9_B_hrm_multi4",
+ "step9_C_trm_single",
+ "step9_D_trm_multi4",
+ "step9_E_hrm_baseline",
+ "step9_F_hrm_multi4",
+ "step9_G_trm_baseline",
+ "step9_H_trm_multi4",
+ "launch_trajectory_perturb_queue",
+ "launch_trajectory_sampling_long",
+ ]
+ rows = []
+ for proc in Path("/proc").iterdir():
+ if not proc.name.isdigit():
+ continue
+ try:
+ cmd = (proc / "cmdline").read_bytes().replace(b"\x00", b" ").decode("utf-8", "ignore")
+ except Exception: # noqa: BLE001
+ continue
+ if any(p in cmd for p in patterns):
+ rows.append((int(proc.name), cmd.strip()))
+ return sorted(rows)
+
+
+def log_tail_metrics(log_name: str):
+ path = ROOT / log_name
+ if not path.exists():
+ return {}
+ text = path.read_text(errors="ignore")
+ evals = []
+ for step, acc in re.findall(r"EVAL @ step\s+(\d+): exact_acc=([0-9.]+)", text):
+ evals.append((int(step), float(acc)))
+ progress = re.findall(r"\[\s*(\d+)/10000\]", text)
+ return {
+ "log_step": int(progress[-1]) if progress else None,
+ "log_last_eval": evals[-1] if evals else None,
+ "log_best_eval": max(evals, key=lambda x: x[1]) if evals else None,
+ }
+
+
+def floss_episode_summary(data: dict):
+ lines = []
+ episodes = data.get("floss_episodes", []) if data else []
+ for ep in episodes:
+ steps = ep.get("steps", [])
+ if not steps:
+ continue
+ first = steps[0]
+ last = steps[-1]
+ max_kl = max((s.get("kl_loss", 0.0) for s in steps), default=0.0)
+ mean_kl = sum((s.get("kl_loss", 0.0) for s in steps)) / max(len(steps), 1)
+ lines.append(
+ f"episode {ep.get('episode')} @ train_step {ep.get('train_step')}: "
+ f"floss {fmt(first.get('floss_loss', first.get('loss')), 6)} -> "
+ f"{fmt(last.get('floss_loss', last.get('loss')), 6)}, "
+ f"lyap1 {fmt(first.get('lyap1_mean'))} -> {fmt(last.get('lyap1_mean'))}, "
+ f"volume {fmt(first.get('volume_mean'))} -> {fmt(last.get('volume_mean'))}, "
+ f"KL mean/max {fmt(mean_kl, 6)}/{fmt(max_kl, 6)}"
+ )
+ return lines
+
+
+def markdown_report():
+ rows = [run_summary(spec) for spec in RUNS]
+ baseline_final = {
+ row["spec"].model: row["final"]
+ for row in rows
+ if row["spec"].family == "baseline" and row["final"] is not None
+ }
+
+ out = []
+ out.append("# Dynamics Control Experiment Report")
+ out.append("")
+ out.append(f"Generated: {datetime.now().isoformat(timespec='seconds')}")
+ out.append("")
+ out.append("## Summary Table")
+ out.append("")
+ out.append("| Model | Run | Status | Init | Final/Last | Delta | Best | Best Step | Vs Baseline | Evals |")
+ out.append("|---|---|---:|---:|---:|---:|---:|---:|---:|---:|")
+ for row in rows:
+ spec = row["spec"]
+ init = row["initial"]
+ final = row["final"]
+ best = row["best"]
+ best_acc = acc_of(best) if best else None
+ best_step = step_of(best) if best else None
+ delta = None if init is None or final is None else final - init
+ base = baseline_final.get(spec.model)
+ vs_base = None if base is None or final is None or spec.family == "baseline" else final - base
+ out.append(
+ f"| {spec.model} | {spec.name} | {row['status']} | {fmt(init)} | {fmt(final)} | "
+ f"{fmt(delta)} | {fmt(best_acc)} | {best_step if best_step is not None else 'NA'} | "
+ f"{fmt(vs_base)} | {row['n_evals']} |"
+ )
+
+ out.append("")
+ out.append("## Floss Episode Diagnostics")
+ out.append("")
+ any_floss = False
+ for row in rows:
+ lines = floss_episode_summary(row.get("data") or {})
+ if not lines:
+ continue
+ any_floss = True
+ out.append(f"### {row['spec'].name}")
+ for line in lines:
+ out.append(f"- {line}")
+ out.append("")
+ if not any_floss:
+ out.append("No completed floss episode diagnostics found yet.")
+ out.append("")
+
+ out.append("## Incomplete Runs / Process Snapshot")
+ out.append("")
+ active = collect_process_snapshot()
+ if active:
+ for pid, cmd in active:
+ out.append(f"- PID {pid}: `{cmd[:220]}`")
+ else:
+ out.append("- No monitored experiment processes are active.")
+ out.append("")
+
+ out.append("## Notes")
+ out.append("")
+ out.append("- `Final/Last` is `final_acc` when present, otherwise the latest eval accuracy.")
+ out.append("- `Vs Baseline` compares against the matching HRM/TRM 10k no-floss baseline.")
+ out.append("- A complete report may still show partial rows if an experiment crashed or was interrupted.")
+ out.append("")
+ return "\n".join(out)
+
+
+def main():
+ report = markdown_report()
+ out_path = ROOT / "dynamics_experiment_report.md"
+ out_path.write_text(report)
+ print(report)
+ print(f"\nWrote {out_path}")
+
+
+if __name__ == "__main__":
+ main()