1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
|
#!/usr/bin/env python3
"""
Online Personalization REPL Demo.
Interactive CLI for chatting with the Personalized Memory RAG system.
Includes:
- Extractor-0.6B for online preference extraction
- Reranker + Policy Retrieval
- Online RL updates (REINFORCE)
"""
import sys
import os
import uuid
import numpy as np
import torch
import readline # For better input handling
import yaml
# Add src to sys.path
sys.path.append(os.path.join(os.path.dirname(__file__), "../src"))
from personalization.config.settings import load_local_models_config
from personalization.config.registry import (
get_preference_extractor,
get_chat_model,
)
from personalization.models.embedding.qwen3_8b import Qwen3Embedding8B
from personalization.models.reranker.qwen3_reranker import Qwen3Reranker
# from personalization.models.llm.qwen_instruct import QwenInstruct # Deprecated direct import
from personalization.user_model.tensor_store import UserTensorStore
from personalization.user_model.session_state import OnlineSessionState
from personalization.retrieval.preference_store.schemas import MemoryCard, ChatTurn, PreferenceList
from personalization.retrieval.pipeline import retrieve_with_policy
from personalization.feedback.handlers import eval_step
from personalization.user_model.policy.reinforce import reinforce_update_user_state
from personalization.user_model.features import ItemProjection
def load_memory_store():
cards_path = "data/corpora/memory_cards.jsonl"
embs_path = "data/corpora/memory_embeddings.npy"
item_proj_path = "data/corpora/item_projection.npz"
if not os.path.exists(cards_path) or not os.path.exists(embs_path):
print("Memory data missing. Starting with empty memory store is possible but item space requires base data.")
# For this demo, we assume base data exists to define PCA space.
sys.exit(1)
print(f"Loading memory cards from {cards_path}...")
cards = []
with open(cards_path, "r") as f:
for line in f:
cards.append(MemoryCard.model_validate_json(line))
memory_embeddings = np.load(embs_path)
# Load PCA projection
proj_data = np.load(item_proj_path)
# We need to reconstruct ItemProjection object to transform new memories
projection = ItemProjection(P=proj_data["P"], mean=proj_data["mean"])
item_vectors = proj_data["V"]
return cards, memory_embeddings, item_vectors, projection
def build_user_turn(user_id: str, text: str, turn_id: int) -> ChatTurn:
return ChatTurn(
user_id=user_id,
session_id="online_debug_session",
turn_id=turn_id,
role="user",
text=text,
meta={"source": "repl"}
)
def build_assistant_turn(user_id: str, text: str, turn_id: int) -> ChatTurn:
return ChatTurn(
user_id=user_id,
session_id="online_debug_session",
turn_id=turn_id,
role="assistant",
text=text,
meta={"source": "repl"}
)
def add_preferences_as_memory_cards(
prefs: PreferenceList,
query: str,
user_id: str,
turn_id: int,
embed_model: Qwen3Embedding8B,
projection: ItemProjection,
memory_cards: list,
memory_embeddings: np.ndarray,
item_vectors: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""
Adds extracted preferences as new memory cards.
Returns updated memory_embeddings and item_vectors.
"""
if not prefs.preferences:
print(" [Extractor] No preferences found in this turn.")
return memory_embeddings, item_vectors
e_m_list = []
v_m_list = []
# Only compute embedding once if we use the query as source for all prefs from this turn
# Alternatively, embed the note text.
# The current design uses the original query embedding e_m.
e_q = embed_model.encode([query], return_tensor=False)[0]
v_q = projection.transform_vector(np.array(e_q))
print(f" [Extractor] Extracted {len(prefs.preferences)} preferences:")
for pref in prefs.preferences:
note_text = f"When {pref.condition}, {pref.action}."
print(f" - {note_text}")
# Simple deduplication check based on note_text for this user
# In a real system, use vector similarity or hash
is_duplicate = False
for card in memory_cards:
if card.user_id == user_id and card.note_text == note_text:
is_duplicate = True
break
if is_duplicate:
print(" (Duplicate, skipping add)")
continue
card = MemoryCard(
card_id=str(uuid.uuid4()),
user_id=user_id,
source_session_id="online_debug_session",
source_turn_ids=[turn_id],
raw_queries=[query],
preference_list=PreferenceList(preferences=[pref]),
note_text=note_text,
embedding_e=e_q, # Store list[float]
kind="pref",
)
memory_cards.append(card)
e_m_list.append(e_q)
v_m_list.append(v_q)
# Update numpy arrays
if e_m_list:
new_embs = np.array(e_m_list)
new_vecs = np.array(v_m_list)
memory_embeddings = np.vstack([memory_embeddings, new_embs])
item_vectors = np.vstack([item_vectors, new_vecs])
print(f" [Debug] Added {len(e_m_list)} new cards. Total cards: {len(memory_cards)}")
return memory_embeddings, item_vectors
def main():
# 1. Load Config & Models
print("Loading configuration...")
cfg = load_local_models_config()
# RL Config (Should load from user_model.yaml, hardcoded for safety/demo)
rl_cfg = {
"item_dim": 256,
"beta_long": 0.1,
"beta_short": 0.3,
"tau": 1.0,
"eta_long": 1e-3,
"eta_short": 5e-3,
"ema_alpha": 0.05,
"short_decay": 0.1,
"dense_topk": 64,
"rerank_topk": 3,
"max_new_tokens": 512
}
print("Loading models and stores...")
# Using explicit classes for clarity, but registry can also be used
embed_model = Qwen3Embedding8B.from_config(cfg)
reranker = Qwen3Reranker.from_config(cfg)
# Use registry for ChatModel (supports switching backends)
# Default to "qwen_1_5b" if not specified in user_model.yaml
llm_name = "qwen_1_5b"
# Try loading from config safely
try:
config_path = os.path.join(os.path.dirname(__file__), "../configs/user_model.yaml")
if os.path.exists(config_path):
with open(config_path, "r") as f:
user_cfg = yaml.safe_load(f)
if user_cfg and "llm_name" in user_cfg:
llm_name = user_cfg["llm_name"]
print(f"Loaded llm_name from config: {llm_name}")
else:
print(f"Warning: Config file not found at {config_path}")
except Exception as e:
print(f"Failed to load user_model.yaml: {e}")
pass
print(f"Loading ChatModel: {llm_name}...")
chat_model = get_chat_model(llm_name)
# Use registry for extractor to support switching
extractor_name = "qwen3_0_6b_sft" # Default per design doc
print(f"Loading extractor: {extractor_name}...")
try:
extractor = get_preference_extractor(extractor_name)
except Exception as e:
print(f"Failed to load {extractor_name}: {e}. Fallback to rule.")
extractor = get_preference_extractor("rule")
user_store = UserTensorStore(
k=rl_cfg["item_dim"],
path="data/users/user_store_online.npz",
)
# Load Memory
memory_cards, memory_embeddings, item_vectors, projection = load_memory_store()
# 2. Init Session
user_id = "debug_user"
user_state = user_store.get_state(user_id)
session_state = OnlineSessionState(user_id=user_id)
print(f"\n--- Online Personalization REPL (User: {user_id}) ---")
print(f"Initial State: ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}")
print("Type 'exit' or 'quit' to stop.\n")
while True:
try:
q_t = input("User: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nExiting...")
break
if q_t.lower() in ("exit", "quit"):
break
if not q_t:
continue
# 3. RL Update (from previous turn)
e_q_t = embed_model.encode([q_t], return_tensor=False)[0]
e_q_t = np.array(e_q_t)
if session_state.last_query is not None:
r_hat, g_hat = eval_step(
q_t=session_state.last_query,
answer_t=session_state.last_answer,
q_t1=q_t,
memories_t=session_state.last_memories,
query_embedding_t=session_state.last_query_embedding,
query_embedding_t1=e_q_t,
)
print(f" [Feedback] Reward: {r_hat:.2f}, Gating: {g_hat:.2f}")
if (session_state.last_candidate_item_vectors is not None and
session_state.last_policy_probs is not None and
len(session_state.last_chosen_indices) > 0):
# IMPORTANT: Extract the vectors of the chosen items to align with probs
# last_candidate_item_vectors: [64, dim]
# last_chosen_indices: [3] indices into the 64 candidates
# last_policy_probs: [3] probabilities for the chosen items
# We need the vectors corresponding to the chosen indices
# chosen_indices contains indices into candidates list
chosen_vectors = session_state.last_candidate_item_vectors[session_state.last_chosen_indices]
updated = reinforce_update_user_state(
user_state=user_state,
item_vectors=chosen_vectors, # Corrected: Pass only chosen vectors [3, dim]
chosen_indices=np.arange(len(session_state.last_chosen_indices)), # Indices are now 0,1,2 relative to chosen_vectors
policy_probs=session_state.last_policy_probs,
reward_hat=r_hat,
gating=g_hat,
tau=rl_cfg["tau"],
eta_long=rl_cfg["eta_long"],
eta_short=rl_cfg["eta_short"],
ema_alpha=rl_cfg["ema_alpha"],
short_decay=rl_cfg["short_decay"],
)
if updated:
print(" [RL] User state updated.")
user_store.save_state(user_state) # Save immediately for safety
# 4. Update History
user_turn = build_user_turn(user_id, q_t, len(session_state.history))
session_state.history.append(user_turn)
# 5. Extract Preferences -> New Memory
# Extract from recent history
prefs = extractor.extract_turn(session_state.history)
memory_embeddings, item_vectors = add_preferences_as_memory_cards(
prefs, q_t, user_id, user_turn.turn_id,
embed_model, projection, memory_cards, memory_embeddings, item_vectors
)
# 6. Retrieve + Policy
# Use only_own_memories=True to allow strict privacy
# Fix unpacking order: pipeline returns (candidates, vecs, scores, indices, probs)
candidates, cand_item_vecs, base_scores, chosen_indices, probs = retrieve_with_policy(
user_id=user_id,
query=q_t,
embed_model=embed_model,
reranker=reranker,
memory_cards=memory_cards,
memory_embeddings=memory_embeddings,
user_store=user_store,
item_vectors=item_vectors,
topk_dense=rl_cfg["dense_topk"],
topk_rerank=rl_cfg["rerank_topk"],
beta_long=rl_cfg["beta_long"],
beta_short=rl_cfg["beta_short"],
tau=rl_cfg["tau"],
only_own_memories=True # User requested strict privacy for demo
)
# Map back to indices in candidates list (0..K-1)
print(f"DEBUG: candidates len={len(candidates)}, type={type(candidates)}")
print(f"DEBUG: chosen_indices={chosen_indices}, type={type(chosen_indices)}")
if len(chosen_indices) > 0:
print(f"DEBUG: first idx type={type(chosen_indices[0])}, val={chosen_indices[0]}")
memories_t = [candidates[int(i)] for i in chosen_indices]
if memories_t:
print(f" [Retrieval] Found {len(memories_t)} memories:")
# Display Deduplication: Group by note_text
from collections import Counter
content_counts = Counter([m.note_text for m in memories_t])
# Print unique contents with counts
for text, count in content_counts.most_common():
user_info = f" ({count} users)" if count > 1 else ""
print(f" - {text}{user_info}")
# 7. LLM Answer
memory_notes = [m.note_text for m in memories_t]
# history should be a list of ChatTurn objects, not dicts
# session_state.history is already a list of ChatTurn
answer_t = chat_model.answer(
history=session_state.history,
memory_notes=memory_notes,
max_new_tokens=rl_cfg["max_new_tokens"]
)
print(f"Assistant: {answer_t}")
# 8. Update State for Next Turn
assist_turn = build_assistant_turn(user_id, answer_t, len(session_state.history))
session_state.history.append(assist_turn)
session_state.last_query = q_t
session_state.last_answer = answer_t
session_state.last_memories = memories_t
session_state.last_query_embedding = e_q_t
session_state.last_candidate_item_vectors = cand_item_vecs
session_state.last_policy_probs = probs
session_state.last_chosen_indices = chosen_indices
print(f" [State] ||z_long||={np.linalg.norm(user_state.z_long):.16f}, ||z_short||={np.linalg.norm(user_state.z_short):.16f}")
print("-" * 40)
print("Saving final user state...")
user_store.save_state(user_state)
user_store.persist()
# Save updated memories
print(f"Saving {len(memory_cards)} memory cards to disk...")
# Ideally should be atomic or append-only, but for demo we rewrite
# Backup original first? For demo, direct overwrite is fine or save to new file
cards_path = "data/corpora/memory_cards.jsonl"
embs_path = "data/corpora/memory_embeddings.npy"
item_proj_path = "data/corpora/item_projection.npz"
with open(cards_path, "w", encoding="utf-8") as f:
for card in memory_cards:
f.write(card.model_dump_json() + "\n")
np.save(embs_path, memory_embeddings)
# Update item projection file with new item vectors?
# item_projection.npz usually stores the Projection Matrix P and Mean.
# The 'V' (item vectors) in it is just a cache.
# We should update V in the npz so next load has them.
# Load original to keep P and mean
proj_data = np.load(item_proj_path)
np.savez(
item_proj_path,
P=proj_data["P"],
mean=proj_data["mean"],
V=item_vectors
)
print("Memory store updated.")
if __name__ == "__main__":
main()
|