summaryrefslogtreecommitdiff
path: root/adapt/cache_hidden.py
blob: 421c3b756a9934340e91e644fcd36d765bc0d828 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Cache hidden states from support set using frozen model."""

import torch
from data.templates import build_support_prompt


def cache_support_hidden_states(
    wrapper,
    support_items: list,
    task: str,
) -> list:
    """Cache hidden states from support set items.

    Args:
        wrapper: QwenWrapper instance
        support_items: List of dicts with 'support_input' and 'support_output'
        task: 'review' or 'topic'

    Returns:
        List of (h_states, label_ids) tuples
    """
    cached = []

    for item in support_items:
        input_text = build_support_prompt(item['support_input'], task)
        target_text = " " + item['support_output']  # Space prefix for clean tokenization

        h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text)

        if h_states is not None and h_states.shape[0] > 0:
            cached.append((h_states.detach().cpu(), label_ids.detach().cpu()))

    return cached