summaryrefslogtreecommitdiff
path: root/baselines/prompt_all_k.py
blob: 5b132d811fef036bd24151d0921391a3ab1edd9b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
"""Prompt-All-K baseline: put all K support items into the prompt as demonstrations."""

from data.templates import build_prompt_with_examples


def generate_prompt_all_k(wrapper, query_input: str, support_items: list,
                          task: str, max_new_tokens: int = 512,
                          temperature: float = 0.7, top_p: float = 0.9) -> str:
    """Generate with all K support items in the prompt."""
    prompt = build_prompt_with_examples(query_input, support_items, task)
    return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens,
                                  temperature=temperature, top_p=top_p)