diff options
Diffstat (limited to 'baselines/bm25_top1.py')
| -rw-r--r-- | baselines/bm25_top1.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/baselines/bm25_top1.py b/baselines/bm25_top1.py new file mode 100644 index 0000000..9cc92ac --- /dev/null +++ b/baselines/bm25_top1.py @@ -0,0 +1,41 @@ +"""BM25-Top1 baseline: retrieve the most relevant support item and use it as in-context example.""" + +from rank_bm25 import BM25Okapi +from data.templates import build_prompt_with_examples + + +def bm25_select_top1(query_input: str, support_items: list) -> list: + """Select the most relevant support item using BM25. + + Args: + query_input: The query text + support_items: Available support items + + Returns: + List containing the single most relevant support item + """ + if len(support_items) <= 1: + return support_items + + # Build corpus from support inputs + outputs + corpus = [] + for item in support_items: + doc = item['support_input'] + " " + item['support_output'] + corpus.append(doc.lower().split()) + + bm25 = BM25Okapi(corpus) + query_tokens = query_input.lower().split() + scores = bm25.get_scores(query_tokens) + + best_idx = scores.argmax() + return [support_items[best_idx]] + + +def generate_bm25_top1(wrapper, query_input: str, support_items: list, + task: str, max_new_tokens: int = 512, + temperature: float = 0.7, top_p: float = 0.9) -> str: + """Generate with BM25-Top1 baseline.""" + selected = bm25_select_top1(query_input, support_items) + prompt = build_prompt_with_examples(query_input, selected, task) + return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p) |
