diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 15:12:34 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 15:12:34 -0500 |
| commit | 8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch) | |
| tree | c92a92184fb2f46f265ab84c1f754c3d5d6597bc /baselines | |
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs,
and experiment results for the User Prior Head personalization method.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'baselines')
| -rw-r--r-- | baselines/__init__.py | 0 | ||||
| -rw-r--r-- | baselines/bm25_top1.py | 41 | ||||
| -rw-r--r-- | baselines/prompt_all_k.py | 12 |
3 files changed, 53 insertions, 0 deletions
diff --git a/baselines/__init__.py b/baselines/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/baselines/__init__.py diff --git a/baselines/bm25_top1.py b/baselines/bm25_top1.py new file mode 100644 index 0000000..9cc92ac --- /dev/null +++ b/baselines/bm25_top1.py @@ -0,0 +1,41 @@ +"""BM25-Top1 baseline: retrieve the most relevant support item and use it as in-context example.""" + +from rank_bm25 import BM25Okapi +from data.templates import build_prompt_with_examples + + +def bm25_select_top1(query_input: str, support_items: list) -> list: + """Select the most relevant support item using BM25. + + Args: + query_input: The query text + support_items: Available support items + + Returns: + List containing the single most relevant support item + """ + if len(support_items) <= 1: + return support_items + + # Build corpus from support inputs + outputs + corpus = [] + for item in support_items: + doc = item['support_input'] + " " + item['support_output'] + corpus.append(doc.lower().split()) + + bm25 = BM25Okapi(corpus) + query_tokens = query_input.lower().split() + scores = bm25.get_scores(query_tokens) + + best_idx = scores.argmax() + return [support_items[best_idx]] + + +def generate_bm25_top1(wrapper, query_input: str, support_items: list, + task: str, max_new_tokens: int = 512, + temperature: float = 0.7, top_p: float = 0.9) -> str: + """Generate with BM25-Top1 baseline.""" + selected = bm25_select_top1(query_input, support_items) + prompt = build_prompt_with_examples(query_input, selected, task) + return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p) diff --git a/baselines/prompt_all_k.py b/baselines/prompt_all_k.py new file mode 100644 index 0000000..5b132d8 --- /dev/null +++ b/baselines/prompt_all_k.py @@ -0,0 +1,12 @@ +"""Prompt-All-K baseline: put all K support items into the prompt as demonstrations.""" + +from data.templates import build_prompt_with_examples + + +def generate_prompt_all_k(wrapper, query_input: str, support_items: list, + task: str, max_new_tokens: int = 512, + temperature: float = 0.7, top_p: float = 0.9) -> str: + """Generate with all K support items in the prompt.""" + prompt = build_prompt_with_examples(query_input, support_items, task) + return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens, + temperature=temperature, top_p=top_p) |
