diff options
Diffstat (limited to 'src/personalization/data')
| -rw-r--r-- | src/personalization/data/personamem_loader.py | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/src/personalization/data/personamem_loader.py b/src/personalization/data/personamem_loader.py new file mode 100644 index 0000000..3b516ad --- /dev/null +++ b/src/personalization/data/personamem_loader.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import csv +import json +from dataclasses import dataclass +from typing import Dict, List + +@dataclass +class PersonaMemQuestion: + persona_id: str + question_id: str + question_type: str + topic: str + user_question_or_message: str + all_options: List[str] # 4 options + correct_index: int # 0..3 + shared_context_id: str + end_index_in_shared_context: int + +@dataclass +class PersonaMemContext: + shared_context_id: str + messages: List[dict] # raw dicts with "role"/"content" etc + +def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]: + questions = [] + with open(path_csv, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + # Check fields + # The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc. + # Assuming standard PersonaMem format or similar to provided description. + # We might need to adjust based on actual file content. + # Based on user description: + try: + options_str = row.get("all_options", "[]") # Assuming json string + try: + options = json.loads(options_str) + except: + # Fallback if it's not JSON (e.g. string repr) + # For now assume JSON or simple list + options = [] + + # Handle raw answer format (e.g. "(c)" or "c") + raw_ans = row.get("correct_answer", "").strip() + # Remove parens if present + if raw_ans.startswith("(") and raw_ans.endswith(")"): + raw_ans = raw_ans[1:-1] + + # Parse correct index + # If correct_answer is 'A','B','C','D' -> 0,1,2,3 + ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3} + correct_idx = ans_map.get(raw_ans, -1) + + q = PersonaMemQuestion( + persona_id=row["persona_id"], + question_id=row["question_id"], + question_type=row.get("question_type", "unknown"), + topic=row.get("topic", "unknown"), + user_question_or_message=row.get("user_question_or_message", row.get("question", "")), + all_options=options, + correct_index=correct_idx, + shared_context_id=row["shared_context_id"], + end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1)) + ) + questions.append(q) + except KeyError as e: + # print(f"Skipping row due to missing key: {e}") + continue + return questions + +def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]: + contexts = {} + with open(path_jsonl, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + # Format: {"hash_id": [messages...]} + for cid, msgs in data.items(): + contexts[cid] = PersonaMemContext( + shared_context_id=cid, + messages=msgs + ) + return contexts + |
