diff options
Diffstat (limited to 'src/personalization/models/preference_extractor')
| -rw-r--r-- | src/personalization/models/preference_extractor/rule_extractor.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py index 0f743d9..42f43ed 100644 --- a/src/personalization/models/preference_extractor/rule_extractor.py +++ b/src/personalization/models/preference_extractor/rule_extractor.py @@ -119,6 +119,59 @@ class QwenRuleExtractor(PreferenceExtractor): return text[start : end + 1] return None + @torch.inference_mode() + def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]: + """ + Batch extract preferences from multiple queries using left-padded batching. + """ + if not queries: + return [] + + # Save and set padding side for decoder-only batched generation + orig_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = "left" + + all_results = [] + prompts = [self.build_preference_prompt(q) for q in queries] + + for start in range(0, len(prompts), batch_size): + batch_prompts = prompts[start:start + batch_size] + inputs = self.tokenizer( + batch_prompts, return_tensors="pt", padding=True, truncation=True + ).to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, + max_new_tokens=512, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + for i in range(len(batch_prompts)): + input_len = (inputs["attention_mask"][i] == 1).sum().item() + gen_ids = outputs[i][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + try: + data = json.loads(text) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + except Exception: + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + all_results.append(validated.model_dump()) + continue + except Exception: + pass + all_results.append({"preferences": []}) + + self.tokenizer.padding_side = orig_padding_side + return all_results + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: """ Extract preferences from the LAST user turn in the history. |
