From 85c330e0c0d5f560a20d950a203b368bb35ed45b Mon Sep 17 00:00:00 2001 From: chensyCN Date: Sat, 10 Jan 2026 13:32:19 +0800 Subject: update vectorized retrieval --- src/LinearRAG.py | 105 ++++++++++++++++++++++++++++++++++--------------------- src/ner.py | 2 +- 2 files changed, 66 insertions(+), 41 deletions(-) diff --git a/src/LinearRAG.py b/src/LinearRAG.py index 76b65ad..d4fed46 100644 --- a/src/LinearRAG.py +++ b/src/LinearRAG.py @@ -353,30 +353,67 @@ class LinearRAG: sentence_activation ) - # Step 2: Apply sentence similarities (element-wise on dense vector) - weighted_sentence_scores = sentence_activation * sentence_similarities + # Step 2: Per-entity top-k sentence selection + # This matches BFS behavior: each entity independently selects its top-k sentences + selected_sentence_indices_list = [] - # Implement per-entity top-k sentence selection (more aggressive pruning) - # For vectorized efficiency, we use a tighter global approximation - num_active = len(nonzero_indices) - if num_active > 0 and self.config.top_k_sentence > 0: - # Calculate adaptive k based on number of active entities - # Use per-entity top-k approximation: num_active * top_k_sentence - k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores)) - if k > 0: - # Get top-k sentences - top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k) - # Zero out all non-top-k sentences - mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool) - mask[top_k_indices] = True + if len(nonzero_indices) > 0 and self.config.top_k_sentence > 0: + # Iterate through each active entity + for i, entity_idx in enumerate(nonzero_indices): + entity_score = nonzero_values[i] + + # Get sentences connected to this entity from the sparse matrix + # entity_to_sentence_sparse shape: (num_entities, num_sentences) + entity_row = self.entity_to_sentence_sparse[entity_idx].coalesce() + entity_sentence_indices = entity_row.indices()[0] # Get column indices + + if len(entity_sentence_indices) == 0: + continue + + # Filter out already used sentences + sentence_mask = ~used_sentence_mask[entity_sentence_indices] + available_sentence_indices = entity_sentence_indices[sentence_mask] + + if len(available_sentence_indices) == 0: + continue + + # Get sentence similarities (for ranking) + sentence_sims = sentence_similarities[available_sentence_indices] + + # Select top-k sentences based ONLY on sentence similarity (matches BFS line 240) + # NOT weighted by entity_score at selection time + k = min(self.config.top_k_sentence, len(sentence_sims)) + if k > 0: + top_k_values, top_k_local_indices = torch.topk(sentence_sims, k) + top_k_sentence_indices = available_sentence_indices[top_k_local_indices] + selected_sentence_indices_list.append(top_k_sentence_indices) + + # Merge all selected sentences (with deduplication via unique) + if len(selected_sentence_indices_list) > 0: + all_selected_sentences = torch.cat(selected_sentence_indices_list) + unique_selected_sentences = torch.unique(all_selected_sentences) + + # Mark selected sentences as used + used_sentence_mask[unique_selected_sentences] = True + + # Compute weighted sentence scores for propagation + # weighted_score = sentence_activation * sentence_similarity + weighted_sentence_scores = sentence_activation * sentence_similarities + + # Zero out non-selected sentences + mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device) + mask[unique_selected_sentences] = True weighted_sentence_scores = torch.where( mask, weighted_sentence_scores, torch.zeros_like(weighted_sentence_scores) ) - - # Mark these sentences as used for deduplication - used_sentence_mask[top_k_indices] = True + else: + # No sentences selected, create zero vector + weighted_sentence_scores = torch.zeros(num_sentences, dtype=torch.float32, device=self.device) + else: + # No active entities or top_k_sentence is 0 + weighted_sentence_scores = torch.zeros(num_sentences, dtype=torch.float32, device=self.device) # Step 3: Weighted sentences @ S2E -> propagate to next entities # Convert to sparse for more efficient computation @@ -406,14 +443,15 @@ class LinearRAG: # Update entity scores (accumulate in dense format) entity_scores_dense += next_entity_scores_dense - # Update actived_entities dictionary (only for entities above threshold) + # Update actived_entities dictionary (record last trigger like BFS) + # This matches BFS behavior: unconditionally update for entities above threshold next_entity_scores_np = next_entity_scores_dense.cpu().numpy() active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0] for entity_idx in active_indices: score = next_entity_scores_np[entity_idx] entity_hash_id = self.entity_hash_ids[entity_idx] - if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score: - actived_entities[entity_hash_id] = (entity_idx, float(score), iteration) + # Unconditionally update to record the last trigger (matches BFS line 252) + actived_entities[entity_hash_id] = (entity_idx, float(score), iteration) # Prepare sparse tensor for next iteration next_nonzero_mask = next_entity_scores_dense > 0 @@ -557,30 +595,17 @@ class LinearRAG: ] def add_edges(self): - existing_edges = set() - for edge in self.graph.es: - source_name = self.graph.vs[edge.source]["name"] - target_name = self.graph.vs[edge.target]["name"] - existing_edges.add(frozenset([source_name, target_name])) - - new_edges = [] - new_weights = [] + edges = [] + weights = [] for node_hash_id, node_to_node_stats in self.node_to_node_stats.items(): for neighbor_hash_id, weight in node_to_node_stats.items(): if node_hash_id == neighbor_hash_id: continue - edge_key = frozenset([node_hash_id, neighbor_hash_id]) - if edge_key not in existing_edges: - new_edges.append((node_hash_id, neighbor_hash_id)) - new_weights.append(weight) - existing_edges.add(edge_key) - - if new_edges: - self.graph.add_edges(new_edges) - start_idx = len(self.graph.es) - len(new_edges) - for i, weight in enumerate(new_weights): - self.graph.es[start_idx + i]['weight'] = weight + edges.append((node_hash_id, neighbor_hash_id)) + weights.append(weight) + self.graph.add_edges(edges) + self.graph.es['weight'] = weights def add_entity_to_passage_edges(self, passage_hash_id_to_entities): passage_to_entity_count ={} diff --git a/src/ner.py b/src/ner.py index 4ee788e..e758fde 100644 --- a/src/ner.py +++ b/src/ner.py @@ -27,7 +27,7 @@ class SpacyNER: sentence_to_entities = defaultdict(list) unique_entities = set() passage_hash_id_to_entities = {} - pdb.set_trace() + # pdb.set_trace() # 注释掉调试断点 for ent in doc.ents: if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL": continue -- cgit v1.2.3