summaryrefslogtreecommitdiff
path: root/src/personalization/data/personamem_loader.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /src/personalization/data/personamem_loader.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'src/personalization/data/personamem_loader.py')
-rw-r--r--src/personalization/data/personamem_loader.py84
1 files changed, 84 insertions, 0 deletions
diff --git a/src/personalization/data/personamem_loader.py b/src/personalization/data/personamem_loader.py
new file mode 100644
index 0000000..3b516ad
--- /dev/null
+++ b/src/personalization/data/personamem_loader.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import csv
+import json
+from dataclasses import dataclass
+from typing import Dict, List
+
+@dataclass
+class PersonaMemQuestion:
+ persona_id: str
+ question_id: str
+ question_type: str
+ topic: str
+ user_question_or_message: str
+ all_options: List[str] # 4 options
+ correct_index: int # 0..3
+ shared_context_id: str
+ end_index_in_shared_context: int
+
+@dataclass
+class PersonaMemContext:
+ shared_context_id: str
+ messages: List[dict] # raw dicts with "role"/"content" etc
+
+def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]:
+ questions = []
+ with open(path_csv, "r", encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ for row in reader:
+ # Check fields
+ # The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc.
+ # Assuming standard PersonaMem format or similar to provided description.
+ # We might need to adjust based on actual file content.
+ # Based on user description:
+ try:
+ options_str = row.get("all_options", "[]") # Assuming json string
+ try:
+ options = json.loads(options_str)
+ except:
+ # Fallback if it's not JSON (e.g. string repr)
+ # For now assume JSON or simple list
+ options = []
+
+ # Handle raw answer format (e.g. "(c)" or "c")
+ raw_ans = row.get("correct_answer", "").strip()
+ # Remove parens if present
+ if raw_ans.startswith("(") and raw_ans.endswith(")"):
+ raw_ans = raw_ans[1:-1]
+
+ # Parse correct index
+ # If correct_answer is 'A','B','C','D' -> 0,1,2,3
+ ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3}
+ correct_idx = ans_map.get(raw_ans, -1)
+
+ q = PersonaMemQuestion(
+ persona_id=row["persona_id"],
+ question_id=row["question_id"],
+ question_type=row.get("question_type", "unknown"),
+ topic=row.get("topic", "unknown"),
+ user_question_or_message=row.get("user_question_or_message", row.get("question", "")),
+ all_options=options,
+ correct_index=correct_idx,
+ shared_context_id=row["shared_context_id"],
+ end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1))
+ )
+ questions.append(q)
+ except KeyError as e:
+ # print(f"Skipping row due to missing key: {e}")
+ continue
+ return questions
+
+def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]:
+ contexts = {}
+ with open(path_jsonl, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ # Format: {"hash_id": [messages...]}
+ for cid, msgs in data.items():
+ contexts[cid] = PersonaMemContext(
+ shared_context_id=cid,
+ messages=msgs
+ )
+ return contexts
+