blob: 9cc92ac05c90dd2b7cd24f0dc44521968be260d3 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
"""BM25-Top1 baseline: retrieve the most relevant support item and use it as in-context example."""
from rank_bm25 import BM25Okapi
from data.templates import build_prompt_with_examples
def bm25_select_top1(query_input: str, support_items: list) -> list:
"""Select the most relevant support item using BM25.
Args:
query_input: The query text
support_items: Available support items
Returns:
List containing the single most relevant support item
"""
if len(support_items) <= 1:
return support_items
# Build corpus from support inputs + outputs
corpus = []
for item in support_items:
doc = item['support_input'] + " " + item['support_output']
corpus.append(doc.lower().split())
bm25 = BM25Okapi(corpus)
query_tokens = query_input.lower().split()
scores = bm25.get_scores(query_tokens)
best_idx = scores.argmax()
return [support_items[best_idx]]
def generate_bm25_top1(wrapper, query_input: str, support_items: list,
task: str, max_new_tokens: int = 512,
temperature: float = 0.7, top_p: float = 0.9) -> str:
"""Generate with BM25-Top1 baseline."""
selected = bm25_select_top1(query_input, support_items)
prompt = build_prompt_with_examples(query_input, selected, task)
return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens,
temperature=temperature, top_p=top_p)
|