summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor/rule_extractor.py
blob: 42f43ed09f16eba2ebabc028a8ecc4ede2e77cd5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from __future__ import annotations

import json
import re
import os
from typing import Any, Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from personalization.config.registry import choose_dtype, choose_device_map
from personalization.config.settings import LocalModelsConfig
from .base import PreferenceExtractor
from personalization.retrieval.preference_store.schemas import (
    PreferenceList,
    preference_list_json_schema,
    ChatTurn,
)

# Hardcoded System Prompt to match SFT training
# This MUST match what was used in training (scripts/split_train_test.py)
SFT_SYSTEM_PROMPT = (
    "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
    "If no preferences are found, return {\"preferences\": []}."
)

class QwenRuleExtractor(PreferenceExtractor):
    """
    Extractor using a Fine-Tuned (SFT) Qwen model.
    Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction.
    """
    def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, use_fast=True, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            dtype=dtype,
            device_map=device_map,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    @classmethod
    def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor":
        spec = cfg.preference_extractor
        dtype = choose_dtype(spec.dtype)
        device_map = choose_device_map(spec.device_map)
        return cls(spec.local_path, dtype=dtype, device_map=device_map)

    def build_preference_prompt(self, query: str) -> str:
        """
        Construct the prompt string using the tokenizer's chat template.
        Matches the format seen during SFT training.
        """
        messages = [
            {"role": "system", "content": SFT_SYSTEM_PROMPT},
            {"role": "user", "content": query}
        ]
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return prompt

    @torch.inference_mode()
    def extract_preferences(self, query: str) -> Dict[str, Any]:
        """
        Directly extract preferences from query using the SFT model.
        Returns a dict compatible with PreferenceList model (key: 'preferences').
        """
        prompt = self.build_preference_prompt(query)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        outputs = self.model.generate(
            **inputs,
            do_sample=False,        # Deterministic greedy decoding
            max_new_tokens=512,     # Allow enough space for JSON
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        input_len = inputs["input_ids"].shape[1]
        gen_ids = outputs[0][input_len:]
        text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        
        if os.getenv("PREF_DEBUG") == "1":
            print(f"[debug][extractor] Raw output: {text}")

        # Try parsing JSON
        try:
            # 1. Direct parse
            data = json.loads(text)
            
            # 2. Validate against schema structure
            validated = PreferenceList.model_validate(data)
            return validated.model_dump()
            
        except Exception:
            # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible)
            extracted_json = self._extract_json_substring(text)
            if extracted_json:
                try:
                    data = json.loads(extracted_json)
                    validated = PreferenceList.model_validate(data)
                    return validated.model_dump()
                except:
                    pass
            
            # If all fails, return empty
            return {"preferences": []}

    def _extract_json_substring(self, text: str) -> str | None:
        """Helper to find { ... } block in text."""
        # Find first '{' and last '}'
        start = text.find('{')
        end = text.rfind('}')
        if start != -1 and end != -1 and end > start:
            return text[start : end + 1]
        return None

    @torch.inference_mode()
    def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]:
        """
        Batch extract preferences from multiple queries using left-padded batching.
        """
        if not queries:
            return []

        # Save and set padding side for decoder-only batched generation
        orig_padding_side = self.tokenizer.padding_side
        self.tokenizer.padding_side = "left"

        all_results = []
        prompts = [self.build_preference_prompt(q) for q in queries]

        for start in range(0, len(prompts), batch_size):
            batch_prompts = prompts[start:start + batch_size]
            inputs = self.tokenizer(
                batch_prompts, return_tensors="pt", padding=True, truncation=True
            ).to(self.model.device)

            outputs = self.model.generate(
                **inputs,
                do_sample=False,
                max_new_tokens=512,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

            for i in range(len(batch_prompts)):
                input_len = (inputs["attention_mask"][i] == 1).sum().item()
                gen_ids = outputs[i][input_len:]
                text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)

                try:
                    data = json.loads(text)
                    validated = PreferenceList.model_validate(data)
                    all_results.append(validated.model_dump())
                except Exception:
                    extracted_json = self._extract_json_substring(text)
                    if extracted_json:
                        try:
                            data = json.loads(extracted_json)
                            validated = PreferenceList.model_validate(data)
                            all_results.append(validated.model_dump())
                            continue
                        except Exception:
                            pass
                    all_results.append({"preferences": []})

        self.tokenizer.padding_side = orig_padding_side
        return all_results

    def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
        """
        Extract preferences from the LAST user turn in the history.
        We don't concat history because our SFT model was trained on single-turn extraction.
        Using context might confuse it unless we trained it that way.
        """
        # Find the last user message
        last_user_msg = None
        for t in reversed(turns):
            if t.role == "user":
                last_user_msg = t.text
                break
        
        if not last_user_msg:
            return PreferenceList(preferences=[])
        
        result_dict = self.extract_preferences(last_user_msg)
        return PreferenceList.model_validate(result_dict)

    def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
        """
        Extract preferences from ALL user turns individually.
        """
        results = []
        for turn in turns:
            if turn.role == "user":
                res = self.extract_preferences(turn.text)
                results.append(PreferenceList.model_validate(res))
            else:
                results.append(PreferenceList(preferences=[]))
        return results