"""Cache hidden states from support set using frozen model.""" import torch from data.templates import build_support_prompt def cache_support_hidden_states( wrapper, support_items: list, task: str, ) -> list: """Cache hidden states from support set items. Args: wrapper: QwenWrapper instance support_items: List of dicts with 'support_input' and 'support_output' task: 'review' or 'topic' Returns: List of (h_states, label_ids) tuples """ cached = [] for item in support_items: input_text = build_support_prompt(item['support_input'], task) target_text = " " + item['support_output'] # Space prefix for clean tokenization h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text) if h_states is not None and h_states.shape[0] > 0: cached.append((h_states.detach().cpu(), label_ids.detach().cpu())) return cached