1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
#!/usr/bin/env python3
"""
Day 3 Demo: Feedback Loop Simulation (Reward + Gating).
"""
import sys
import os
import json
import numpy as np
import random
# Add src to sys.path
sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
from personalization.config.settings import load_local_models_config
from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn
from personalization.user_model.tensor_store import UserTensorStore
from personalization.retrieval.pipeline import retrieve_with_rerank
from personalization.feedback.handlers import eval_step
def main():
# Paths
cards_path = "data/corpora/memory_cards.jsonl"
embs_path = "data/corpora/memory_embeddings.npy"
item_proj_path = "data/corpora/item_projection.npz"
user_store_path = "data/users/user_store.npz"
oasst_path = "data/raw_datasets/oasst1_queries.jsonl" # Source of turns
# 1. Load Data
print("Loading data stores...")
if not os.path.exists(cards_path) or not os.path.exists(embs_path):
print("Memory data missing.")
sys.exit(1)
cards = []
with open(cards_path, "r") as f:
for line in f:
cards.append(MemoryCard.model_validate_json(line))
memory_embeddings = np.load(embs_path)
proj_data = np.load(item_proj_path)
item_vectors = proj_data["V"]
# 2. Load Models
print("Loading models...")
cfg = load_local_models_config()
embedder = Qwen3Embedding8B.from_config(cfg)
reranker = Qwen3Reranker.from_config(cfg)
user_store = UserTensorStore(k=256, path=user_store_path)
# 3. Simulate a Session
# Since we don't have full sessions in 'oasst1_queries.jsonl' (it's flat queries),
# we'll mock a session or try to find one if we had full chat logs.
# For demo, let's construct a synthetic scenario.
print("\n--- Synthetic Session Evaluation ---")
# Scenario 1: Success
# User asks python, system gives good python answer, user asks follow up.
user_id = "test_user_ok"
q_t = "How do I list files in a directory in Python?"
# Mock retrieval results (relevant)
# Ideally we'd run retrieval, but let's assume we found a relevant card
# For demo, let's actually run retrieval
hits = retrieve_with_rerank(
user_id=user_id,
query=q_t,
embed_model=embedder,
reranker=reranker,
memory_cards=cards,
memory_embeddings=memory_embeddings,
user_store=user_store,
item_vectors=item_vectors,
topk_dense=64,
topk_rerank=3
)
a_t = "You can use os.listdir() or pathlib.Path.iterdir(). Here is an example..."
q_t1 = "Great, can you show me the pathlib one?"
print(f"\n[Scenario 1]")
print(f"Q_t: {q_t}")
print(f"A_t: {a_t}")
print(f"Q_t+1: {q_t1}")
print(f"Memories: {[m.note_text for m in hits]}")
# Eval
e_q = embedder.encode([q_t], return_tensor=False)[0]
e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
e_q = np.array(e_q)
e_q1 = np.array(e_q1)
r_hat, g_hat = eval_step(q_t, a_t, q_t1, hits, e_q, e_q1)
print(f"-> Reward: {r_hat:.2f}")
print(f"-> Gating: {g_hat:.2f}")
# Scenario 2: Failure (Complaint)
q_t = "Explain quantum entanglement."
a_t = "Quantum entanglement is a phenomenon where particles..."
q_t1 = "No, that's not what I meant. Explain it simply like I'm five."
# Mock retrieval (irrelevant or empty?)
# Let's say we retrieved some python stuff again by mistake
print(f"\n[Scenario 2]")
print(f"Q_t: {q_t}")
print(f"A_t: {a_t}")
print(f"Q_t+1: {q_t1}")
print(f"Memories: {[m.note_text for m in hits]} (Irrelevant)")
e_q = embedder.encode([q_t], return_tensor=False)[0]
e_q1 = embedder.encode([q_t1], return_tensor=False)[0]
e_q = np.array(e_q)
e_q1 = np.array(e_q1)
r_hat, g_hat = eval_step(q_t, a_t, q_t1, hits, e_q, e_q1)
print(f"-> Reward: {r_hat:.2f}")
print(f"-> Gating: {g_hat:.2f}")
if __name__ == "__main__":
main()
|