1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
from typing import Iterable, List, Optional
import numpy as np
from tqdm import tqdm
from personalization.retrieval.preference_store.schemas import ChatTurn, MemoryCard
from personalization.feedback.schemas import TurnSample
from personalization.retrieval.pipeline import retrieve_with_rerank
from personalization.models.llm.qwen_instruct import QwenInstruct
from personalization.models.embedding.base import EmbeddingModel
from personalization.models.reranker.base import Reranker
from personalization.user_model.tensor_store import UserTensorStore
def build_turn_samples_from_sessions(
sessions: Iterable[List[ChatTurn]],
embed_model: EmbeddingModel,
llm: QwenInstruct,
reranker: Reranker,
memory_cards: List[MemoryCard],
memory_embeddings: np.ndarray,
user_store: UserTensorStore,
item_vectors: np.ndarray,
max_samples: Optional[int] = None,
topk_dense: int = 64,
topk_rerank: int = 3,
) -> List[TurnSample]:
samples = []
for turns in tqdm(sessions, desc="Building TurnSamples"):
if max_samples and len(samples) >= max_samples:
break
# Ensure sorted by turn_id
sorted_turns = sorted(turns, key=lambda x: x.turn_id)
# Iterate to find (q_t, a_t, q_{t+1})
for i in range(len(sorted_turns)):
if max_samples and len(samples) >= max_samples:
break
q_t = sorted_turns[i]
if q_t.role != "user":
continue
# Find next user turn
# Also try to find assistant response in between
a_t_text = ""
q_t1 = None
# Look ahead
for j in range(i + 1, len(sorted_turns)):
next_turn = sorted_turns[j]
if next_turn.role == "assistant" and not a_t_text:
a_t_text = next_turn.text
elif next_turn.role == "user":
q_t1 = next_turn
break
if not q_t1:
# End of session or no subsequent user query
continue
# We have q_t, a_t (optional but preferred), q_t1
# If a_t is missing, we might skip or use empty string.
# For RL, we usually need the answer to evaluate quality.
# If dataset doesn't have assistant turns, we might need to generate one?
# For now, let's proceed even if a_t is empty, or maybe require it.
if not a_t_text:
# Try to use LLM to generate if needed, but for offline sampling
# from existing chats, we prefer existing answers.
# If using OASST1, it should have assistant turns.
pass
# 3. Retrieve memories for q_t
memories_t = retrieve_with_rerank(
user_id=q_t.user_id,
query=q_t.text,
embed_model=embed_model,
reranker=reranker,
memory_cards=memory_cards,
memory_embeddings=memory_embeddings,
user_store=user_store,
item_vectors=item_vectors,
topk_dense=topk_dense,
topk_rerank=topk_rerank,
beta_long=0.0,
beta_short=0.0,
only_own_memories=True # Assume we want user specific memories
)
# 4. Precompute embeddings
# We can do this efficiently later or batch, but here per sample
e_q_t = embed_model.encode([q_t.text], return_tensor=False)[0]
e_q_t1 = embed_model.encode([q_t1.text], return_tensor=False)[0]
sample = TurnSample(
user_id=q_t.user_id,
session_id=q_t.session_id,
turn_id=q_t.turn_id,
query_t=q_t.text,
answer_t=a_t_text,
query_t1=q_t1.text,
memories=memories_t,
query_embedding_t=np.array(e_q_t),
query_embedding_t1=np.array(e_q_t1)
)
samples.append(sample)
return samples
|