blob: 421c3b756a9934340e91e644fcd36d765bc0d828 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
"""Cache hidden states from support set using frozen model."""
import torch
from data.templates import build_support_prompt
def cache_support_hidden_states(
wrapper,
support_items: list,
task: str,
) -> list:
"""Cache hidden states from support set items.
Args:
wrapper: QwenWrapper instance
support_items: List of dicts with 'support_input' and 'support_output'
task: 'review' or 'topic'
Returns:
List of (h_states, label_ids) tuples
"""
cached = []
for item in support_items:
input_text = build_support_prompt(item['support_input'], task)
target_text = " " + item['support_output'] # Space prefix for clean tokenization
h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text)
if h_states is not None and h_states.shape[0] > 0:
cached.append((h_states.detach().cpu(), label_ids.detach().cpu()))
return cached
|