diff options
Diffstat (limited to 'adapt/cache_hidden.py')
| -rw-r--r-- | adapt/cache_hidden.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/adapt/cache_hidden.py b/adapt/cache_hidden.py new file mode 100644 index 0000000..421c3b7 --- /dev/null +++ b/adapt/cache_hidden.py @@ -0,0 +1,33 @@ +"""Cache hidden states from support set using frozen model.""" + +import torch +from data.templates import build_support_prompt + + +def cache_support_hidden_states( + wrapper, + support_items: list, + task: str, +) -> list: + """Cache hidden states from support set items. + + Args: + wrapper: QwenWrapper instance + support_items: List of dicts with 'support_input' and 'support_output' + task: 'review' or 'topic' + + Returns: + List of (h_states, label_ids) tuples + """ + cached = [] + + for item in support_items: + input_text = build_support_prompt(item['support_input'], task) + target_text = " " + item['support_output'] # Space prefix for clean tokenization + + h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text) + + if h_states is not None and h_states.shape[0] > 0: + cached.append((h_states.detach().cpu(), label_ids.detach().cpu())) + + return cached |
