From 7a4ad471c9026c7882504b1c8b730045b4bb74af Mon Sep 17 00:00:00 2001 From: CHEN SHENGYUAN Date: Thu, 18 Dec 2025 15:35:33 +0800 Subject: enable vectorized retrieval with sparse matrix operations --- src/LinearRAG.py | 274 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 268 insertions(+), 6 deletions(-) (limited to 'src/LinearRAG.py') diff --git a/src/LinearRAG.py b/src/LinearRAG.py index 0f4f6f8..9435883 100644 --- a/src/LinearRAG.py +++ b/src/LinearRAG.py @@ -11,11 +11,22 @@ from src.ner import SpacyNER import igraph as ig import re import logging +import torch logger = logging.getLogger(__name__) + + class LinearRAG: def __init__(self, global_config): self.config = global_config logger.info(f"Initializing LinearRAG with config: {self.config}") + retrieval_method = "Vectorized Matrix-based" if self.config.use_vectorized_retrieval else "BFS Iteration" + logger.info(f"Using retrieval method: {retrieval_method}") + + # Setup device for GPU acceleration + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if self.config.use_vectorized_retrieval: + logger.info(f"Using device: {self.device} for vectorized retrieval") + self.dataset_name = global_config.dataset_name self.load_embedding_store() self.llm_model = self.config.llm_model @@ -70,7 +81,6 @@ class LinearRAG: question_info["pred_answer"] = pred_ans return retrieval_results - def retrieve(self, questions): self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys()) self.entity_embeddings = np.array(self.entity_embedding_store.embeddings) @@ -81,6 +91,19 @@ class LinearRAG: self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()} self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()} + # Precompute sparse matrices for vectorized retrieval if needed + if self.config.use_vectorized_retrieval: + logger.info("Precomputing sparse adjacency matrices for vectorized retrieval...") + self._precompute_sparse_matrices() + e2s_shape = self.entity_to_sentence_sparse.shape + s2e_shape = self.sentence_to_entity_sparse.shape + e2s_nnz = self.entity_to_sentence_sparse._nnz() + s2e_nnz = self.sentence_to_entity_sparse._nnz() + logger.info(f"Matrices built: Entity-Sentence {e2s_shape}, Sentence-Entity {s2e_shape}") + logger.info(f"E2S Sparsity: {(1 - e2s_nnz / (e2s_shape[0] * e2s_shape[1])) * 100:.2f}% (nnz={e2s_nnz})") + logger.info(f"S2E Sparsity: {(1 - s2e_nnz / (s2e_shape[0] * s2e_shape[1])) * 100:.2f}% (nnz={s2e_nnz})") + logger.info(f"Device: {self.device}") + retrieval_results = [] for question_info in tqdm(questions, desc="Retrieving"): question = question_info["question"] @@ -104,11 +127,67 @@ class LinearRAG: } retrieval_results.append(result) return retrieval_results - + + def _precompute_sparse_matrices(self): + """ + Precompute and cache sparse adjacency matrices for efficient vectorized retrieval using PyTorch. + This is called once at the beginning of retrieve() to avoid rebuilding matrices per query. + """ + num_entities = len(self.entity_hash_ids) + num_sentences = len(self.sentence_hash_ids) - + # Build entity-to-sentence matrix (Mention matrix) using COO format + entity_to_sentence_indices = [] + entity_to_sentence_values = [] + + for entity_hash_id, sentence_hash_ids in self.entity_hash_id_to_sentence_hash_ids.items(): + entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id] + for sentence_hash_id in sentence_hash_ids: + sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id] + entity_to_sentence_indices.append([entity_idx, sentence_idx]) + entity_to_sentence_values.append(1.0) + + # Build sentence-to-entity matrix + sentence_to_entity_indices = [] + sentence_to_entity_values = [] + + for sentence_hash_id, entity_hash_ids in self.sentence_hash_id_to_entity_hash_ids.items(): + sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id] + for entity_hash_id in entity_hash_ids: + entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id] + sentence_to_entity_indices.append([sentence_idx, entity_idx]) + sentence_to_entity_values.append(1.0) + + # Convert to PyTorch sparse tensors (COO format, then convert to CSR for efficiency) + if len(entity_to_sentence_indices) > 0: + e2s_indices = torch.tensor(entity_to_sentence_indices, dtype=torch.long).t() + e2s_values = torch.tensor(entity_to_sentence_values, dtype=torch.float32) + self.entity_to_sentence_sparse = torch.sparse_coo_tensor( + e2s_indices, e2s_values, (num_entities, num_sentences), device=self.device + ).coalesce() + else: + self.entity_to_sentence_sparse = torch.sparse_coo_tensor( + torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32), + (num_entities, num_sentences), device=self.device + ) + + if len(sentence_to_entity_indices) > 0: + s2e_indices = torch.tensor(sentence_to_entity_indices, dtype=torch.long).t() + s2e_values = torch.tensor(sentence_to_entity_values, dtype=torch.float32) + self.sentence_to_entity_sparse = torch.sparse_coo_tensor( + s2e_indices, s2e_values, (num_sentences, num_entities), device=self.device + ).coalesce() + else: + self.sentence_to_entity_sparse = torch.sparse_coo_tensor( + torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32), + (num_sentences, num_entities), device=self.device + ) + def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores): - entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) + if self.config.use_vectorized_retrieval: + entity_weights, actived_entities = self.calculate_entity_scores_vectorized(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) + else: + entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) passage_weights = self.calculate_passage_scores(question_embedding,actived_entities) node_weights = entity_weights + passage_weights ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights) @@ -176,6 +255,191 @@ class LinearRAG: iteration += 1 return entity_weights, actived_entities + def calculate_entity_scores_vectorized(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): + """ + GPU-accelerated vectorized version using PyTorch sparse tensors. + Uses sparse representation for both matrices and entity score vectors for maximum efficiency. + Now includes proper dynamic pruning to match BFS behavior: + - Sentence deduplication (tracks used sentences) + - Per-entity top-k sentence selection + - Proper threshold-based pruning + """ + # Initialize entity weights + entity_weights = np.zeros(len(self.graph.vs["name"])) + num_entities = len(self.entity_hash_ids) + num_sentences = len(self.sentence_hash_ids) + + # Compute all sentence similarities with the question at once + question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding + sentence_similarities_np = np.dot(self.sentence_embeddings, question_emb).flatten() + + # Convert to torch tensors and move to device + sentence_similarities = torch.from_numpy(sentence_similarities_np).float().to(self.device) + + # Track used sentences for deduplication (like BFS version) + used_sentence_mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device) + + # Initialize seed entity scores as sparse tensor + seed_indices = torch.tensor([[idx] for idx in seed_entity_indices], dtype=torch.long).t() + seed_values = torch.tensor(seed_entity_scores, dtype=torch.float32) + entity_scores_sparse = torch.sparse_coo_tensor( + seed_indices, seed_values, (num_entities,), device=self.device + ).coalesce() + + # Also maintain a dense accumulator for total scores + entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device) + entity_scores_dense.scatter_(0, torch.tensor(seed_entity_indices, device=self.device), + torch.tensor(seed_entity_scores, dtype=torch.float32, device=self.device)) + + # Initialize actived_entities + actived_entities = {} + for seed_entity_idx, seed_entity, seed_entity_hash_id, seed_entity_score in zip( + seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores + ): + actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 0) + seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id] + entity_weights[seed_entity_node_idx] = seed_entity_score + + current_entity_scores_sparse = entity_scores_sparse + + # Iterative matrix-based propagation using sparse matrices on GPU + for iteration in range(1, self.config.max_iterations): + # Convert sparse tensor to dense for threshold operation + current_entity_scores_dense = current_entity_scores_sparse.to_dense() + + # Apply threshold to current scores + current_entity_scores_dense = torch.where( + current_entity_scores_dense >= self.config.iteration_threshold, + current_entity_scores_dense, + torch.zeros_like(current_entity_scores_dense) + ) + + # Get non-zero indices for sparse representation + nonzero_mask = current_entity_scores_dense > 0 + nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).squeeze(-1) + + if len(nonzero_indices) == 0: + break + + # Extract non-zero values and create sparse tensor + nonzero_values = current_entity_scores_dense[nonzero_indices] + current_entity_scores_sparse = torch.sparse_coo_tensor( + nonzero_indices.unsqueeze(0), nonzero_values, (num_entities,), device=self.device + ).coalesce() + + # Step 1: Sparse entity scores @ Sparse E2S matrix + # Convert sparse vector to 2D for matrix multiplication + current_scores_2d = torch.sparse_coo_tensor( + torch.stack([nonzero_indices, torch.zeros_like(nonzero_indices)]), + nonzero_values, + (num_entities, 1), + device=self.device + ).coalesce() + + # E @ E2S -> sentence activation scores (sparse @ sparse = dense) + sentence_activation = torch.sparse.mm( + self.entity_to_sentence_sparse.t(), + current_scores_2d + ) + # Convert to dense before squeeze to avoid CUDA sparse tensor issues + if sentence_activation.is_sparse: + sentence_activation = sentence_activation.to_dense() + sentence_activation = sentence_activation.squeeze() + + # Apply sentence deduplication: mask out used sentences + sentence_activation = torch.where( + used_sentence_mask, + torch.zeros_like(sentence_activation), + sentence_activation + ) + + # Step 2: Apply sentence similarities (element-wise on dense vector) + weighted_sentence_scores = sentence_activation * sentence_similarities + + # Implement per-entity top-k sentence selection (more aggressive pruning) + # For vectorized efficiency, we use a tighter global approximation + num_active = len(nonzero_indices) + if num_active > 0 and self.config.top_k_sentence > 0: + # Calculate adaptive k based on number of active entities + # Use per-entity top-k approximation: num_active * top_k_sentence + k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores)) + if k > 0: + # Get top-k sentences + top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k) + # Zero out all non-top-k sentences + mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool) + mask[top_k_indices] = True + weighted_sentence_scores = torch.where( + mask, + weighted_sentence_scores, + torch.zeros_like(weighted_sentence_scores) + ) + + # Mark these sentences as used for deduplication + used_sentence_mask[top_k_indices] = True + + # Step 3: Weighted sentences @ S2E -> propagate to next entities + # Convert to sparse for more efficient computation + weighted_nonzero_mask = weighted_sentence_scores > 0 + weighted_nonzero_indices = torch.nonzero(weighted_nonzero_mask, as_tuple=False).squeeze(-1) + + if len(weighted_nonzero_indices) > 0: + weighted_nonzero_values = weighted_sentence_scores[weighted_nonzero_indices] + weighted_scores_2d = torch.sparse_coo_tensor( + torch.stack([weighted_nonzero_indices, torch.zeros_like(weighted_nonzero_indices)]), + weighted_nonzero_values, + (num_sentences, 1), + device=self.device + ).coalesce() + + next_entity_scores_result = torch.sparse.mm( + self.sentence_to_entity_sparse.t(), + weighted_scores_2d + ) + # Convert to dense before squeeze to avoid CUDA sparse tensor issues + if next_entity_scores_result.is_sparse: + next_entity_scores_result = next_entity_scores_result.to_dense() + next_entity_scores_dense = next_entity_scores_result.squeeze() + else: + next_entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device) + + # Update entity scores (accumulate in dense format) + entity_scores_dense += next_entity_scores_dense + + # Update actived_entities dictionary (only for entities above threshold) + next_entity_scores_np = next_entity_scores_dense.cpu().numpy() + active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0] + for entity_idx in active_indices: + score = next_entity_scores_np[entity_idx] + entity_hash_id = self.entity_hash_ids[entity_idx] + if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score: + actived_entities[entity_hash_id] = (entity_idx, float(score), iteration) + + # Prepare sparse tensor for next iteration + next_nonzero_mask = next_entity_scores_dense > 0 + next_nonzero_indices = torch.nonzero(next_nonzero_mask, as_tuple=False).squeeze(-1) + if len(next_nonzero_indices) > 0: + next_nonzero_values = next_entity_scores_dense[next_nonzero_indices] + current_entity_scores_sparse = torch.sparse_coo_tensor( + next_nonzero_indices.unsqueeze(0), next_nonzero_values, + (num_entities,), device=self.device + ).coalesce() + else: + break + + # Convert back to numpy for final processing + entity_scores_final = entity_scores_dense.cpu().numpy() + + # Map entity scores to graph node weights (only for non-zero scores) + nonzero_indices = np.where(entity_scores_final > 0)[0] + for entity_idx in nonzero_indices: + score = entity_scores_final[entity_idx] + entity_hash_id = self.entity_hash_ids[entity_idx] + entity_node_idx = self.node_name_to_vertex_idx[entity_hash_id] + entity_weights[entity_node_idx] = float(score) + + return entity_weights, actived_entities + def calculate_passage_scores(self, question_embedding, actived_entities): passage_weights = np.zeros(len(self.graph.vs["name"])) dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding) @@ -305,8 +569,6 @@ class LinearRAG: self.graph.add_edges(edges) self.graph.es['weight'] = weights - - def add_entity_to_passage_edges(self, passage_hash_id_to_entities): passage_to_entity_count ={} passage_to_all_score = defaultdict(int) -- cgit v1.2.3