from __future__ import annotations import csv import json from dataclasses import dataclass from typing import Dict, List @dataclass class PersonaMemQuestion: persona_id: str question_id: str question_type: str topic: str user_question_or_message: str all_options: List[str] # 4 options correct_index: int # 0..3 shared_context_id: str end_index_in_shared_context: int @dataclass class PersonaMemContext: shared_context_id: str messages: List[dict] # raw dicts with "role"/"content" etc def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]: questions = [] with open(path_csv, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: # Check fields # The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc. # Assuming standard PersonaMem format or similar to provided description. # We might need to adjust based on actual file content. # Based on user description: try: options_str = row.get("all_options", "[]") # Assuming json string try: options = json.loads(options_str) except: # Fallback if it's not JSON (e.g. string repr) # For now assume JSON or simple list options = [] # Handle raw answer format (e.g. "(c)" or "c") raw_ans = row.get("correct_answer", "").strip() # Remove parens if present if raw_ans.startswith("(") and raw_ans.endswith(")"): raw_ans = raw_ans[1:-1] # Parse correct index # If correct_answer is 'A','B','C','D' -> 0,1,2,3 ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3} correct_idx = ans_map.get(raw_ans, -1) q = PersonaMemQuestion( persona_id=row["persona_id"], question_id=row["question_id"], question_type=row.get("question_type", "unknown"), topic=row.get("topic", "unknown"), user_question_or_message=row.get("user_question_or_message", row.get("question", "")), all_options=options, correct_index=correct_idx, shared_context_id=row["shared_context_id"], end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1)) ) questions.append(q) except KeyError as e: # print(f"Skipping row due to missing key: {e}") continue return questions def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]: contexts = {} with open(path_jsonl, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) # Format: {"hash_id": [messages...]} for cid, msgs in data.items(): contexts[cid] = PersonaMemContext( shared_context_id=cid, messages=msgs ) return contexts