summaryrefslogtreecommitdiff
path: root/resulets/scripts/summarize_dense_baselines.py
blob: 081667e4efce5d9b188514adf783f8abce42ef87 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Summarize dense retrieval baseline runs into a compact CSV table."""

import argparse
import csv
import json
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from baselines.dense_retrieval import DENSE_RETRIEVER_CONFIGS


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", default="outputs/dense_retrieval_baselines")
    parser.add_argument("--csv_name", default="dense_summary.csv")
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    rows = []

    for summary_path in sorted(output_dir.glob("*_K*/summary.json")):
        with summary_path.open() as f:
            summary = json.load(f)

        task = summary.get("task")
        setting = summary.get("setting")
        K = summary.get("K")
        aggregate = summary.get("aggregate", {})

        for method, metrics in aggregate.items():
            config = DENSE_RETRIEVER_CONFIGS.get(method)
            rows.append({
                "task": task,
                "setting": setting,
                "K": K,
                "method": method,
                "model": config.model_name if config else "",
                "retrieval_text": config.text_mode if config else "",
                "year": config.citation_year if config else "",
                "rougeL": metrics.get("rougeL"),
                "meteor": metrics.get("meteor"),
                "sfd_nolen": metrics.get("sfd_nolen"),
                "avg_len": metrics.get("avg_len"),
            })

    csv_path = output_dir / args.csv_name
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = [
        "task", "setting", "K", "method", "model", "retrieval_text", "year",
        "rougeL", "meteor", "sfd_nolen", "avg_len",
    ]
    with csv_path.open("w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    print(f"Wrote {csv_path} ({len(rows)} rows)")


if __name__ == "__main__":
    main()