summaryrefslogtreecommitdiff
path: root/src/personalization/data/personamem_loader.py
blob: 3b516ad4999956cb1462b8f17b2ae0723cb63064 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import annotations

import csv
import json
from dataclasses import dataclass
from typing import Dict, List

@dataclass
class PersonaMemQuestion:
    persona_id: str
    question_id: str
    question_type: str
    topic: str
    user_question_or_message: str
    all_options: List[str]   # 4 options
    correct_index: int       # 0..3
    shared_context_id: str
    end_index_in_shared_context: int

@dataclass
class PersonaMemContext:
    shared_context_id: str
    messages: List[dict]  # raw dicts with "role"/"content" etc

def load_personamem_questions_32k(path_csv: str) -> List[PersonaMemQuestion]:
    questions = []
    with open(path_csv, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            # Check fields
            # The official csv usually has: question_id, persona_id, shared_context_id, question, correct_answer, options etc.
            # Assuming standard PersonaMem format or similar to provided description.
            # We might need to adjust based on actual file content.
            # Based on user description:
            try:
                options_str = row.get("all_options", "[]") # Assuming json string
                try:
                    options = json.loads(options_str)
                except:
                    # Fallback if it's not JSON (e.g. string repr)
                    # For now assume JSON or simple list
                    options = []

                # Handle raw answer format (e.g. "(c)" or "c")
                raw_ans = row.get("correct_answer", "").strip()
                # Remove parens if present
                if raw_ans.startswith("(") and raw_ans.endswith(")"):
                    raw_ans = raw_ans[1:-1]
                
                # Parse correct index
                # If correct_answer is 'A','B','C','D' -> 0,1,2,3
                ans_map = {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'a': 0, 'b': 1, 'c': 2, 'd': 3}
                correct_idx = ans_map.get(raw_ans, -1)
                
                q = PersonaMemQuestion(
                    persona_id=row["persona_id"],
                    question_id=row["question_id"],
                    question_type=row.get("question_type", "unknown"),
                    topic=row.get("topic", "unknown"),
                    user_question_or_message=row.get("user_question_or_message", row.get("question", "")), 
                    all_options=options,
                    correct_index=correct_idx,
                    shared_context_id=row["shared_context_id"],
                    end_index_in_shared_context=int(row.get("end_index_in_shared_context", -1))
                )
                questions.append(q)
            except KeyError as e:
                # print(f"Skipping row due to missing key: {e}")
                continue
    return questions

def load_personamem_contexts_32k(path_jsonl: str) -> Dict[str, PersonaMemContext]:
    contexts = {}
    with open(path_jsonl, "r", encoding="utf-8") as f:
        for line in f:
            data = json.loads(line)
            # Format: {"hash_id": [messages...]}
            for cid, msgs in data.items():
                contexts[cid] = PersonaMemContext(
                    shared_context_id=cid,
                    messages=msgs
                )
    return contexts