summaryrefslogtreecommitdiff
path: root/baselines/prompt_all_k.py
diff options
context:
space:
mode:
Diffstat (limited to 'baselines/prompt_all_k.py')
-rw-r--r--baselines/prompt_all_k.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/baselines/prompt_all_k.py b/baselines/prompt_all_k.py
new file mode 100644
index 0000000..5b132d8
--- /dev/null
+++ b/baselines/prompt_all_k.py
@@ -0,0 +1,12 @@
+"""Prompt-All-K baseline: put all K support items into the prompt as demonstrations."""
+
+from data.templates import build_prompt_with_examples
+
+
+def generate_prompt_all_k(wrapper, query_input: str, support_items: list,
+ task: str, max_new_tokens: int = 512,
+ temperature: float = 0.7, top_p: float = 0.9) -> str:
+ """Generate with all K support items in the prompt."""
+ prompt = build_prompt_with_examples(query_input, support_items, task)
+ return wrapper.generate_base(prompt, max_new_tokens=max_new_tokens,
+ temperature=temperature, top_p=top_p)