summaryrefslogtreecommitdiff
path: root/adapt/cache_hidden.py
diff options
context:
space:
mode:
Diffstat (limited to 'adapt/cache_hidden.py')
-rw-r--r--adapt/cache_hidden.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/adapt/cache_hidden.py b/adapt/cache_hidden.py
new file mode 100644
index 0000000..421c3b7
--- /dev/null
+++ b/adapt/cache_hidden.py
@@ -0,0 +1,33 @@
+"""Cache hidden states from support set using frozen model."""
+
+import torch
+from data.templates import build_support_prompt
+
+
+def cache_support_hidden_states(
+ wrapper,
+ support_items: list,
+ task: str,
+) -> list:
+ """Cache hidden states from support set items.
+
+ Args:
+ wrapper: QwenWrapper instance
+ support_items: List of dicts with 'support_input' and 'support_output'
+ task: 'review' or 'topic'
+
+ Returns:
+ List of (h_states, label_ids) tuples
+ """
+ cached = []
+
+ for item in support_items:
+ input_text = build_support_prompt(item['support_input'], task)
+ target_text = " " + item['support_output'] # Space prefix for clean tokenization
+
+ h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text)
+
+ if h_states is not None and h_states.shape[0] > 0:
+ cached.append((h_states.detach().cpu(), label_ids.detach().cpu()))
+
+ return cached