From 97f593471662a194e57bd07520b898b965f6b013 Mon Sep 17 00:00:00 2001 From: LuyaoZhuang Date: Sat, 3 Jan 2026 08:04:16 -0500 Subject: remove wrongly placed folder --- LinearRAG_upload/src/LinearRAG.py | 623 -------------------------------------- 1 file changed, 623 deletions(-) delete mode 100644 LinearRAG_upload/src/LinearRAG.py (limited to 'LinearRAG_upload') diff --git a/LinearRAG_upload/src/LinearRAG.py b/LinearRAG_upload/src/LinearRAG.py deleted file mode 100644 index 76b65ad..0000000 --- a/LinearRAG_upload/src/LinearRAG.py +++ /dev/null @@ -1,623 +0,0 @@ -from src.embedding_store import EmbeddingStore -from src.utils import min_max_normalize -import os -import json -from collections import defaultdict -import numpy as np -import math -from concurrent.futures import ThreadPoolExecutor -from tqdm import tqdm -from src.ner import SpacyNER -import igraph as ig -import re -import logging -import torch -logger = logging.getLogger(__name__) - - -class LinearRAG: - def __init__(self, global_config): - self.config = global_config - logger.info(f"Initializing LinearRAG with config: {self.config}") - retrieval_method = "Vectorized Matrix-based" if self.config.use_vectorized_retrieval else "BFS Iteration" - logger.info(f"Using retrieval method: {retrieval_method}") - - # Setup device for GPU acceleration - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if self.config.use_vectorized_retrieval: - logger.info(f"Using device: {self.device} for vectorized retrieval") - - self.dataset_name = global_config.dataset_name - self.load_embedding_store() - self.llm_model = self.config.llm_model - self.spacy_ner = SpacyNER(self.config.spacy_model) - self.graph = ig.Graph(directed=False) - - def load_embedding_store(self): - self.passage_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "passage_embedding.parquet"), batch_size=self.config.batch_size, namespace="passage") - self.entity_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "entity_embedding.parquet"), batch_size=self.config.batch_size, namespace="entity") - self.sentence_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "sentence_embedding.parquet"), batch_size=self.config.batch_size, namespace="sentence") - - def load_existing_data(self,passage_hash_ids): - self.ner_results_path = os.path.join(self.config.working_dir,self.dataset_name, "ner_results.json") - if os.path.exists(self.ner_results_path): - existing_ner_reuslts = json.load(open(self.ner_results_path)) - existing_passage_hash_id_to_entities = existing_ner_reuslts["passage_hash_id_to_entities"] - existing_sentence_to_entities = existing_ner_reuslts["sentence_to_entities"] - existing_passage_hash_ids = set(existing_passage_hash_id_to_entities.keys()) - new_passage_hash_ids = set(passage_hash_ids) - existing_passage_hash_ids - return existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_ids - else: - return {}, {}, passage_hash_ids - - def qa(self, questions): - retrieval_results = self.retrieve(questions) - system_prompt = f"""As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations.""" - all_messages = [] - for retrieval_result in retrieval_results: - question = retrieval_result["question"] - sorted_passage = retrieval_result["sorted_passage"] - prompt_user = """""" - for passage in sorted_passage: - prompt_user += f"{passage}\n" - prompt_user += f"Question: {question}\n Thought: " - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt_user} - ] - all_messages.append(messages) - with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: - all_qa_results = list(tqdm( - executor.map(self.llm_model.infer, all_messages), - total=len(all_messages), - desc="QA Reading (Parallel)" - )) - - for qa_result,question_info in zip(all_qa_results,retrieval_results): - try: - pred_ans = qa_result.split('Answer:')[1].strip() - except: - pred_ans = qa_result - question_info["pred_answer"] = pred_ans - return retrieval_results - - def retrieve(self, questions): - self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys()) - self.entity_embeddings = np.array(self.entity_embedding_store.embeddings) - self.passage_hash_ids = list(self.passage_embedding_store.hash_id_to_text.keys()) - self.passage_embeddings = np.array(self.passage_embedding_store.embeddings) - self.sentence_hash_ids = list(self.sentence_embedding_store.hash_id_to_text.keys()) - self.sentence_embeddings = np.array(self.sentence_embedding_store.embeddings) - self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()} - self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()} - - # Precompute sparse matrices for vectorized retrieval if needed - if self.config.use_vectorized_retrieval: - logger.info("Precomputing sparse adjacency matrices for vectorized retrieval...") - self._precompute_sparse_matrices() - e2s_shape = self.entity_to_sentence_sparse.shape - s2e_shape = self.sentence_to_entity_sparse.shape - e2s_nnz = self.entity_to_sentence_sparse._nnz() - s2e_nnz = self.sentence_to_entity_sparse._nnz() - logger.info(f"Matrices built: Entity-Sentence {e2s_shape}, Sentence-Entity {s2e_shape}") - logger.info(f"E2S Sparsity: {(1 - e2s_nnz / (e2s_shape[0] * e2s_shape[1])) * 100:.2f}% (nnz={e2s_nnz})") - logger.info(f"S2E Sparsity: {(1 - s2e_nnz / (s2e_shape[0] * s2e_shape[1])) * 100:.2f}% (nnz={s2e_nnz})") - logger.info(f"Device: {self.device}") - - retrieval_results = [] - for question_info in tqdm(questions, desc="Retrieving"): - question = question_info["question"] - question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size) - seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question) - if len(seed_entities) != 0: - sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) - final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k] - final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k] - final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids] - else: - sorted_passage_indices,sorted_passage_scores = self.dense_passage_retrieval(question_embedding) - final_passage_indices = sorted_passage_indices[:self.config.retrieval_top_k] - final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k] - final_passages = [self.passage_embedding_store.texts[idx] for idx in final_passage_indices] - result = { - "question": question, - "sorted_passage": final_passages, - "sorted_passage_scores": final_passage_scores, - "gold_answer": question_info["answer"] - } - retrieval_results.append(result) - return retrieval_results - - def _precompute_sparse_matrices(self): - """ - Precompute and cache sparse adjacency matrices for efficient vectorized retrieval using PyTorch. - This is called once at the beginning of retrieve() to avoid rebuilding matrices per query. - """ - num_entities = len(self.entity_hash_ids) - num_sentences = len(self.sentence_hash_ids) - - # Build entity-to-sentence matrix (Mention matrix) using COO format - entity_to_sentence_indices = [] - entity_to_sentence_values = [] - - for entity_hash_id, sentence_hash_ids in self.entity_hash_id_to_sentence_hash_ids.items(): - entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id] - for sentence_hash_id in sentence_hash_ids: - sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id] - entity_to_sentence_indices.append([entity_idx, sentence_idx]) - entity_to_sentence_values.append(1.0) - - # Build sentence-to-entity matrix - sentence_to_entity_indices = [] - sentence_to_entity_values = [] - - for sentence_hash_id, entity_hash_ids in self.sentence_hash_id_to_entity_hash_ids.items(): - sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id] - for entity_hash_id in entity_hash_ids: - entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id] - sentence_to_entity_indices.append([sentence_idx, entity_idx]) - sentence_to_entity_values.append(1.0) - - # Convert to PyTorch sparse tensors (COO format, then convert to CSR for efficiency) - if len(entity_to_sentence_indices) > 0: - e2s_indices = torch.tensor(entity_to_sentence_indices, dtype=torch.long).t() - e2s_values = torch.tensor(entity_to_sentence_values, dtype=torch.float32) - self.entity_to_sentence_sparse = torch.sparse_coo_tensor( - e2s_indices, e2s_values, (num_entities, num_sentences), device=self.device - ).coalesce() - else: - self.entity_to_sentence_sparse = torch.sparse_coo_tensor( - torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32), - (num_entities, num_sentences), device=self.device - ) - - if len(sentence_to_entity_indices) > 0: - s2e_indices = torch.tensor(sentence_to_entity_indices, dtype=torch.long).t() - s2e_values = torch.tensor(sentence_to_entity_values, dtype=torch.float32) - self.sentence_to_entity_sparse = torch.sparse_coo_tensor( - s2e_indices, s2e_values, (num_sentences, num_entities), device=self.device - ).coalesce() - else: - self.sentence_to_entity_sparse = torch.sparse_coo_tensor( - torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32), - (num_sentences, num_entities), device=self.device - ) - - def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores): - if self.config.use_vectorized_retrieval: - entity_weights, actived_entities = self.calculate_entity_scores_vectorized(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) - else: - entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) - passage_weights = self.calculate_passage_scores(question_embedding,actived_entities) - node_weights = entity_weights + passage_weights - ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights) - return ppr_sorted_passage_indices,ppr_sorted_passage_scores - - def run_ppr(self, node_weights): - reset_prob = np.where(np.isnan(node_weights) | (node_weights < 0), 0, node_weights) - pagerank_scores = self.graph.personalized_pagerank( - vertices=range(len(self.node_name_to_vertex_idx)), - damping=self.config.damping, - directed=False, - weights='weight', - reset=reset_prob, - implementation='prpack' - ) - - doc_scores = np.array([pagerank_scores[idx] for idx in self.passage_node_indices]) - sorted_indices_in_doc_scores = np.argsort(doc_scores)[::-1] - sorted_passage_scores = doc_scores[sorted_indices_in_doc_scores] - - sorted_passage_hash_ids = [ - self.vertex_idx_to_node_name[self.passage_node_indices[i]] - for i in sorted_indices_in_doc_scores - ] - - return sorted_passage_hash_ids, sorted_passage_scores.tolist() - - def calculate_entity_scores(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): - actived_entities = {} - entity_weights = np.zeros(len(self.graph.vs["name"])) - for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): - actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1) - seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id] - entity_weights[seed_entity_node_idx] = seed_entity_score - used_sentence_hash_ids = set() - current_entities = actived_entities.copy() - iteration = 1 - while len(current_entities) > 0 and iteration < self.config.max_iterations: - new_entities = {} - for entity_hash_id, (entity_id, entity_score, tier) in current_entities.items(): - if entity_score < self.config.iteration_threshold: - continue - sentence_hash_ids = [sid for sid in list(self.entity_hash_id_to_sentence_hash_ids[entity_hash_id]) if sid not in used_sentence_hash_ids] - if not sentence_hash_ids: - continue - sentence_indices = [self.sentence_embedding_store.hash_id_to_idx[sid] for sid in sentence_hash_ids] - sentence_embeddings = self.sentence_embeddings[sentence_indices] - question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding - sentence_similarities = np.dot(sentence_embeddings, question_emb).flatten() - top_sentence_indices = np.argsort(sentence_similarities)[::-1][:self.config.top_k_sentence] - for top_sentence_index in top_sentence_indices: - top_sentence_hash_id = sentence_hash_ids[top_sentence_index] - top_sentence_score = sentence_similarities[top_sentence_index] - used_sentence_hash_ids.add(top_sentence_hash_id) - entity_hash_ids_in_sentence = self.sentence_hash_id_to_entity_hash_ids[top_sentence_hash_id] - for next_entity_hash_id in entity_hash_ids_in_sentence: - next_entity_score = entity_score * top_sentence_score - if next_entity_score < self.config.iteration_threshold: - continue - next_enitity_node_idx = self.node_name_to_vertex_idx[next_entity_hash_id] - entity_weights[next_enitity_node_idx] += next_entity_score - new_entities[next_entity_hash_id] = (next_enitity_node_idx, next_entity_score, iteration+1) - actived_entities.update(new_entities) - current_entities = new_entities.copy() - iteration += 1 - return entity_weights, actived_entities - - def calculate_entity_scores_vectorized(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): - """ - GPU-accelerated vectorized version using PyTorch sparse tensors. - Uses sparse representation for both matrices and entity score vectors for maximum efficiency. - Now includes proper dynamic pruning to match BFS behavior: - - Sentence deduplication (tracks used sentences) - - Per-entity top-k sentence selection - - Proper threshold-based pruning - """ - # Initialize entity weights - entity_weights = np.zeros(len(self.graph.vs["name"])) - num_entities = len(self.entity_hash_ids) - num_sentences = len(self.sentence_hash_ids) - - # Compute all sentence similarities with the question at once - question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding - sentence_similarities_np = np.dot(self.sentence_embeddings, question_emb).flatten() - - # Convert to torch tensors and move to device - sentence_similarities = torch.from_numpy(sentence_similarities_np).float().to(self.device) - - # Track used sentences for deduplication (like BFS version) - used_sentence_mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device) - - # Initialize seed entity scores as sparse tensor - seed_indices = torch.tensor([[idx] for idx in seed_entity_indices], dtype=torch.long).t() - seed_values = torch.tensor(seed_entity_scores, dtype=torch.float32) - entity_scores_sparse = torch.sparse_coo_tensor( - seed_indices, seed_values, (num_entities,), device=self.device - ).coalesce() - - # Also maintain a dense accumulator for total scores - entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device) - entity_scores_dense.scatter_(0, torch.tensor(seed_entity_indices, device=self.device), - torch.tensor(seed_entity_scores, dtype=torch.float32, device=self.device)) - - # Initialize actived_entities - actived_entities = {} - for seed_entity_idx, seed_entity, seed_entity_hash_id, seed_entity_score in zip( - seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores - ): - actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 0) - seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id] - entity_weights[seed_entity_node_idx] = seed_entity_score - - current_entity_scores_sparse = entity_scores_sparse - - # Iterative matrix-based propagation using sparse matrices on GPU - for iteration in range(1, self.config.max_iterations): - # Convert sparse tensor to dense for threshold operation - current_entity_scores_dense = current_entity_scores_sparse.to_dense() - - # Apply threshold to current scores - current_entity_scores_dense = torch.where( - current_entity_scores_dense >= self.config.iteration_threshold, - current_entity_scores_dense, - torch.zeros_like(current_entity_scores_dense) - ) - - # Get non-zero indices for sparse representation - nonzero_mask = current_entity_scores_dense > 0 - nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).squeeze(-1) - - if len(nonzero_indices) == 0: - break - - # Extract non-zero values and create sparse tensor - nonzero_values = current_entity_scores_dense[nonzero_indices] - current_entity_scores_sparse = torch.sparse_coo_tensor( - nonzero_indices.unsqueeze(0), nonzero_values, (num_entities,), device=self.device - ).coalesce() - - # Step 1: Sparse entity scores @ Sparse E2S matrix - # Convert sparse vector to 2D for matrix multiplication - current_scores_2d = torch.sparse_coo_tensor( - torch.stack([nonzero_indices, torch.zeros_like(nonzero_indices)]), - nonzero_values, - (num_entities, 1), - device=self.device - ).coalesce() - - # E @ E2S -> sentence activation scores (sparse @ sparse = dense) - sentence_activation = torch.sparse.mm( - self.entity_to_sentence_sparse.t(), - current_scores_2d - ) - # Convert to dense before squeeze to avoid CUDA sparse tensor issues - if sentence_activation.is_sparse: - sentence_activation = sentence_activation.to_dense() - sentence_activation = sentence_activation.squeeze() - - # Apply sentence deduplication: mask out used sentences - sentence_activation = torch.where( - used_sentence_mask, - torch.zeros_like(sentence_activation), - sentence_activation - ) - - # Step 2: Apply sentence similarities (element-wise on dense vector) - weighted_sentence_scores = sentence_activation * sentence_similarities - - # Implement per-entity top-k sentence selection (more aggressive pruning) - # For vectorized efficiency, we use a tighter global approximation - num_active = len(nonzero_indices) - if num_active > 0 and self.config.top_k_sentence > 0: - # Calculate adaptive k based on number of active entities - # Use per-entity top-k approximation: num_active * top_k_sentence - k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores)) - if k > 0: - # Get top-k sentences - top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k) - # Zero out all non-top-k sentences - mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool) - mask[top_k_indices] = True - weighted_sentence_scores = torch.where( - mask, - weighted_sentence_scores, - torch.zeros_like(weighted_sentence_scores) - ) - - # Mark these sentences as used for deduplication - used_sentence_mask[top_k_indices] = True - - # Step 3: Weighted sentences @ S2E -> propagate to next entities - # Convert to sparse for more efficient computation - weighted_nonzero_mask = weighted_sentence_scores > 0 - weighted_nonzero_indices = torch.nonzero(weighted_nonzero_mask, as_tuple=False).squeeze(-1) - - if len(weighted_nonzero_indices) > 0: - weighted_nonzero_values = weighted_sentence_scores[weighted_nonzero_indices] - weighted_scores_2d = torch.sparse_coo_tensor( - torch.stack([weighted_nonzero_indices, torch.zeros_like(weighted_nonzero_indices)]), - weighted_nonzero_values, - (num_sentences, 1), - device=self.device - ).coalesce() - - next_entity_scores_result = torch.sparse.mm( - self.sentence_to_entity_sparse.t(), - weighted_scores_2d - ) - # Convert to dense before squeeze to avoid CUDA sparse tensor issues - if next_entity_scores_result.is_sparse: - next_entity_scores_result = next_entity_scores_result.to_dense() - next_entity_scores_dense = next_entity_scores_result.squeeze() - else: - next_entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device) - - # Update entity scores (accumulate in dense format) - entity_scores_dense += next_entity_scores_dense - - # Update actived_entities dictionary (only for entities above threshold) - next_entity_scores_np = next_entity_scores_dense.cpu().numpy() - active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0] - for entity_idx in active_indices: - score = next_entity_scores_np[entity_idx] - entity_hash_id = self.entity_hash_ids[entity_idx] - if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score: - actived_entities[entity_hash_id] = (entity_idx, float(score), iteration) - - # Prepare sparse tensor for next iteration - next_nonzero_mask = next_entity_scores_dense > 0 - next_nonzero_indices = torch.nonzero(next_nonzero_mask, as_tuple=False).squeeze(-1) - if len(next_nonzero_indices) > 0: - next_nonzero_values = next_entity_scores_dense[next_nonzero_indices] - current_entity_scores_sparse = torch.sparse_coo_tensor( - next_nonzero_indices.unsqueeze(0), next_nonzero_values, - (num_entities,), device=self.device - ).coalesce() - else: - break - - # Convert back to numpy for final processing - entity_scores_final = entity_scores_dense.cpu().numpy() - - # Map entity scores to graph node weights (only for non-zero scores) - nonzero_indices = np.where(entity_scores_final > 0)[0] - for entity_idx in nonzero_indices: - score = entity_scores_final[entity_idx] - entity_hash_id = self.entity_hash_ids[entity_idx] - entity_node_idx = self.node_name_to_vertex_idx[entity_hash_id] - entity_weights[entity_node_idx] = float(score) - - return entity_weights, actived_entities - - def calculate_passage_scores(self, question_embedding, actived_entities): - passage_weights = np.zeros(len(self.graph.vs["name"])) - dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding) - dpr_passage_scores = min_max_normalize(dpr_passage_scores) - for i, dpr_passage_index in enumerate(dpr_passage_indices): - total_entity_bonus = 0 - passage_hash_id = self.passage_embedding_store.hash_ids[dpr_passage_index] - dpr_passage_score = dpr_passage_scores[i] - passage_text_lower = self.passage_embedding_store.hash_id_to_text[passage_hash_id].lower() - for entity_hash_id, (entity_id, entity_score, tier) in actived_entities.items(): - entity_lower = self.entity_embedding_store.hash_id_to_text[entity_hash_id].lower() - entity_occurrences = passage_text_lower.count(entity_lower) - if entity_occurrences > 0: - denom = tier if tier >= 1 else 1 - entity_bonus = entity_score * math.log(1 + entity_occurrences) / denom - total_entity_bonus += entity_bonus - passage_score = self.config.passage_ratio * dpr_passage_score + math.log(1 + total_entity_bonus) - passage_node_idx = self.node_name_to_vertex_idx[passage_hash_id] - passage_weights[passage_node_idx] = passage_score * self.config.passage_node_weight - return passage_weights - - def dense_passage_retrieval(self, question_embedding): - question_emb = question_embedding.reshape(1, -1) - question_passage_similarities = np.dot(self.passage_embeddings, question_emb.T).flatten() - sorted_passage_indices = np.argsort(question_passage_similarities)[::-1] - sorted_passage_scores = question_passage_similarities[sorted_passage_indices].tolist() - return sorted_passage_indices, sorted_passage_scores - - def get_seed_entities(self, question): - question_entities = list(self.spacy_ner.question_ner(question)) - if len(question_entities) == 0: - return [],[],[],[] - question_entity_embeddings = self.config.embedding_model.encode(question_entities,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size) - similarities = np.dot(self.entity_embeddings, question_entity_embeddings.T) - seed_entity_indices = [] - seed_entity_texts = [] - seed_entity_hash_ids = [] - seed_entity_scores = [] - for query_entity_idx in range(len(question_entities)): - entity_scores = similarities[:, query_entity_idx] - best_entity_idx = np.argmax(entity_scores) - best_entity_score = entity_scores[best_entity_idx] - best_entity_hash_id = self.entity_hash_ids[best_entity_idx] - best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id] - seed_entity_indices.append(best_entity_idx) - seed_entity_texts.append(best_entity_text) - seed_entity_hash_ids.append(best_entity_hash_id) - seed_entity_scores.append(best_entity_score) - return seed_entity_indices, seed_entity_texts, seed_entity_hash_ids, seed_entity_scores - - def index(self, passages): - self.node_to_node_stats = defaultdict(dict) - self.entity_to_sentence_stats = defaultdict(dict) - self.passage_embedding_store.insert_text(passages) - hash_id_to_passage = self.passage_embedding_store.get_hash_id_to_text() - existing_passage_hash_id_to_entities,existing_sentence_to_entities, new_passage_hash_ids = self.load_existing_data(hash_id_to_passage.keys()) - if len(new_passage_hash_ids) > 0: - new_hash_id_to_passage = {k : hash_id_to_passage[k] for k in new_passage_hash_ids} - new_passage_hash_id_to_entities,new_sentence_to_entities = self.spacy_ner.batch_ner(new_hash_id_to_passage, self.config.max_workers) - self.merge_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities) - self.save_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities) - entity_nodes, sentence_nodes,passage_hash_id_to_entities,self.entity_to_sentence,self.sentence_to_entity = self.extract_nodes_and_edges(existing_passage_hash_id_to_entities, existing_sentence_to_entities) - self.sentence_embedding_store.insert_text(list(sentence_nodes)) - self.entity_embedding_store.insert_text(list(entity_nodes)) - self.entity_hash_id_to_sentence_hash_ids = {} - for entity, sentence in self.entity_to_sentence.items(): - entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity] - self.entity_hash_id_to_sentence_hash_ids[entity_hash_id] = [self.sentence_embedding_store.text_to_hash_id[s] for s in sentence] - self.sentence_hash_id_to_entity_hash_ids = {} - for sentence, entities in self.sentence_to_entity.items(): - sentence_hash_id = self.sentence_embedding_store.text_to_hash_id[sentence] - self.sentence_hash_id_to_entity_hash_ids[sentence_hash_id] = [self.entity_embedding_store.text_to_hash_id[e] for e in entities] - self.add_entity_to_passage_edges(passage_hash_id_to_entities) - self.add_adjacent_passage_edges() - self.augment_graph() - output_graphml_path = os.path.join(self.config.working_dir,self.dataset_name, "LinearRAG.graphml") - os.makedirs(os.path.dirname(output_graphml_path), exist_ok=True) - self.graph.write_graphml(output_graphml_path) - - def add_adjacent_passage_edges(self): - passage_id_to_text = self.passage_embedding_store.get_hash_id_to_text() - index_pattern = re.compile(r'^(\d+):') - indexed_items = [ - (int(match.group(1)), node_key) - for node_key, text in passage_id_to_text.items() - if (match := index_pattern.match(text.strip())) - ] - indexed_items.sort(key=lambda x: x[0]) - for i in range(len(indexed_items) - 1): - current_node = indexed_items[i][1] - next_node = indexed_items[i + 1][1] - self.node_to_node_stats[current_node][next_node] = 1.0 - - def augment_graph(self): - self.add_nodes() - self.add_edges() - - def add_nodes(self): - existing_nodes = {v["name"]: v for v in self.graph.vs if "name" in v.attributes()} - entity_hash_id_to_text = self.entity_embedding_store.get_hash_id_to_text() - passage_hash_id_to_text = self.passage_embedding_store.get_hash_id_to_text() - all_hash_id_to_text = {**entity_hash_id_to_text, **passage_hash_id_to_text} - - passage_hash_ids = set(passage_hash_id_to_text.keys()) - - for hash_id, text in all_hash_id_to_text.items(): - if hash_id not in existing_nodes: - self.graph.add_vertex(name=hash_id, content=text) - - self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()} - self.passage_node_indices = [ - self.node_name_to_vertex_idx[passage_id] - for passage_id in passage_hash_ids - if passage_id in self.node_name_to_vertex_idx - ] - - def add_edges(self): - existing_edges = set() - for edge in self.graph.es: - source_name = self.graph.vs[edge.source]["name"] - target_name = self.graph.vs[edge.target]["name"] - existing_edges.add(frozenset([source_name, target_name])) - - new_edges = [] - new_weights = [] - - for node_hash_id, node_to_node_stats in self.node_to_node_stats.items(): - for neighbor_hash_id, weight in node_to_node_stats.items(): - if node_hash_id == neighbor_hash_id: - continue - edge_key = frozenset([node_hash_id, neighbor_hash_id]) - if edge_key not in existing_edges: - new_edges.append((node_hash_id, neighbor_hash_id)) - new_weights.append(weight) - existing_edges.add(edge_key) - - if new_edges: - self.graph.add_edges(new_edges) - start_idx = len(self.graph.es) - len(new_edges) - for i, weight in enumerate(new_weights): - self.graph.es[start_idx + i]['weight'] = weight - - def add_entity_to_passage_edges(self, passage_hash_id_to_entities): - passage_to_entity_count ={} - passage_to_all_score = defaultdict(int) - for passage_hash_id, entities in passage_hash_id_to_entities.items(): - passage = self.passage_embedding_store.hash_id_to_text[passage_hash_id] - for entity in entities: - entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity] - count = passage.count(entity) - passage_to_entity_count[(passage_hash_id, entity_hash_id)] = count - passage_to_all_score[passage_hash_id] += count - for (passage_hash_id, entity_hash_id), count in passage_to_entity_count.items(): - score = count / passage_to_all_score[passage_hash_id] - self.node_to_node_stats[passage_hash_id][entity_hash_id] = score - - def extract_nodes_and_edges(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities): - entity_nodes = set() - sentence_nodes = set() - passage_hash_id_to_entities = defaultdict(set) - entity_to_sentence= defaultdict(set) - sentence_to_entity = defaultdict(set) - for passage_hash_id, entities in existing_passage_hash_id_to_entities.items(): - for entity in entities: - entity_nodes.add(entity) - passage_hash_id_to_entities[passage_hash_id].add(entity) - for sentence,entities in existing_sentence_to_entities.items(): - sentence_nodes.add(sentence) - for entity in entities: - entity_to_sentence[entity].add(sentence) - sentence_to_entity[sentence].add(entity) - return entity_nodes, sentence_nodes, passage_hash_id_to_entities, entity_to_sentence, sentence_to_entity - - def merge_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities): - existing_passage_hash_id_to_entities.update(new_passage_hash_id_to_entities) - existing_sentence_to_entities.update(new_sentence_to_entities) - return existing_passage_hash_id_to_entities, existing_sentence_to_entities - - def save_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities): - with open(self.ner_results_path, "w") as f: - json.dump({"passage_hash_id_to_entities": existing_passage_hash_id_to_entities, "sentence_to_entities": existing_sentence_to_entities}, f) -- cgit v1.2.3