summaryrefslogtreecommitdiff
path: root/resulets/scripts/summarize_dense_baselines.py
diff options
context:
space:
mode:
Diffstat (limited to 'resulets/scripts/summarize_dense_baselines.py')
-rw-r--r--resulets/scripts/summarize_dense_baselines.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/resulets/scripts/summarize_dense_baselines.py b/resulets/scripts/summarize_dense_baselines.py
new file mode 100644
index 0000000..081667e
--- /dev/null
+++ b/resulets/scripts/summarize_dense_baselines.py
@@ -0,0 +1,63 @@
+"""Summarize dense retrieval baseline runs into a compact CSV table."""
+
+import argparse
+import csv
+import json
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
+
+from baselines.dense_retrieval import DENSE_RETRIEVER_CONFIGS
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--output_dir", default="outputs/dense_retrieval_baselines")
+ parser.add_argument("--csv_name", default="dense_summary.csv")
+ args = parser.parse_args()
+
+ output_dir = Path(args.output_dir)
+ rows = []
+
+ for summary_path in sorted(output_dir.glob("*_K*/summary.json")):
+ with summary_path.open() as f:
+ summary = json.load(f)
+
+ task = summary.get("task")
+ setting = summary.get("setting")
+ K = summary.get("K")
+ aggregate = summary.get("aggregate", {})
+
+ for method, metrics in aggregate.items():
+ config = DENSE_RETRIEVER_CONFIGS.get(method)
+ rows.append({
+ "task": task,
+ "setting": setting,
+ "K": K,
+ "method": method,
+ "model": config.model_name if config else "",
+ "retrieval_text": config.text_mode if config else "",
+ "year": config.citation_year if config else "",
+ "rougeL": metrics.get("rougeL"),
+ "meteor": metrics.get("meteor"),
+ "sfd_nolen": metrics.get("sfd_nolen"),
+ "avg_len": metrics.get("avg_len"),
+ })
+
+ csv_path = output_dir / args.csv_name
+ csv_path.parent.mkdir(parents=True, exist_ok=True)
+ fieldnames = [
+ "task", "setting", "K", "method", "model", "retrieval_text", "year",
+ "rougeL", "meteor", "sfd_nolen", "avg_len",
+ ]
+ with csv_path.open("w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(rows)
+
+ print(f"Wrote {csv_path} ({len(rows)} rows)")
+
+
+if __name__ == "__main__":
+ main()