diff options
| author | BLUESKY477 <abcd15803148000@163.com> | 2026-05-22 19:23:44 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-05-22 19:23:44 -0500 |
| commit | 896df7f11b441a9b8dfa50820024a82884da58d0 (patch) | |
| tree | 0182ae4a7a0bb16ee6a764393838a580e1ba1c31 /resulets/scripts/summarize_dense_baselines.py | |
| parent | 6f48c4fae3243e280b27a977c6a8cb731becf446 (diff) | |
Diffstat (limited to 'resulets/scripts/summarize_dense_baselines.py')
| -rw-r--r-- | resulets/scripts/summarize_dense_baselines.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/resulets/scripts/summarize_dense_baselines.py b/resulets/scripts/summarize_dense_baselines.py new file mode 100644 index 0000000..081667e --- /dev/null +++ b/resulets/scripts/summarize_dense_baselines.py @@ -0,0 +1,63 @@ +"""Summarize dense retrieval baseline runs into a compact CSV table.""" + +import argparse +import csv +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from baselines.dense_retrieval import DENSE_RETRIEVER_CONFIGS + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--output_dir", default="outputs/dense_retrieval_baselines") + parser.add_argument("--csv_name", default="dense_summary.csv") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + rows = [] + + for summary_path in sorted(output_dir.glob("*_K*/summary.json")): + with summary_path.open() as f: + summary = json.load(f) + + task = summary.get("task") + setting = summary.get("setting") + K = summary.get("K") + aggregate = summary.get("aggregate", {}) + + for method, metrics in aggregate.items(): + config = DENSE_RETRIEVER_CONFIGS.get(method) + rows.append({ + "task": task, + "setting": setting, + "K": K, + "method": method, + "model": config.model_name if config else "", + "retrieval_text": config.text_mode if config else "", + "year": config.citation_year if config else "", + "rougeL": metrics.get("rougeL"), + "meteor": metrics.get("meteor"), + "sfd_nolen": metrics.get("sfd_nolen"), + "avg_len": metrics.get("avg_len"), + }) + + csv_path = output_dir / args.csv_name + csv_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = [ + "task", "setting", "K", "method", "model", "retrieval_text", "year", + "rougeL", "meteor", "sfd_nolen", "avg_len", + ] + with csv_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + print(f"Wrote {csv_path} ({len(rows)} rows)") + + +if __name__ == "__main__": + main() |
