summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorchensyCN <chensy9605@gmail.com>2026-01-10 13:32:19 +0800
committerchensyCN <chensy9605@gmail.com>2026-01-10 13:32:19 +0800
commit85c330e0c0d5f560a20d950a203b368bb35ed45b (patch)
tree7991a5197fdba019487d061af24554f043742ee8
parentc3816630d1c145c33c6928cb3a4f248381aca96d (diff)
update vectorized retrieval
-rw-r--r--src/LinearRAG.py105
-rw-r--r--src/ner.py2
2 files changed, 66 insertions, 41 deletions
diff --git a/src/LinearRAG.py b/src/LinearRAG.py
index 76b65ad..d4fed46 100644
--- a/src/LinearRAG.py
+++ b/src/LinearRAG.py
@@ -353,30 +353,67 @@ class LinearRAG:
sentence_activation
)
- # Step 2: Apply sentence similarities (element-wise on dense vector)
- weighted_sentence_scores = sentence_activation * sentence_similarities
+ # Step 2: Per-entity top-k sentence selection
+ # This matches BFS behavior: each entity independently selects its top-k sentences
+ selected_sentence_indices_list = []
- # Implement per-entity top-k sentence selection (more aggressive pruning)
- # For vectorized efficiency, we use a tighter global approximation
- num_active = len(nonzero_indices)
- if num_active > 0 and self.config.top_k_sentence > 0:
- # Calculate adaptive k based on number of active entities
- # Use per-entity top-k approximation: num_active * top_k_sentence
- k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores))
- if k > 0:
- # Get top-k sentences
- top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k)
- # Zero out all non-top-k sentences
- mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool)
- mask[top_k_indices] = True
+ if len(nonzero_indices) > 0 and self.config.top_k_sentence > 0:
+ # Iterate through each active entity
+ for i, entity_idx in enumerate(nonzero_indices):
+ entity_score = nonzero_values[i]
+
+ # Get sentences connected to this entity from the sparse matrix
+ # entity_to_sentence_sparse shape: (num_entities, num_sentences)
+ entity_row = self.entity_to_sentence_sparse[entity_idx].coalesce()
+ entity_sentence_indices = entity_row.indices()[0] # Get column indices
+
+ if len(entity_sentence_indices) == 0:
+ continue
+
+ # Filter out already used sentences
+ sentence_mask = ~used_sentence_mask[entity_sentence_indices]
+ available_sentence_indices = entity_sentence_indices[sentence_mask]
+
+ if len(available_sentence_indices) == 0:
+ continue
+
+ # Get sentence similarities (for ranking)
+ sentence_sims = sentence_similarities[available_sentence_indices]
+
+ # Select top-k sentences based ONLY on sentence similarity (matches BFS line 240)
+ # NOT weighted by entity_score at selection time
+ k = min(self.config.top_k_sentence, len(sentence_sims))
+ if k > 0:
+ top_k_values, top_k_local_indices = torch.topk(sentence_sims, k)
+ top_k_sentence_indices = available_sentence_indices[top_k_local_indices]
+ selected_sentence_indices_list.append(top_k_sentence_indices)
+
+ # Merge all selected sentences (with deduplication via unique)
+ if len(selected_sentence_indices_list) > 0:
+ all_selected_sentences = torch.cat(selected_sentence_indices_list)
+ unique_selected_sentences = torch.unique(all_selected_sentences)
+
+ # Mark selected sentences as used
+ used_sentence_mask[unique_selected_sentences] = True
+
+ # Compute weighted sentence scores for propagation
+ # weighted_score = sentence_activation * sentence_similarity
+ weighted_sentence_scores = sentence_activation * sentence_similarities
+
+ # Zero out non-selected sentences
+ mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device)
+ mask[unique_selected_sentences] = True
weighted_sentence_scores = torch.where(
mask,
weighted_sentence_scores,
torch.zeros_like(weighted_sentence_scores)
)
-
- # Mark these sentences as used for deduplication
- used_sentence_mask[top_k_indices] = True
+ else:
+ # No sentences selected, create zero vector
+ weighted_sentence_scores = torch.zeros(num_sentences, dtype=torch.float32, device=self.device)
+ else:
+ # No active entities or top_k_sentence is 0
+ weighted_sentence_scores = torch.zeros(num_sentences, dtype=torch.float32, device=self.device)
# Step 3: Weighted sentences @ S2E -> propagate to next entities
# Convert to sparse for more efficient computation
@@ -406,14 +443,15 @@ class LinearRAG:
# Update entity scores (accumulate in dense format)
entity_scores_dense += next_entity_scores_dense
- # Update actived_entities dictionary (only for entities above threshold)
+ # Update actived_entities dictionary (record last trigger like BFS)
+ # This matches BFS behavior: unconditionally update for entities above threshold
next_entity_scores_np = next_entity_scores_dense.cpu().numpy()
active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0]
for entity_idx in active_indices:
score = next_entity_scores_np[entity_idx]
entity_hash_id = self.entity_hash_ids[entity_idx]
- if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score:
- actived_entities[entity_hash_id] = (entity_idx, float(score), iteration)
+ # Unconditionally update to record the last trigger (matches BFS line 252)
+ actived_entities[entity_hash_id] = (entity_idx, float(score), iteration)
# Prepare sparse tensor for next iteration
next_nonzero_mask = next_entity_scores_dense > 0
@@ -557,30 +595,17 @@ class LinearRAG:
]
def add_edges(self):
- existing_edges = set()
- for edge in self.graph.es:
- source_name = self.graph.vs[edge.source]["name"]
- target_name = self.graph.vs[edge.target]["name"]
- existing_edges.add(frozenset([source_name, target_name]))
-
- new_edges = []
- new_weights = []
+ edges = []
+ weights = []
for node_hash_id, node_to_node_stats in self.node_to_node_stats.items():
for neighbor_hash_id, weight in node_to_node_stats.items():
if node_hash_id == neighbor_hash_id:
continue
- edge_key = frozenset([node_hash_id, neighbor_hash_id])
- if edge_key not in existing_edges:
- new_edges.append((node_hash_id, neighbor_hash_id))
- new_weights.append(weight)
- existing_edges.add(edge_key)
-
- if new_edges:
- self.graph.add_edges(new_edges)
- start_idx = len(self.graph.es) - len(new_edges)
- for i, weight in enumerate(new_weights):
- self.graph.es[start_idx + i]['weight'] = weight
+ edges.append((node_hash_id, neighbor_hash_id))
+ weights.append(weight)
+ self.graph.add_edges(edges)
+ self.graph.es['weight'] = weights
def add_entity_to_passage_edges(self, passage_hash_id_to_entities):
passage_to_entity_count ={}
diff --git a/src/ner.py b/src/ner.py
index 4ee788e..e758fde 100644
--- a/src/ner.py
+++ b/src/ner.py
@@ -27,7 +27,7 @@ class SpacyNER:
sentence_to_entities = defaultdict(list)
unique_entities = set()
passage_hash_id_to_entities = {}
- pdb.set_trace()
+ # pdb.set_trace() # 注释掉调试断点
for ent in doc.ents:
if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
continue