summaryrefslogtreecommitdiff
path: root/baselines
diff options
context:
space:
mode:
Diffstat (limited to 'baselines')
-rw-r--r--baselines/__init__.py0
-rw-r--r--baselines/bm25_top1.py41
-rw-r--r--baselines/prompt_all_k.py12
3 files changed, 53 insertions, 0 deletions
diff --git a/baselines/__init__.py b/baselines/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/baselines/__init__.py
diff --git a/baselines/bm25_top1.py b/baselines/bm25_top1.py
new file mode 100644
index 0000000..9cc92ac
--- /dev/null
+++ b/baselines/bm25_top1.py
@@ -0,0 +1,41 @@
+"""BM25-Top1 baseline: retrieve the most relevant support item and use it as in-context example."""
+
+from rank_bm25 import BM25Okapi
+from data.templates import build_prompt_with_examples
+
+
+def bm25_select_top1(query_input: str, support_items: list) -> list:
+ """Select the most relevant support item using BM25.
+
+ Args:
+ query_input: The query text
+ support_items: Available support items
+
+ Returns:
+ List containing the single most relevant support item
+ """
+ if len(support_items) <= 1:
+ return support_items
+
+ # Build corpus from support inputs + outputs
+ corpus = []
+ for item in support_items:
+ doc = item['support_input'] + " " + item['support_output']
+ corpus.append(doc.lower().split())
+
+ bm25 = BM25Okapi(corpus)
+ query_tokens = query_input.lower().split()
+ scores = bm25.get_scores(query_tokens)
+
+ best_idx = scores.argmax()
+ return [support_items[best_idx]]
+
+
+def generate_bm25_top1(wrapper, query_input: str, support_items: list,
+ task: str, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9) -> str:
+ """Generate with BM25-Top1 baseline."""
+ selected = bm25_select_top1(query_input, support_items)
+ prompt = build_prompt_with_examples(query_input, selected, task)
+ return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens,
+ temperature=temperature, top_p=top_p)
diff --git a/baselines/prompt_all_k.py b/baselines/prompt_all_k.py
new file mode 100644
index 0000000..5b132d8
--- /dev/null
+++ b/baselines/prompt_all_k.py
@@ -0,0 +1,12 @@
+"""Prompt-All-K baseline: put all K support items into the prompt as demonstrations."""
+
+from data.templates import build_prompt_with_examples
+
+
+def generate_prompt_all_k(wrapper, query_input: str, support_items: list,
+ task: str, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9) -> str:
+ """Generate with all K support items in the prompt."""
+ prompt = build_prompt_with_examples(query_input, support_items, task)
+ return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens,
+ temperature=temperature, top_p=top_p)