summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor/rule_extractor.py
blob: 0f743d97e12fe107baee7857c12f138b0e541c7c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

import json
import re
import os
from typing import Any, Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from personalization.config.registry import choose_dtype, choose_device_map
from personalization.config.settings import LocalModelsConfig
from .base import PreferenceExtractor
from personalization.retrieval.preference_store.schemas import (
    PreferenceList,
    preference_list_json_schema,
    ChatTurn,
)

# Hardcoded System Prompt to match SFT training
# This MUST match what was used in training (scripts/split_train_test.py)
SFT_SYSTEM_PROMPT = (
    "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
    "If no preferences are found, return {\"preferences\": []}."
)

class QwenRuleExtractor(PreferenceExtractor):
    """
    Extractor using a Fine-Tuned (SFT) Qwen model.
    Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction.
    """
    def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, use_fast=True, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            dtype=dtype,
            device_map=device_map,
            trust_remote_code=True,
        )
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    @classmethod
    def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor":
        spec = cfg.preference_extractor
        dtype = choose_dtype(spec.dtype)
        device_map = choose_device_map(spec.device_map)
        return cls(spec.local_path, dtype=dtype, device_map=device_map)

    def build_preference_prompt(self, query: str) -> str:
        """
        Construct the prompt string using the tokenizer's chat template.
        Matches the format seen during SFT training.
        """
        messages = [
            {"role": "system", "content": SFT_SYSTEM_PROMPT},
            {"role": "user", "content": query}
        ]
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return prompt

    @torch.inference_mode()
    def extract_preferences(self, query: str) -> Dict[str, Any]:
        """
        Directly extract preferences from query using the SFT model.
        Returns a dict compatible with PreferenceList model (key: 'preferences').
        """
        prompt = self.build_preference_prompt(query)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        outputs = self.model.generate(
            **inputs,
            do_sample=False,        # Deterministic greedy decoding
            max_new_tokens=512,     # Allow enough space for JSON
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        
        input_len = inputs["input_ids"].shape[1]
        gen_ids = outputs[0][input_len:]
        text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
        
        if os.getenv("PREF_DEBUG") == "1":
            print(f"[debug][extractor] Raw output: {text}")

        # Try parsing JSON
        try:
            # 1. Direct parse
            data = json.loads(text)
            
            # 2. Validate against schema structure
            validated = PreferenceList.model_validate(data)
            return validated.model_dump()
            
        except Exception:
            # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible)
            extracted_json = self._extract_json_substring(text)
            if extracted_json:
                try:
                    data = json.loads(extracted_json)
                    validated = PreferenceList.model_validate(data)
                    return validated.model_dump()
                except:
                    pass
            
            # If all fails, return empty
            return {"preferences": []}

    def _extract_json_substring(self, text: str) -> str | None:
        """Helper to find { ... } block in text."""
        # Find first '{' and last '}'
        start = text.find('{')
        end = text.rfind('}')
        if start != -1 and end != -1 and end > start:
            return text[start : end + 1]
        return None

    def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
        """
        Extract preferences from the LAST user turn in the history.
        We don't concat history because our SFT model was trained on single-turn extraction.
        Using context might confuse it unless we trained it that way.
        """
        # Find the last user message
        last_user_msg = None
        for t in reversed(turns):
            if t.role == "user":
                last_user_msg = t.text
                break
        
        if not last_user_msg:
            return PreferenceList(preferences=[])
        
        result_dict = self.extract_preferences(last_user_msg)
        return PreferenceList.model_validate(result_dict)

    def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
        """
        Extract preferences from ALL user turns individually.
        """
        results = []
        for turn in turns:
            if turn.role == "user":
                res = self.extract_preferences(turn.text)
                results.append(PreferenceList.model_validate(res))
            else:
                results.append(PreferenceList(preferences=[]))
        return results