1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
|
from __future__ import annotations
import csv
import json
from dataclasses import dataclass
from typing import Dict, List
@dataclass
class PersonaMemQuestion:
persona_id: str
question_id: str
question_type: str
topic: str
user_question_or_message: str
all_options: List[str] # 4 options
correct_index: int # 0..3
shared_context_id: str
end_index_in_shared_context: int
@dataclass
class PersonaMemContext:
shared_context_id: str
messages: List[dict] # raw dicts with "role"/"content" etc
def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]:
questions = []
with open(path_csv, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
# Check fields
# The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc.
# Assuming standard PersonaMem format or similar to provided description.
# We might need to adjust based on actual file content.
# Based on user description:
try:
options_str = row.get("all_options", "[]") # Assuming json string
try:
options = json.loads(options_str)
except:
# Fallback if it's not JSON (e.g. string repr)
# For now assume JSON or simple list
options = []
# Handle raw answer format (e.g. "(c)" or "c")
raw_ans = row.get("correct_answer", "").strip()
# Remove parens if present
if raw_ans.startswith("(") and raw_ans.endswith(")"):
raw_ans = raw_ans[1:-1]
# Parse correct index
# If correct_answer is 'A','B','C','D' -> 0,1,2,3
ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3}
correct_idx = ans_map.get(raw_ans, -1)
q = PersonaMemQuestion(
persona_id=row["persona_id"],
question_id=row["question_id"],
question_type=row.get("question_type", "unknown"),
topic=row.get("topic", "unknown"),
user_question_or_message=row.get("user_question_or_message", row.get("question", "")),
all_options=options,
correct_index=correct_idx,
shared_context_id=row["shared_context_id"],
end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1))
)
questions.append(q)
except KeyError as e:
# print(f"Skipping row due to missing key: {e}")
continue
return questions
def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]:
contexts = {}
with open(path_jsonl, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
# Format: {"hash_id": [messages...]}
for cid, msgs in data.items():
contexts[cid] = PersonaMemContext(
shared_context_id=cid,
messages=msgs
)
return contexts
|