summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LinearRAG_upload/src/LinearRAG.py623
1 files changed, 623 insertions, 0 deletions
diff --git a/LinearRAG_upload/src/LinearRAG.py b/LinearRAG_upload/src/LinearRAG.py
new file mode 100644
index 0000000..76b65ad
--- /dev/null
+++ b/LinearRAG_upload/src/LinearRAG.py
@@ -0,0 +1,623 @@
+from src.embedding_store import EmbeddingStore
+from src.utils import min_max_normalize
+import os
+import json
+from collections import defaultdict
+import numpy as np
+import math
+from concurrent.futures import ThreadPoolExecutor
+from tqdm import tqdm
+from src.ner import SpacyNER
+import igraph as ig
+import re
+import logging
+import torch
+logger = logging.getLogger(__name__)
+
+
+class LinearRAG:
+ def __init__(self, global_config):
+ self.config = global_config
+ logger.info(f"Initializing LinearRAG with config: {self.config}")
+ retrieval_method = "Vectorized Matrix-based" if self.config.use_vectorized_retrieval else "BFS Iteration"
+ logger.info(f"Using retrieval method: {retrieval_method}")
+
+ # Setup device for GPU acceleration
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ if self.config.use_vectorized_retrieval:
+ logger.info(f"Using device: {self.device} for vectorized retrieval")
+
+ self.dataset_name = global_config.dataset_name
+ self.load_embedding_store()
+ self.llm_model = self.config.llm_model
+ self.spacy_ner = SpacyNER(self.config.spacy_model)
+ self.graph = ig.Graph(directed=False)
+
+ def load_embedding_store(self):
+ self.passage_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "passage_embedding.parquet"), batch_size=self.config.batch_size, namespace="passage")
+ self.entity_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "entity_embedding.parquet"), batch_size=self.config.batch_size, namespace="entity")
+ self.sentence_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "sentence_embedding.parquet"), batch_size=self.config.batch_size, namespace="sentence")
+
+ def load_existing_data(self,passage_hash_ids):
+ self.ner_results_path = os.path.join(self.config.working_dir,self.dataset_name, "ner_results.json")
+ if os.path.exists(self.ner_results_path):
+ existing_ner_reuslts = json.load(open(self.ner_results_path))
+ existing_passage_hash_id_to_entities = existing_ner_reuslts["passage_hash_id_to_entities"]
+ existing_sentence_to_entities = existing_ner_reuslts["sentence_to_entities"]
+ existing_passage_hash_ids = set(existing_passage_hash_id_to_entities.keys())
+ new_passage_hash_ids = set(passage_hash_ids) - existing_passage_hash_ids
+ return existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_ids
+ else:
+ return {}, {}, passage_hash_ids
+
+ def qa(self, questions):
+ retrieval_results = self.retrieve(questions)
+ system_prompt = f"""As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations."""
+ all_messages = []
+ for retrieval_result in retrieval_results:
+ question = retrieval_result["question"]
+ sorted_passage = retrieval_result["sorted_passage"]
+ prompt_user = """"""
+ for passage in sorted_passage:
+ prompt_user += f"{passage}\n"
+ prompt_user += f"Question: {question}\n Thought: "
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": prompt_user}
+ ]
+ all_messages.append(messages)
+ with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
+ all_qa_results = list(tqdm(
+ executor.map(self.llm_model.infer, all_messages),
+ total=len(all_messages),
+ desc="QA Reading (Parallel)"
+ ))
+
+ for qa_result,question_info in zip(all_qa_results,retrieval_results):
+ try:
+ pred_ans = qa_result.split('Answer:')[1].strip()
+ except:
+ pred_ans = qa_result
+ question_info["pred_answer"] = pred_ans
+ return retrieval_results
+
+ def retrieve(self, questions):
+ self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys())
+ self.entity_embeddings = np.array(self.entity_embedding_store.embeddings)
+ self.passage_hash_ids = list(self.passage_embedding_store.hash_id_to_text.keys())
+ self.passage_embeddings = np.array(self.passage_embedding_store.embeddings)
+ self.sentence_hash_ids = list(self.sentence_embedding_store.hash_id_to_text.keys())
+ self.sentence_embeddings = np.array(self.sentence_embedding_store.embeddings)
+ self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
+ self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()}
+
+ # Precompute sparse matrices for vectorized retrieval if needed
+ if self.config.use_vectorized_retrieval:
+ logger.info("Precomputing sparse adjacency matrices for vectorized retrieval...")
+ self._precompute_sparse_matrices()
+ e2s_shape = self.entity_to_sentence_sparse.shape
+ s2e_shape = self.sentence_to_entity_sparse.shape
+ e2s_nnz = self.entity_to_sentence_sparse._nnz()
+ s2e_nnz = self.sentence_to_entity_sparse._nnz()
+ logger.info(f"Matrices built: Entity-Sentence {e2s_shape}, Sentence-Entity {s2e_shape}")
+ logger.info(f"E2S Sparsity: {(1 - e2s_nnz / (e2s_shape[0] * e2s_shape[1])) * 100:.2f}% (nnz={e2s_nnz})")
+ logger.info(f"S2E Sparsity: {(1 - s2e_nnz / (s2e_shape[0] * s2e_shape[1])) * 100:.2f}% (nnz={s2e_nnz})")
+ logger.info(f"Device: {self.device}")
+
+ retrieval_results = []
+ for question_info in tqdm(questions, desc="Retrieving"):
+ question = question_info["question"]
+ question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
+ seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question)
+ if len(seed_entities) != 0:
+ sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k]
+ final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
+ final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids]
+ else:
+ sorted_passage_indices,sorted_passage_scores = self.dense_passage_retrieval(question_embedding)
+ final_passage_indices = sorted_passage_indices[:self.config.retrieval_top_k]
+ final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
+ final_passages = [self.passage_embedding_store.texts[idx] for idx in final_passage_indices]
+ result = {
+ "question": question,
+ "sorted_passage": final_passages,
+ "sorted_passage_scores": final_passage_scores,
+ "gold_answer": question_info["answer"]
+ }
+ retrieval_results.append(result)
+ return retrieval_results
+
+ def _precompute_sparse_matrices(self):
+ """
+ Precompute and cache sparse adjacency matrices for efficient vectorized retrieval using PyTorch.
+ This is called once at the beginning of retrieve() to avoid rebuilding matrices per query.
+ """
+ num_entities = len(self.entity_hash_ids)
+ num_sentences = len(self.sentence_hash_ids)
+
+ # Build entity-to-sentence matrix (Mention matrix) using COO format
+ entity_to_sentence_indices = []
+ entity_to_sentence_values = []
+
+ for entity_hash_id, sentence_hash_ids in self.entity_hash_id_to_sentence_hash_ids.items():
+ entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id]
+ for sentence_hash_id in sentence_hash_ids:
+ sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id]
+ entity_to_sentence_indices.append([entity_idx, sentence_idx])
+ entity_to_sentence_values.append(1.0)
+
+ # Build sentence-to-entity matrix
+ sentence_to_entity_indices = []
+ sentence_to_entity_values = []
+
+ for sentence_hash_id, entity_hash_ids in self.sentence_hash_id_to_entity_hash_ids.items():
+ sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id]
+ for entity_hash_id in entity_hash_ids:
+ entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id]
+ sentence_to_entity_indices.append([sentence_idx, entity_idx])
+ sentence_to_entity_values.append(1.0)
+
+ # Convert to PyTorch sparse tensors (COO format, then convert to CSR for efficiency)
+ if len(entity_to_sentence_indices) > 0:
+ e2s_indices = torch.tensor(entity_to_sentence_indices, dtype=torch.long).t()
+ e2s_values = torch.tensor(entity_to_sentence_values, dtype=torch.float32)
+ self.entity_to_sentence_sparse = torch.sparse_coo_tensor(
+ e2s_indices, e2s_values, (num_entities, num_sentences), device=self.device
+ ).coalesce()
+ else:
+ self.entity_to_sentence_sparse = torch.sparse_coo_tensor(
+ torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32),
+ (num_entities, num_sentences), device=self.device
+ )
+
+ if len(sentence_to_entity_indices) > 0:
+ s2e_indices = torch.tensor(sentence_to_entity_indices, dtype=torch.long).t()
+ s2e_values = torch.tensor(sentence_to_entity_values, dtype=torch.float32)
+ self.sentence_to_entity_sparse = torch.sparse_coo_tensor(
+ s2e_indices, s2e_values, (num_sentences, num_entities), device=self.device
+ ).coalesce()
+ else:
+ self.sentence_to_entity_sparse = torch.sparse_coo_tensor(
+ torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32),
+ (num_sentences, num_entities), device=self.device
+ )
+
+ def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores):
+ if self.config.use_vectorized_retrieval:
+ entity_weights, actived_entities = self.calculate_entity_scores_vectorized(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ else:
+ entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ passage_weights = self.calculate_passage_scores(question_embedding,actived_entities)
+ node_weights = entity_weights + passage_weights
+ ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights)
+ return ppr_sorted_passage_indices,ppr_sorted_passage_scores
+
+ def run_ppr(self, node_weights):
+ reset_prob = np.where(np.isnan(node_weights) | (node_weights < 0), 0, node_weights)
+ pagerank_scores = self.graph.personalized_pagerank(
+ vertices=range(len(self.node_name_to_vertex_idx)),
+ damping=self.config.damping,
+ directed=False,
+ weights='weight',
+ reset=reset_prob,
+ implementation='prpack'
+ )
+
+ doc_scores = np.array([pagerank_scores[idx] for idx in self.passage_node_indices])
+ sorted_indices_in_doc_scores = np.argsort(doc_scores)[::-1]
+ sorted_passage_scores = doc_scores[sorted_indices_in_doc_scores]
+
+ sorted_passage_hash_ids = [
+ self.vertex_idx_to_node_name[self.passage_node_indices[i]]
+ for i in sorted_indices_in_doc_scores
+ ]
+
+ return sorted_passage_hash_ids, sorted_passage_scores.tolist()
+
+ def calculate_entity_scores(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ actived_entities = {}
+ entity_weights = np.zeros(len(self.graph.vs["name"]))
+ for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1)
+ seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
+ entity_weights[seed_entity_node_idx] = seed_entity_score
+ used_sentence_hash_ids = set()
+ current_entities = actived_entities.copy()
+ iteration = 1
+ while len(current_entities) > 0 and iteration < self.config.max_iterations:
+ new_entities = {}
+ for entity_hash_id, (entity_id, entity_score, tier) in current_entities.items():
+ if entity_score < self.config.iteration_threshold:
+ continue
+ sentence_hash_ids = [sid for sid in list(self.entity_hash_id_to_sentence_hash_ids[entity_hash_id]) if sid not in used_sentence_hash_ids]
+ if not sentence_hash_ids:
+ continue
+ sentence_indices = [self.sentence_embedding_store.hash_id_to_idx[sid] for sid in sentence_hash_ids]
+ sentence_embeddings = self.sentence_embeddings[sentence_indices]
+ question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding
+ sentence_similarities = np.dot(sentence_embeddings, question_emb).flatten()
+ top_sentence_indices = np.argsort(sentence_similarities)[::-1][:self.config.top_k_sentence]
+ for top_sentence_index in top_sentence_indices:
+ top_sentence_hash_id = sentence_hash_ids[top_sentence_index]
+ top_sentence_score = sentence_similarities[top_sentence_index]
+ used_sentence_hash_ids.add(top_sentence_hash_id)
+ entity_hash_ids_in_sentence = self.sentence_hash_id_to_entity_hash_ids[top_sentence_hash_id]
+ for next_entity_hash_id in entity_hash_ids_in_sentence:
+ next_entity_score = entity_score * top_sentence_score
+ if next_entity_score < self.config.iteration_threshold:
+ continue
+ next_enitity_node_idx = self.node_name_to_vertex_idx[next_entity_hash_id]
+ entity_weights[next_enitity_node_idx] += next_entity_score
+ new_entities[next_entity_hash_id] = (next_enitity_node_idx, next_entity_score, iteration+1)
+ actived_entities.update(new_entities)
+ current_entities = new_entities.copy()
+ iteration += 1
+ return entity_weights, actived_entities
+
+ def calculate_entity_scores_vectorized(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ """
+ GPU-accelerated vectorized version using PyTorch sparse tensors.
+ Uses sparse representation for both matrices and entity score vectors for maximum efficiency.
+ Now includes proper dynamic pruning to match BFS behavior:
+ - Sentence deduplication (tracks used sentences)
+ - Per-entity top-k sentence selection
+ - Proper threshold-based pruning
+ """
+ # Initialize entity weights
+ entity_weights = np.zeros(len(self.graph.vs["name"]))
+ num_entities = len(self.entity_hash_ids)
+ num_sentences = len(self.sentence_hash_ids)
+
+ # Compute all sentence similarities with the question at once
+ question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding
+ sentence_similarities_np = np.dot(self.sentence_embeddings, question_emb).flatten()
+
+ # Convert to torch tensors and move to device
+ sentence_similarities = torch.from_numpy(sentence_similarities_np).float().to(self.device)
+
+ # Track used sentences for deduplication (like BFS version)
+ used_sentence_mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device)
+
+ # Initialize seed entity scores as sparse tensor
+ seed_indices = torch.tensor([[idx] for idx in seed_entity_indices], dtype=torch.long).t()
+ seed_values = torch.tensor(seed_entity_scores, dtype=torch.float32)
+ entity_scores_sparse = torch.sparse_coo_tensor(
+ seed_indices, seed_values, (num_entities,), device=self.device
+ ).coalesce()
+
+ # Also maintain a dense accumulator for total scores
+ entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device)
+ entity_scores_dense.scatter_(0, torch.tensor(seed_entity_indices, device=self.device),
+ torch.tensor(seed_entity_scores, dtype=torch.float32, device=self.device))
+
+ # Initialize actived_entities
+ actived_entities = {}
+ for seed_entity_idx, seed_entity, seed_entity_hash_id, seed_entity_score in zip(
+ seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores
+ ):
+ actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 0)
+ seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
+ entity_weights[seed_entity_node_idx] = seed_entity_score
+
+ current_entity_scores_sparse = entity_scores_sparse
+
+ # Iterative matrix-based propagation using sparse matrices on GPU
+ for iteration in range(1, self.config.max_iterations):
+ # Convert sparse tensor to dense for threshold operation
+ current_entity_scores_dense = current_entity_scores_sparse.to_dense()
+
+ # Apply threshold to current scores
+ current_entity_scores_dense = torch.where(
+ current_entity_scores_dense >= self.config.iteration_threshold,
+ current_entity_scores_dense,
+ torch.zeros_like(current_entity_scores_dense)
+ )
+
+ # Get non-zero indices for sparse representation
+ nonzero_mask = current_entity_scores_dense > 0
+ nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).squeeze(-1)
+
+ if len(nonzero_indices) == 0:
+ break
+
+ # Extract non-zero values and create sparse tensor
+ nonzero_values = current_entity_scores_dense[nonzero_indices]
+ current_entity_scores_sparse = torch.sparse_coo_tensor(
+ nonzero_indices.unsqueeze(0), nonzero_values, (num_entities,), device=self.device
+ ).coalesce()
+
+ # Step 1: Sparse entity scores @ Sparse E2S matrix
+ # Convert sparse vector to 2D for matrix multiplication
+ current_scores_2d = torch.sparse_coo_tensor(
+ torch.stack([nonzero_indices, torch.zeros_like(nonzero_indices)]),
+ nonzero_values,
+ (num_entities, 1),
+ device=self.device
+ ).coalesce()
+
+ # E @ E2S -> sentence activation scores (sparse @ sparse = dense)
+ sentence_activation = torch.sparse.mm(
+ self.entity_to_sentence_sparse.t(),
+ current_scores_2d
+ )
+ # Convert to dense before squeeze to avoid CUDA sparse tensor issues
+ if sentence_activation.is_sparse:
+ sentence_activation = sentence_activation.to_dense()
+ sentence_activation = sentence_activation.squeeze()
+
+ # Apply sentence deduplication: mask out used sentences
+ sentence_activation = torch.where(
+ used_sentence_mask,
+ torch.zeros_like(sentence_activation),
+ sentence_activation
+ )
+
+ # Step 2: Apply sentence similarities (element-wise on dense vector)
+ weighted_sentence_scores = sentence_activation * sentence_similarities
+
+ # Implement per-entity top-k sentence selection (more aggressive pruning)
+ # For vectorized efficiency, we use a tighter global approximation
+ num_active = len(nonzero_indices)
+ if num_active > 0 and self.config.top_k_sentence > 0:
+ # Calculate adaptive k based on number of active entities
+ # Use per-entity top-k approximation: num_active * top_k_sentence
+ k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores))
+ if k > 0:
+ # Get top-k sentences
+ top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k)
+ # Zero out all non-top-k sentences
+ mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool)
+ mask[top_k_indices] = True
+ weighted_sentence_scores = torch.where(
+ mask,
+ weighted_sentence_scores,
+ torch.zeros_like(weighted_sentence_scores)
+ )
+
+ # Mark these sentences as used for deduplication
+ used_sentence_mask[top_k_indices] = True
+
+ # Step 3: Weighted sentences @ S2E -> propagate to next entities
+ # Convert to sparse for more efficient computation
+ weighted_nonzero_mask = weighted_sentence_scores > 0
+ weighted_nonzero_indices = torch.nonzero(weighted_nonzero_mask, as_tuple=False).squeeze(-1)
+
+ if len(weighted_nonzero_indices) > 0:
+ weighted_nonzero_values = weighted_sentence_scores[weighted_nonzero_indices]
+ weighted_scores_2d = torch.sparse_coo_tensor(
+ torch.stack([weighted_nonzero_indices, torch.zeros_like(weighted_nonzero_indices)]),
+ weighted_nonzero_values,
+ (num_sentences, 1),
+ device=self.device
+ ).coalesce()
+
+ next_entity_scores_result = torch.sparse.mm(
+ self.sentence_to_entity_sparse.t(),
+ weighted_scores_2d
+ )
+ # Convert to dense before squeeze to avoid CUDA sparse tensor issues
+ if next_entity_scores_result.is_sparse:
+ next_entity_scores_result = next_entity_scores_result.to_dense()
+ next_entity_scores_dense = next_entity_scores_result.squeeze()
+ else:
+ next_entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device)
+
+ # Update entity scores (accumulate in dense format)
+ entity_scores_dense += next_entity_scores_dense
+
+ # Update actived_entities dictionary (only for entities above threshold)
+ next_entity_scores_np = next_entity_scores_dense.cpu().numpy()
+ active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0]
+ for entity_idx in active_indices:
+ score = next_entity_scores_np[entity_idx]
+ entity_hash_id = self.entity_hash_ids[entity_idx]
+ if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score:
+ actived_entities[entity_hash_id] = (entity_idx, float(score), iteration)
+
+ # Prepare sparse tensor for next iteration
+ next_nonzero_mask = next_entity_scores_dense > 0
+ next_nonzero_indices = torch.nonzero(next_nonzero_mask, as_tuple=False).squeeze(-1)
+ if len(next_nonzero_indices) > 0:
+ next_nonzero_values = next_entity_scores_dense[next_nonzero_indices]
+ current_entity_scores_sparse = torch.sparse_coo_tensor(
+ next_nonzero_indices.unsqueeze(0), next_nonzero_values,
+ (num_entities,), device=self.device
+ ).coalesce()
+ else:
+ break
+
+ # Convert back to numpy for final processing
+ entity_scores_final = entity_scores_dense.cpu().numpy()
+
+ # Map entity scores to graph node weights (only for non-zero scores)
+ nonzero_indices = np.where(entity_scores_final > 0)[0]
+ for entity_idx in nonzero_indices:
+ score = entity_scores_final[entity_idx]
+ entity_hash_id = self.entity_hash_ids[entity_idx]
+ entity_node_idx = self.node_name_to_vertex_idx[entity_hash_id]
+ entity_weights[entity_node_idx] = float(score)
+
+ return entity_weights, actived_entities
+
+ def calculate_passage_scores(self, question_embedding, actived_entities):
+ passage_weights = np.zeros(len(self.graph.vs["name"]))
+ dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding)
+ dpr_passage_scores = min_max_normalize(dpr_passage_scores)
+ for i, dpr_passage_index in enumerate(dpr_passage_indices):
+ total_entity_bonus = 0
+ passage_hash_id = self.passage_embedding_store.hash_ids[dpr_passage_index]
+ dpr_passage_score = dpr_passage_scores[i]
+ passage_text_lower = self.passage_embedding_store.hash_id_to_text[passage_hash_id].lower()
+ for entity_hash_id, (entity_id, entity_score, tier) in actived_entities.items():
+ entity_lower = self.entity_embedding_store.hash_id_to_text[entity_hash_id].lower()
+ entity_occurrences = passage_text_lower.count(entity_lower)
+ if entity_occurrences > 0:
+ denom = tier if tier >= 1 else 1
+ entity_bonus = entity_score * math.log(1 + entity_occurrences) / denom
+ total_entity_bonus += entity_bonus
+ passage_score = self.config.passage_ratio * dpr_passage_score + math.log(1 + total_entity_bonus)
+ passage_node_idx = self.node_name_to_vertex_idx[passage_hash_id]
+ passage_weights[passage_node_idx] = passage_score * self.config.passage_node_weight
+ return passage_weights
+
+ def dense_passage_retrieval(self, question_embedding):
+ question_emb = question_embedding.reshape(1, -1)
+ question_passage_similarities = np.dot(self.passage_embeddings, question_emb.T).flatten()
+ sorted_passage_indices = np.argsort(question_passage_similarities)[::-1]
+ sorted_passage_scores = question_passage_similarities[sorted_passage_indices].tolist()
+ return sorted_passage_indices, sorted_passage_scores
+
+ def get_seed_entities(self, question):
+ question_entities = list(self.spacy_ner.question_ner(question))
+ if len(question_entities) == 0:
+ return [],[],[],[]
+ question_entity_embeddings = self.config.embedding_model.encode(question_entities,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
+ similarities = np.dot(self.entity_embeddings, question_entity_embeddings.T)
+ seed_entity_indices = []
+ seed_entity_texts = []
+ seed_entity_hash_ids = []
+ seed_entity_scores = []
+ for query_entity_idx in range(len(question_entities)):
+ entity_scores = similarities[:, query_entity_idx]
+ best_entity_idx = np.argmax(entity_scores)
+ best_entity_score = entity_scores[best_entity_idx]
+ best_entity_hash_id = self.entity_hash_ids[best_entity_idx]
+ best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id]
+ seed_entity_indices.append(best_entity_idx)
+ seed_entity_texts.append(best_entity_text)
+ seed_entity_hash_ids.append(best_entity_hash_id)
+ seed_entity_scores.append(best_entity_score)
+ return seed_entity_indices, seed_entity_texts, seed_entity_hash_ids, seed_entity_scores
+
+ def index(self, passages):
+ self.node_to_node_stats = defaultdict(dict)
+ self.entity_to_sentence_stats = defaultdict(dict)
+ self.passage_embedding_store.insert_text(passages)
+ hash_id_to_passage = self.passage_embedding_store.get_hash_id_to_text()
+ existing_passage_hash_id_to_entities,existing_sentence_to_entities, new_passage_hash_ids = self.load_existing_data(hash_id_to_passage.keys())
+ if len(new_passage_hash_ids) > 0:
+ new_hash_id_to_passage = {k : hash_id_to_passage[k] for k in new_passage_hash_ids}
+ new_passage_hash_id_to_entities,new_sentence_to_entities = self.spacy_ner.batch_ner(new_hash_id_to_passage, self.config.max_workers)
+ self.merge_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities)
+ self.save_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities)
+ entity_nodes, sentence_nodes,passage_hash_id_to_entities,self.entity_to_sentence,self.sentence_to_entity = self.extract_nodes_and_edges(existing_passage_hash_id_to_entities, existing_sentence_to_entities)
+ self.sentence_embedding_store.insert_text(list(sentence_nodes))
+ self.entity_embedding_store.insert_text(list(entity_nodes))
+ self.entity_hash_id_to_sentence_hash_ids = {}
+ for entity, sentence in self.entity_to_sentence.items():
+ entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity]
+ self.entity_hash_id_to_sentence_hash_ids[entity_hash_id] = [self.sentence_embedding_store.text_to_hash_id[s] for s in sentence]
+ self.sentence_hash_id_to_entity_hash_ids = {}
+ for sentence, entities in self.sentence_to_entity.items():
+ sentence_hash_id = self.sentence_embedding_store.text_to_hash_id[sentence]
+ self.sentence_hash_id_to_entity_hash_ids[sentence_hash_id] = [self.entity_embedding_store.text_to_hash_id[e] for e in entities]
+ self.add_entity_to_passage_edges(passage_hash_id_to_entities)
+ self.add_adjacent_passage_edges()
+ self.augment_graph()
+ output_graphml_path = os.path.join(self.config.working_dir,self.dataset_name, "LinearRAG.graphml")
+ os.makedirs(os.path.dirname(output_graphml_path), exist_ok=True)
+ self.graph.write_graphml(output_graphml_path)
+
+ def add_adjacent_passage_edges(self):
+ passage_id_to_text = self.passage_embedding_store.get_hash_id_to_text()
+ index_pattern = re.compile(r'^(\d+):')
+ indexed_items = [
+ (int(match.group(1)), node_key)
+ for node_key, text in passage_id_to_text.items()
+ if (match := index_pattern.match(text.strip()))
+ ]
+ indexed_items.sort(key=lambda x: x[0])
+ for i in range(len(indexed_items) - 1):
+ current_node = indexed_items[i][1]
+ next_node = indexed_items[i + 1][1]
+ self.node_to_node_stats[current_node][next_node] = 1.0
+
+ def augment_graph(self):
+ self.add_nodes()
+ self.add_edges()
+
+ def add_nodes(self):
+ existing_nodes = {v["name"]: v for v in self.graph.vs if "name" in v.attributes()}
+ entity_hash_id_to_text = self.entity_embedding_store.get_hash_id_to_text()
+ passage_hash_id_to_text = self.passage_embedding_store.get_hash_id_to_text()
+ all_hash_id_to_text = {**entity_hash_id_to_text, **passage_hash_id_to_text}
+
+ passage_hash_ids = set(passage_hash_id_to_text.keys())
+
+ for hash_id, text in all_hash_id_to_text.items():
+ if hash_id not in existing_nodes:
+ self.graph.add_vertex(name=hash_id, content=text)
+
+ self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
+ self.passage_node_indices = [
+ self.node_name_to_vertex_idx[passage_id]
+ for passage_id in passage_hash_ids
+ if passage_id in self.node_name_to_vertex_idx
+ ]
+
+ def add_edges(self):
+ existing_edges = set()
+ for edge in self.graph.es:
+ source_name = self.graph.vs[edge.source]["name"]
+ target_name = self.graph.vs[edge.target]["name"]
+ existing_edges.add(frozenset([source_name, target_name]))
+
+ new_edges = []
+ new_weights = []
+
+ for node_hash_id, node_to_node_stats in self.node_to_node_stats.items():
+ for neighbor_hash_id, weight in node_to_node_stats.items():
+ if node_hash_id == neighbor_hash_id:
+ continue
+ edge_key = frozenset([node_hash_id, neighbor_hash_id])
+ if edge_key not in existing_edges:
+ new_edges.append((node_hash_id, neighbor_hash_id))
+ new_weights.append(weight)
+ existing_edges.add(edge_key)
+
+ if new_edges:
+ self.graph.add_edges(new_edges)
+ start_idx = len(self.graph.es) - len(new_edges)
+ for i, weight in enumerate(new_weights):
+ self.graph.es[start_idx + i]['weight'] = weight
+
+ def add_entity_to_passage_edges(self, passage_hash_id_to_entities):
+ passage_to_entity_count ={}
+ passage_to_all_score = defaultdict(int)
+ for passage_hash_id, entities in passage_hash_id_to_entities.items():
+ passage = self.passage_embedding_store.hash_id_to_text[passage_hash_id]
+ for entity in entities:
+ entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity]
+ count = passage.count(entity)
+ passage_to_entity_count[(passage_hash_id, entity_hash_id)] = count
+ passage_to_all_score[passage_hash_id] += count
+ for (passage_hash_id, entity_hash_id), count in passage_to_entity_count.items():
+ score = count / passage_to_all_score[passage_hash_id]
+ self.node_to_node_stats[passage_hash_id][entity_hash_id] = score
+
+ def extract_nodes_and_edges(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities):
+ entity_nodes = set()
+ sentence_nodes = set()
+ passage_hash_id_to_entities = defaultdict(set)
+ entity_to_sentence= defaultdict(set)
+ sentence_to_entity = defaultdict(set)
+ for passage_hash_id, entities in existing_passage_hash_id_to_entities.items():
+ for entity in entities:
+ entity_nodes.add(entity)
+ passage_hash_id_to_entities[passage_hash_id].add(entity)
+ for sentence,entities in existing_sentence_to_entities.items():
+ sentence_nodes.add(sentence)
+ for entity in entities:
+ entity_to_sentence[entity].add(sentence)
+ sentence_to_entity[sentence].add(entity)
+ return entity_nodes, sentence_nodes, passage_hash_id_to_entities, entity_to_sentence, sentence_to_entity
+
+ def merge_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities):
+ existing_passage_hash_id_to_entities.update(new_passage_hash_id_to_entities)
+ existing_sentence_to_entities.update(new_sentence_to_entities)
+ return existing_passage_hash_id_to_entities, existing_sentence_to_entities
+
+ def save_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities):
+ with open(self.ner_results_path, "w") as f:
+ json.dump({"passage_hash_id_to_entities": existing_passage_hash_id_to_entities, "sentence_to_entities": existing_sentence_to_entities}, f)