From 896df7f11b441a9b8dfa50820024a82884da58d0 Mon Sep 17 00:00:00 2001 From: BLUESKY477 Date: Fri, 22 May 2026 19:23:44 -0500 Subject: Add files via upload --- resulets/scripts/summarize_dense_baselines.py | 63 +++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 resulets/scripts/summarize_dense_baselines.py (limited to 'resulets/scripts/summarize_dense_baselines.py') diff --git a/resulets/scripts/summarize_dense_baselines.py b/resulets/scripts/summarize_dense_baselines.py new file mode 100644 index 0000000..081667e --- /dev/null +++ b/resulets/scripts/summarize_dense_baselines.py @@ -0,0 +1,63 @@ +"""Summarize dense retrieval baseline runs into a compact CSV table.""" + +import argparse +import csv +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from baselines.dense_retrieval import DENSE_RETRIEVER_CONFIGS + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", default="outputs/dense_retrieval_baselines") + parser.add_argument("--csv_name", default="dense_summary.csv") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + rows = [] + + for summary_path in sorted(output_dir.glob("*_K*/summary.json")): + with summary_path.open() as f: + summary = json.load(f) + + task = summary.get("task") + setting = summary.get("setting") + K = summary.get("K") + aggregate = summary.get("aggregate", {}) + + for method, metrics in aggregate.items(): + config = DENSE_RETRIEVER_CONFIGS.get(method) + rows.append({ + "task": task, + "setting": setting, + "K": K, + "method": method, + "model": config.model_name if config else "", + "retrieval_text": config.text_mode if config else "", + "year": config.citation_year if config else "", + "rougeL": metrics.get("rougeL"), + "meteor": metrics.get("meteor"), + "sfd_nolen": metrics.get("sfd_nolen"), + "avg_len": metrics.get("avg_len"), + }) + + csv_path = output_dir / args.csv_name + csv_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "task", "setting", "K", "method", "model", "retrieval_text", "year", + "rougeL", "meteor", "sfd_nolen", "avg_len", + ] + with csv_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + print(f"Wrote {csv_path} ({len(rows)} rows)") + + +if __name__ == "__main__": + main() -- cgit v1.2.3