summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor/rule_extractor.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-10 20:16:36 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-10 20:16:36 +0000
commit5626080ca4c4219aec4888d6b9406d0d3349fb55 (patch)
tree86287d9fd5833e11ccd78566992540f2664fd195 /src/personalization/models/preference_extractor/rule_extractor.py
parenta2036838807428424bbbaff507a6563749a83145 (diff)
Add RAG rewrite, 60-session experiment scripts, and analysis tools
- RAG rewrite adapter and vector preference pipeline in personalized_llm - 60-session experiment queue scripts (reflection, rag, rag_vector, rag_rewrite) - Vector-preference correlation analysis and visualization scripts - Local reward model batch processing improvements - Updated CLAUDE.md with full experiment documentation and notes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/preference_extractor/rule_extractor.py')
-rw-r--r--src/personalization/models/preference_extractor/rule_extractor.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py
index 0f743d9..42f43ed 100644
--- a/src/personalization/models/preference_extractor/rule_extractor.py
+++ b/src/personalization/models/preference_extractor/rule_extractor.py
@@ -119,6 +119,59 @@ class QwenRuleExtractor(PreferenceExtractor):
return text[start : end + 1]
return None
+ @torch.inference_mode()
+ def batch_extract_preferences(self, queries: List[str], batch_size: int = 64) -> List[Dict[str, Any]]:
+ """
+ Batch extract preferences from multiple queries using left-padded batching.
+ """
+ if not queries:
+ return []
+
+ # Save and set padding side for decoder-only batched generation
+ orig_padding_side = self.tokenizer.padding_side
+ self.tokenizer.padding_side = "left"
+
+ all_results = []
+ prompts = [self.build_preference_prompt(q) for q in queries]
+
+ for start in range(0, len(prompts), batch_size):
+ batch_prompts = prompts[start:start + batch_size]
+ inputs = self.tokenizer(
+ batch_prompts, return_tensors="pt", padding=True, truncation=True
+ ).to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False,
+ max_new_tokens=512,
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ for i in range(len(batch_prompts)):
+ input_len = (inputs["attention_mask"][i] == 1).sum().item()
+ gen_ids = outputs[i][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ try:
+ data = json.loads(text)
+ validated = PreferenceList.model_validate(data)
+ all_results.append(validated.model_dump())
+ except Exception:
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ all_results.append(validated.model_dump())
+ continue
+ except Exception:
+ pass
+ all_results.append({"preferences": []})
+
+ self.tokenizer.padding_side = orig_padding_side
+ return all_results
+
def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
"""
Extract preferences from the LAST user turn in the history.