summaryrefslogtreecommitdiff
path: root/src/personalization/models/preference_extractor/gpt4o_extractor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/preference_extractor/gpt4o_extractor.py')
-rw-r--r--src/personalization/models/preference_extractor/gpt4o_extractor.py97
1 files changed, 97 insertions, 0 deletions
diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py
new file mode 100644
index 0000000..212bb13
--- /dev/null
+++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+import json
+import os
+from typing import Any, Dict, List
+
+from openai import OpenAI
+from personalization.config.settings import LocalModelsConfig
+from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ ChatTurn,
+ PreferenceList,
+ preference_list_json_schema,
+)
+
+
+class GPT4OExtractor(PreferenceExtractor):
+ def __init__(self, api_key: str, model: str = "gpt-4o") -> None:
+ self.client = OpenAI(api_key=api_key)
+ self.model = model
+
+ # Load system prompt template
+ template_path = "fine_tuning_prompt_template.txt"
+ if os.path.exists(template_path):
+ with open(template_path, "r", encoding="utf-8") as f:
+ self.system_prompt = f.read()
+ else:
+ # Fallback simple prompt if file missing
+ self.system_prompt = (
+ "You are a preference extraction assistant. "
+ "Extract user preferences from the query into a JSON object."
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor":
+ # We rely on env var for API key, config for other potential settings if needed
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("OPENAI_API_KEY environment variable not set")
+ return cls(api_key=api_key)
+
+ def build_preference_prompt(self, query: str) -> str:
+ # GPT4OExtractor uses the system prompt loaded in __init__
+ return self.system_prompt
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ # Reuse logic but return raw dict
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": query},
+ ],
+ response_format={"type": "json_object"},
+ temperature=0.0,
+ )
+ content = response.choices[0].message.content
+ if content:
+ return json.loads(content)
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return {"preferences": []}
+
+ def extract_turn(self, turn: ChatTurn) -> PreferenceList:
+ if turn.role != "user":
+ return PreferenceList(preferences=[])
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": turn.text},
+ ],
+ response_format={"type": "json_object"},
+ temperature=0.0,
+ )
+
+ content = response.choices[0].message.content
+ if not content:
+ return PreferenceList(preferences=[])
+
+ data = json.loads(content)
+ # The prompt might return {"preferences": [...]}, validate it
+ return PreferenceList.model_validate(data)
+
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return PreferenceList(preferences=[])
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ results = []
+ for turn in turns:
+ results.append(self.extract_turn(turn))
+ return results
+