summaryrefslogtreecommitdiff
path: root/baselines/bm25_top1.py
blob: 9cc92ac05c90dd2b7cd24f0dc44521968be260d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""BM25-Top1 baseline: retrieve the most relevant support item and use it as in-context example."""

from rank_bm25 import BM25Okapi
from data.templates import build_prompt_with_examples


def bm25_select_top1(query_input: str, support_items: list) -> list:
    """Select the most relevant support item using BM25.

    Args:
        query_input: The query text
        support_items: Available support items

    Returns:
        List containing the single most relevant support item
    """
    if len(support_items) <= 1:
        return support_items

    # Build corpus from support inputs + outputs
    corpus = []
    for item in support_items:
        doc = item['support_input'] + " " + item['support_output']
        corpus.append(doc.lower().split())

    bm25 = BM25Okapi(corpus)
    query_tokens = query_input.lower().split()
    scores = bm25.get_scores(query_tokens)

    best_idx = scores.argmax()
    return [support_items[best_idx]]


def generate_bm25_top1(wrapper, query_input: str, support_items: list,
                        task: str, max_new_tokens: int = 512,
                        temperature: float = 0.7, top_p: float = 0.9) -> str:
    """Generate with BM25-Top1 baseline."""
    selected = bm25_select_top1(query_input, support_items)
    prompt = build_prompt_with_examples(query_input, selected, task)
    return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens,
                                  temperature=temperature, top_p=top_p)