summaryrefslogtreecommitdiff
path: root/LinearRAG_upload/src
diff options
context:
space:
mode:
authorLuyaoZhuang <zhuangluyao523@gmail.com>2026-01-03 08:04:16 -0500
committerLuyaoZhuang <zhuangluyao523@gmail.com>2026-01-03 08:04:16 -0500
commit97f593471662a194e57bd07520b898b965f6b013 (patch)
tree38ce6779abc74c6de9157569ce690b4e2a4d9ae7 /LinearRAG_upload/src
parentef9c7f907b67428806f76cfe33764d2960259447 (diff)
remove wrongly placed folder
Diffstat (limited to 'LinearRAG_upload/src')
-rw-r--r--LinearRAG_upload/src/LinearRAG.py623
1 files changed, 0 insertions, 623 deletions
diff --git a/LinearRAG_upload/src/LinearRAG.py b/LinearRAG_upload/src/LinearRAG.py
deleted file mode 100644
index 76b65ad..0000000
--- a/LinearRAG_upload/src/LinearRAG.py
+++ /dev/null
@@ -1,623 +0,0 @@
-from src.embedding_store import EmbeddingStore
-from src.utils import min_max_normalize
-import os
-import json
-from collections import defaultdict
-import numpy as np
-import math
-from concurrent.futures import ThreadPoolExecutor
-from tqdm import tqdm
-from src.ner import SpacyNER
-import igraph as ig
-import re
-import logging
-import torch
-logger = logging.getLogger(__name__)
-
-
-class LinearRAG:
- def __init__(self, global_config):
- self.config = global_config
- logger.info(f"Initializing LinearRAG with config: {self.config}")
- retrieval_method = "Vectorized Matrix-based" if self.config.use_vectorized_retrieval else "BFS Iteration"
- logger.info(f"Using retrieval method: {retrieval_method}")
-
- # Setup device for GPU acceleration
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- if self.config.use_vectorized_retrieval:
- logger.info(f"Using device: {self.device} for vectorized retrieval")
-
- self.dataset_name = global_config.dataset_name
- self.load_embedding_store()
- self.llm_model = self.config.llm_model
- self.spacy_ner = SpacyNER(self.config.spacy_model)
- self.graph = ig.Graph(directed=False)
-
- def load_embedding_store(self):
- self.passage_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "passage_embedding.parquet"), batch_size=self.config.batch_size, namespace="passage")
- self.entity_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "entity_embedding.parquet"), batch_size=self.config.batch_size, namespace="entity")
- self.sentence_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "sentence_embedding.parquet"), batch_size=self.config.batch_size, namespace="sentence")
-
- def load_existing_data(self,passage_hash_ids):
- self.ner_results_path = os.path.join(self.config.working_dir,self.dataset_name, "ner_results.json")
- if os.path.exists(self.ner_results_path):
- existing_ner_reuslts = json.load(open(self.ner_results_path))
- existing_passage_hash_id_to_entities = existing_ner_reuslts["passage_hash_id_to_entities"]
- existing_sentence_to_entities = existing_ner_reuslts["sentence_to_entities"]
- existing_passage_hash_ids = set(existing_passage_hash_id_to_entities.keys())
- new_passage_hash_ids = set(passage_hash_ids) - existing_passage_hash_ids
- return existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_ids
- else:
- return {}, {}, passage_hash_ids
-
- def qa(self, questions):
- retrieval_results = self.retrieve(questions)
- system_prompt = f"""As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations."""
- all_messages = []
- for retrieval_result in retrieval_results:
- question = retrieval_result["question"]
- sorted_passage = retrieval_result["sorted_passage"]
- prompt_user = """"""
- for passage in sorted_passage:
- prompt_user += f"{passage}\n"
- prompt_user += f"Question: {question}\n Thought: "
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt_user}
- ]
- all_messages.append(messages)
- with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
- all_qa_results = list(tqdm(
- executor.map(self.llm_model.infer, all_messages),
- total=len(all_messages),
- desc="QA Reading (Parallel)"
- ))
-
- for qa_result,question_info in zip(all_qa_results,retrieval_results):
- try:
- pred_ans = qa_result.split('Answer:')[1].strip()
- except:
- pred_ans = qa_result
- question_info["pred_answer"] = pred_ans
- return retrieval_results
-
- def retrieve(self, questions):
- self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys())
- self.entity_embeddings = np.array(self.entity_embedding_store.embeddings)
- self.passage_hash_ids = list(self.passage_embedding_store.hash_id_to_text.keys())
- self.passage_embeddings = np.array(self.passage_embedding_store.embeddings)
- self.sentence_hash_ids = list(self.sentence_embedding_store.hash_id_to_text.keys())
- self.sentence_embeddings = np.array(self.sentence_embedding_store.embeddings)
- self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
- self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()}
-
- # Precompute sparse matrices for vectorized retrieval if needed
- if self.config.use_vectorized_retrieval:
- logger.info("Precomputing sparse adjacency matrices for vectorized retrieval...")
- self._precompute_sparse_matrices()
- e2s_shape = self.entity_to_sentence_sparse.shape
- s2e_shape = self.sentence_to_entity_sparse.shape
- e2s_nnz = self.entity_to_sentence_sparse._nnz()
- s2e_nnz = self.sentence_to_entity_sparse._nnz()
- logger.info(f"Matrices built: Entity-Sentence {e2s_shape}, Sentence-Entity {s2e_shape}")
- logger.info(f"E2S Sparsity: {(1 - e2s_nnz / (e2s_shape[0] * e2s_shape[1])) * 100:.2f}% (nnz={e2s_nnz})")
- logger.info(f"S2E Sparsity: {(1 - s2e_nnz / (s2e_shape[0] * s2e_shape[1])) * 100:.2f}% (nnz={s2e_nnz})")
- logger.info(f"Device: {self.device}")
-
- retrieval_results = []
- for question_info in tqdm(questions, desc="Retrieving"):
- question = question_info["question"]
- question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
- seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question)
- if len(seed_entities) != 0:
- sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
- final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k]
- final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
- final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids]
- else:
- sorted_passage_indices,sorted_passage_scores = self.dense_passage_retrieval(question_embedding)
- final_passage_indices = sorted_passage_indices[:self.config.retrieval_top_k]
- final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
- final_passages = [self.passage_embedding_store.texts[idx] for idx in final_passage_indices]
- result = {
- "question": question,
- "sorted_passage": final_passages,
- "sorted_passage_scores": final_passage_scores,
- "gold_answer": question_info["answer"]
- }
- retrieval_results.append(result)
- return retrieval_results
-
- def _precompute_sparse_matrices(self):
- """
- Precompute and cache sparse adjacency matrices for efficient vectorized retrieval using PyTorch.
- This is called once at the beginning of retrieve() to avoid rebuilding matrices per query.
- """
- num_entities = len(self.entity_hash_ids)
- num_sentences = len(self.sentence_hash_ids)
-
- # Build entity-to-sentence matrix (Mention matrix) using COO format
- entity_to_sentence_indices = []
- entity_to_sentence_values = []
-
- for entity_hash_id, sentence_hash_ids in self.entity_hash_id_to_sentence_hash_ids.items():
- entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id]
- for sentence_hash_id in sentence_hash_ids:
- sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id]
- entity_to_sentence_indices.append([entity_idx, sentence_idx])
- entity_to_sentence_values.append(1.0)
-
- # Build sentence-to-entity matrix
- sentence_to_entity_indices = []
- sentence_to_entity_values = []
-
- for sentence_hash_id, entity_hash_ids in self.sentence_hash_id_to_entity_hash_ids.items():
- sentence_idx = self.sentence_embedding_store.hash_id_to_idx[sentence_hash_id]
- for entity_hash_id in entity_hash_ids:
- entity_idx = self.entity_embedding_store.hash_id_to_idx[entity_hash_id]
- sentence_to_entity_indices.append([sentence_idx, entity_idx])
- sentence_to_entity_values.append(1.0)
-
- # Convert to PyTorch sparse tensors (COO format, then convert to CSR for efficiency)
- if len(entity_to_sentence_indices) > 0:
- e2s_indices = torch.tensor(entity_to_sentence_indices, dtype=torch.long).t()
- e2s_values = torch.tensor(entity_to_sentence_values, dtype=torch.float32)
- self.entity_to_sentence_sparse = torch.sparse_coo_tensor(
- e2s_indices, e2s_values, (num_entities, num_sentences), device=self.device
- ).coalesce()
- else:
- self.entity_to_sentence_sparse = torch.sparse_coo_tensor(
- torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32),
- (num_entities, num_sentences), device=self.device
- )
-
- if len(sentence_to_entity_indices) > 0:
- s2e_indices = torch.tensor(sentence_to_entity_indices, dtype=torch.long).t()
- s2e_values = torch.tensor(sentence_to_entity_values, dtype=torch.float32)
- self.sentence_to_entity_sparse = torch.sparse_coo_tensor(
- s2e_indices, s2e_values, (num_sentences, num_entities), device=self.device
- ).coalesce()
- else:
- self.sentence_to_entity_sparse = torch.sparse_coo_tensor(
- torch.zeros((2, 0), dtype=torch.long), torch.zeros(0, dtype=torch.float32),
- (num_sentences, num_entities), device=self.device
- )
-
- def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores):
- if self.config.use_vectorized_retrieval:
- entity_weights, actived_entities = self.calculate_entity_scores_vectorized(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
- else:
- entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
- passage_weights = self.calculate_passage_scores(question_embedding,actived_entities)
- node_weights = entity_weights + passage_weights
- ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights)
- return ppr_sorted_passage_indices,ppr_sorted_passage_scores
-
- def run_ppr(self, node_weights):
- reset_prob = np.where(np.isnan(node_weights) | (node_weights < 0), 0, node_weights)
- pagerank_scores = self.graph.personalized_pagerank(
- vertices=range(len(self.node_name_to_vertex_idx)),
- damping=self.config.damping,
- directed=False,
- weights='weight',
- reset=reset_prob,
- implementation='prpack'
- )
-
- doc_scores = np.array([pagerank_scores[idx] for idx in self.passage_node_indices])
- sorted_indices_in_doc_scores = np.argsort(doc_scores)[::-1]
- sorted_passage_scores = doc_scores[sorted_indices_in_doc_scores]
-
- sorted_passage_hash_ids = [
- self.vertex_idx_to_node_name[self.passage_node_indices[i]]
- for i in sorted_indices_in_doc_scores
- ]
-
- return sorted_passage_hash_ids, sorted_passage_scores.tolist()
-
- def calculate_entity_scores(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
- actived_entities = {}
- entity_weights = np.zeros(len(self.graph.vs["name"]))
- for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
- actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1)
- seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
- entity_weights[seed_entity_node_idx] = seed_entity_score
- used_sentence_hash_ids = set()
- current_entities = actived_entities.copy()
- iteration = 1
- while len(current_entities) > 0 and iteration < self.config.max_iterations:
- new_entities = {}
- for entity_hash_id, (entity_id, entity_score, tier) in current_entities.items():
- if entity_score < self.config.iteration_threshold:
- continue
- sentence_hash_ids = [sid for sid in list(self.entity_hash_id_to_sentence_hash_ids[entity_hash_id]) if sid not in used_sentence_hash_ids]
- if not sentence_hash_ids:
- continue
- sentence_indices = [self.sentence_embedding_store.hash_id_to_idx[sid] for sid in sentence_hash_ids]
- sentence_embeddings = self.sentence_embeddings[sentence_indices]
- question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding
- sentence_similarities = np.dot(sentence_embeddings, question_emb).flatten()
- top_sentence_indices = np.argsort(sentence_similarities)[::-1][:self.config.top_k_sentence]
- for top_sentence_index in top_sentence_indices:
- top_sentence_hash_id = sentence_hash_ids[top_sentence_index]
- top_sentence_score = sentence_similarities[top_sentence_index]
- used_sentence_hash_ids.add(top_sentence_hash_id)
- entity_hash_ids_in_sentence = self.sentence_hash_id_to_entity_hash_ids[top_sentence_hash_id]
- for next_entity_hash_id in entity_hash_ids_in_sentence:
- next_entity_score = entity_score * top_sentence_score
- if next_entity_score < self.config.iteration_threshold:
- continue
- next_enitity_node_idx = self.node_name_to_vertex_idx[next_entity_hash_id]
- entity_weights[next_enitity_node_idx] += next_entity_score
- new_entities[next_entity_hash_id] = (next_enitity_node_idx, next_entity_score, iteration+1)
- actived_entities.update(new_entities)
- current_entities = new_entities.copy()
- iteration += 1
- return entity_weights, actived_entities
-
- def calculate_entity_scores_vectorized(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
- """
- GPU-accelerated vectorized version using PyTorch sparse tensors.
- Uses sparse representation for both matrices and entity score vectors for maximum efficiency.
- Now includes proper dynamic pruning to match BFS behavior:
- - Sentence deduplication (tracks used sentences)
- - Per-entity top-k sentence selection
- - Proper threshold-based pruning
- """
- # Initialize entity weights
- entity_weights = np.zeros(len(self.graph.vs["name"]))
- num_entities = len(self.entity_hash_ids)
- num_sentences = len(self.sentence_hash_ids)
-
- # Compute all sentence similarities with the question at once
- question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding
- sentence_similarities_np = np.dot(self.sentence_embeddings, question_emb).flatten()
-
- # Convert to torch tensors and move to device
- sentence_similarities = torch.from_numpy(sentence_similarities_np).float().to(self.device)
-
- # Track used sentences for deduplication (like BFS version)
- used_sentence_mask = torch.zeros(num_sentences, dtype=torch.bool, device=self.device)
-
- # Initialize seed entity scores as sparse tensor
- seed_indices = torch.tensor([[idx] for idx in seed_entity_indices], dtype=torch.long).t()
- seed_values = torch.tensor(seed_entity_scores, dtype=torch.float32)
- entity_scores_sparse = torch.sparse_coo_tensor(
- seed_indices, seed_values, (num_entities,), device=self.device
- ).coalesce()
-
- # Also maintain a dense accumulator for total scores
- entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device)
- entity_scores_dense.scatter_(0, torch.tensor(seed_entity_indices, device=self.device),
- torch.tensor(seed_entity_scores, dtype=torch.float32, device=self.device))
-
- # Initialize actived_entities
- actived_entities = {}
- for seed_entity_idx, seed_entity, seed_entity_hash_id, seed_entity_score in zip(
- seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores
- ):
- actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 0)
- seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
- entity_weights[seed_entity_node_idx] = seed_entity_score
-
- current_entity_scores_sparse = entity_scores_sparse
-
- # Iterative matrix-based propagation using sparse matrices on GPU
- for iteration in range(1, self.config.max_iterations):
- # Convert sparse tensor to dense for threshold operation
- current_entity_scores_dense = current_entity_scores_sparse.to_dense()
-
- # Apply threshold to current scores
- current_entity_scores_dense = torch.where(
- current_entity_scores_dense >= self.config.iteration_threshold,
- current_entity_scores_dense,
- torch.zeros_like(current_entity_scores_dense)
- )
-
- # Get non-zero indices for sparse representation
- nonzero_mask = current_entity_scores_dense > 0
- nonzero_indices = torch.nonzero(nonzero_mask, as_tuple=False).squeeze(-1)
-
- if len(nonzero_indices) == 0:
- break
-
- # Extract non-zero values and create sparse tensor
- nonzero_values = current_entity_scores_dense[nonzero_indices]
- current_entity_scores_sparse = torch.sparse_coo_tensor(
- nonzero_indices.unsqueeze(0), nonzero_values, (num_entities,), device=self.device
- ).coalesce()
-
- # Step 1: Sparse entity scores @ Sparse E2S matrix
- # Convert sparse vector to 2D for matrix multiplication
- current_scores_2d = torch.sparse_coo_tensor(
- torch.stack([nonzero_indices, torch.zeros_like(nonzero_indices)]),
- nonzero_values,
- (num_entities, 1),
- device=self.device
- ).coalesce()
-
- # E @ E2S -> sentence activation scores (sparse @ sparse = dense)
- sentence_activation = torch.sparse.mm(
- self.entity_to_sentence_sparse.t(),
- current_scores_2d
- )
- # Convert to dense before squeeze to avoid CUDA sparse tensor issues
- if sentence_activation.is_sparse:
- sentence_activation = sentence_activation.to_dense()
- sentence_activation = sentence_activation.squeeze()
-
- # Apply sentence deduplication: mask out used sentences
- sentence_activation = torch.where(
- used_sentence_mask,
- torch.zeros_like(sentence_activation),
- sentence_activation
- )
-
- # Step 2: Apply sentence similarities (element-wise on dense vector)
- weighted_sentence_scores = sentence_activation * sentence_similarities
-
- # Implement per-entity top-k sentence selection (more aggressive pruning)
- # For vectorized efficiency, we use a tighter global approximation
- num_active = len(nonzero_indices)
- if num_active > 0 and self.config.top_k_sentence > 0:
- # Calculate adaptive k based on number of active entities
- # Use per-entity top-k approximation: num_active * top_k_sentence
- k = min(int(num_active * self.config.top_k_sentence), len(weighted_sentence_scores))
- if k > 0:
- # Get top-k sentences
- top_k_values, top_k_indices = torch.topk(weighted_sentence_scores, k)
- # Zero out all non-top-k sentences
- mask = torch.zeros_like(weighted_sentence_scores, dtype=torch.bool)
- mask[top_k_indices] = True
- weighted_sentence_scores = torch.where(
- mask,
- weighted_sentence_scores,
- torch.zeros_like(weighted_sentence_scores)
- )
-
- # Mark these sentences as used for deduplication
- used_sentence_mask[top_k_indices] = True
-
- # Step 3: Weighted sentences @ S2E -> propagate to next entities
- # Convert to sparse for more efficient computation
- weighted_nonzero_mask = weighted_sentence_scores > 0
- weighted_nonzero_indices = torch.nonzero(weighted_nonzero_mask, as_tuple=False).squeeze(-1)
-
- if len(weighted_nonzero_indices) > 0:
- weighted_nonzero_values = weighted_sentence_scores[weighted_nonzero_indices]
- weighted_scores_2d = torch.sparse_coo_tensor(
- torch.stack([weighted_nonzero_indices, torch.zeros_like(weighted_nonzero_indices)]),
- weighted_nonzero_values,
- (num_sentences, 1),
- device=self.device
- ).coalesce()
-
- next_entity_scores_result = torch.sparse.mm(
- self.sentence_to_entity_sparse.t(),
- weighted_scores_2d
- )
- # Convert to dense before squeeze to avoid CUDA sparse tensor issues
- if next_entity_scores_result.is_sparse:
- next_entity_scores_result = next_entity_scores_result.to_dense()
- next_entity_scores_dense = next_entity_scores_result.squeeze()
- else:
- next_entity_scores_dense = torch.zeros(num_entities, dtype=torch.float32, device=self.device)
-
- # Update entity scores (accumulate in dense format)
- entity_scores_dense += next_entity_scores_dense
-
- # Update actived_entities dictionary (only for entities above threshold)
- next_entity_scores_np = next_entity_scores_dense.cpu().numpy()
- active_indices = np.where(next_entity_scores_np >= self.config.iteration_threshold)[0]
- for entity_idx in active_indices:
- score = next_entity_scores_np[entity_idx]
- entity_hash_id = self.entity_hash_ids[entity_idx]
- if entity_hash_id not in actived_entities or actived_entities[entity_hash_id][1] < score:
- actived_entities[entity_hash_id] = (entity_idx, float(score), iteration)
-
- # Prepare sparse tensor for next iteration
- next_nonzero_mask = next_entity_scores_dense > 0
- next_nonzero_indices = torch.nonzero(next_nonzero_mask, as_tuple=False).squeeze(-1)
- if len(next_nonzero_indices) > 0:
- next_nonzero_values = next_entity_scores_dense[next_nonzero_indices]
- current_entity_scores_sparse = torch.sparse_coo_tensor(
- next_nonzero_indices.unsqueeze(0), next_nonzero_values,
- (num_entities,), device=self.device
- ).coalesce()
- else:
- break
-
- # Convert back to numpy for final processing
- entity_scores_final = entity_scores_dense.cpu().numpy()
-
- # Map entity scores to graph node weights (only for non-zero scores)
- nonzero_indices = np.where(entity_scores_final > 0)[0]
- for entity_idx in nonzero_indices:
- score = entity_scores_final[entity_idx]
- entity_hash_id = self.entity_hash_ids[entity_idx]
- entity_node_idx = self.node_name_to_vertex_idx[entity_hash_id]
- entity_weights[entity_node_idx] = float(score)
-
- return entity_weights, actived_entities
-
- def calculate_passage_scores(self, question_embedding, actived_entities):
- passage_weights = np.zeros(len(self.graph.vs["name"]))
- dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding)
- dpr_passage_scores = min_max_normalize(dpr_passage_scores)
- for i, dpr_passage_index in enumerate(dpr_passage_indices):
- total_entity_bonus = 0
- passage_hash_id = self.passage_embedding_store.hash_ids[dpr_passage_index]
- dpr_passage_score = dpr_passage_scores[i]
- passage_text_lower = self.passage_embedding_store.hash_id_to_text[passage_hash_id].lower()
- for entity_hash_id, (entity_id, entity_score, tier) in actived_entities.items():
- entity_lower = self.entity_embedding_store.hash_id_to_text[entity_hash_id].lower()
- entity_occurrences = passage_text_lower.count(entity_lower)
- if entity_occurrences > 0:
- denom = tier if tier >= 1 else 1
- entity_bonus = entity_score * math.log(1 + entity_occurrences) / denom
- total_entity_bonus += entity_bonus
- passage_score = self.config.passage_ratio * dpr_passage_score + math.log(1 + total_entity_bonus)
- passage_node_idx = self.node_name_to_vertex_idx[passage_hash_id]
- passage_weights[passage_node_idx] = passage_score * self.config.passage_node_weight
- return passage_weights
-
- def dense_passage_retrieval(self, question_embedding):
- question_emb = question_embedding.reshape(1, -1)
- question_passage_similarities = np.dot(self.passage_embeddings, question_emb.T).flatten()
- sorted_passage_indices = np.argsort(question_passage_similarities)[::-1]
- sorted_passage_scores = question_passage_similarities[sorted_passage_indices].tolist()
- return sorted_passage_indices, sorted_passage_scores
-
- def get_seed_entities(self, question):
- question_entities = list(self.spacy_ner.question_ner(question))
- if len(question_entities) == 0:
- return [],[],[],[]
- question_entity_embeddings = self.config.embedding_model.encode(question_entities,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
- similarities = np.dot(self.entity_embeddings, question_entity_embeddings.T)
- seed_entity_indices = []
- seed_entity_texts = []
- seed_entity_hash_ids = []
- seed_entity_scores = []
- for query_entity_idx in range(len(question_entities)):
- entity_scores = similarities[:, query_entity_idx]
- best_entity_idx = np.argmax(entity_scores)
- best_entity_score = entity_scores[best_entity_idx]
- best_entity_hash_id = self.entity_hash_ids[best_entity_idx]
- best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id]
- seed_entity_indices.append(best_entity_idx)
- seed_entity_texts.append(best_entity_text)
- seed_entity_hash_ids.append(best_entity_hash_id)
- seed_entity_scores.append(best_entity_score)
- return seed_entity_indices, seed_entity_texts, seed_entity_hash_ids, seed_entity_scores
-
- def index(self, passages):
- self.node_to_node_stats = defaultdict(dict)
- self.entity_to_sentence_stats = defaultdict(dict)
- self.passage_embedding_store.insert_text(passages)
- hash_id_to_passage = self.passage_embedding_store.get_hash_id_to_text()
- existing_passage_hash_id_to_entities,existing_sentence_to_entities, new_passage_hash_ids = self.load_existing_data(hash_id_to_passage.keys())
- if len(new_passage_hash_ids) > 0:
- new_hash_id_to_passage = {k : hash_id_to_passage[k] for k in new_passage_hash_ids}
- new_passage_hash_id_to_entities,new_sentence_to_entities = self.spacy_ner.batch_ner(new_hash_id_to_passage, self.config.max_workers)
- self.merge_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities)
- self.save_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities)
- entity_nodes, sentence_nodes,passage_hash_id_to_entities,self.entity_to_sentence,self.sentence_to_entity = self.extract_nodes_and_edges(existing_passage_hash_id_to_entities, existing_sentence_to_entities)
- self.sentence_embedding_store.insert_text(list(sentence_nodes))
- self.entity_embedding_store.insert_text(list(entity_nodes))
- self.entity_hash_id_to_sentence_hash_ids = {}
- for entity, sentence in self.entity_to_sentence.items():
- entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity]
- self.entity_hash_id_to_sentence_hash_ids[entity_hash_id] = [self.sentence_embedding_store.text_to_hash_id[s] for s in sentence]
- self.sentence_hash_id_to_entity_hash_ids = {}
- for sentence, entities in self.sentence_to_entity.items():
- sentence_hash_id = self.sentence_embedding_store.text_to_hash_id[sentence]
- self.sentence_hash_id_to_entity_hash_ids[sentence_hash_id] = [self.entity_embedding_store.text_to_hash_id[e] for e in entities]
- self.add_entity_to_passage_edges(passage_hash_id_to_entities)
- self.add_adjacent_passage_edges()
- self.augment_graph()
- output_graphml_path = os.path.join(self.config.working_dir,self.dataset_name, "LinearRAG.graphml")
- os.makedirs(os.path.dirname(output_graphml_path), exist_ok=True)
- self.graph.write_graphml(output_graphml_path)
-
- def add_adjacent_passage_edges(self):
- passage_id_to_text = self.passage_embedding_store.get_hash_id_to_text()
- index_pattern = re.compile(r'^(\d+):')
- indexed_items = [
- (int(match.group(1)), node_key)
- for node_key, text in passage_id_to_text.items()
- if (match := index_pattern.match(text.strip()))
- ]
- indexed_items.sort(key=lambda x: x[0])
- for i in range(len(indexed_items) - 1):
- current_node = indexed_items[i][1]
- next_node = indexed_items[i + 1][1]
- self.node_to_node_stats[current_node][next_node] = 1.0
-
- def augment_graph(self):
- self.add_nodes()
- self.add_edges()
-
- def add_nodes(self):
- existing_nodes = {v["name"]: v for v in self.graph.vs if "name" in v.attributes()}
- entity_hash_id_to_text = self.entity_embedding_store.get_hash_id_to_text()
- passage_hash_id_to_text = self.passage_embedding_store.get_hash_id_to_text()
- all_hash_id_to_text = {**entity_hash_id_to_text, **passage_hash_id_to_text}
-
- passage_hash_ids = set(passage_hash_id_to_text.keys())
-
- for hash_id, text in all_hash_id_to_text.items():
- if hash_id not in existing_nodes:
- self.graph.add_vertex(name=hash_id, content=text)
-
- self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
- self.passage_node_indices = [
- self.node_name_to_vertex_idx[passage_id]
- for passage_id in passage_hash_ids
- if passage_id in self.node_name_to_vertex_idx
- ]
-
- def add_edges(self):
- existing_edges = set()
- for edge in self.graph.es:
- source_name = self.graph.vs[edge.source]["name"]
- target_name = self.graph.vs[edge.target]["name"]
- existing_edges.add(frozenset([source_name, target_name]))
-
- new_edges = []
- new_weights = []
-
- for node_hash_id, node_to_node_stats in self.node_to_node_stats.items():
- for neighbor_hash_id, weight in node_to_node_stats.items():
- if node_hash_id == neighbor_hash_id:
- continue
- edge_key = frozenset([node_hash_id, neighbor_hash_id])
- if edge_key not in existing_edges:
- new_edges.append((node_hash_id, neighbor_hash_id))
- new_weights.append(weight)
- existing_edges.add(edge_key)
-
- if new_edges:
- self.graph.add_edges(new_edges)
- start_idx = len(self.graph.es) - len(new_edges)
- for i, weight in enumerate(new_weights):
- self.graph.es[start_idx + i]['weight'] = weight
-
- def add_entity_to_passage_edges(self, passage_hash_id_to_entities):
- passage_to_entity_count ={}
- passage_to_all_score = defaultdict(int)
- for passage_hash_id, entities in passage_hash_id_to_entities.items():
- passage = self.passage_embedding_store.hash_id_to_text[passage_hash_id]
- for entity in entities:
- entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity]
- count = passage.count(entity)
- passage_to_entity_count[(passage_hash_id, entity_hash_id)] = count
- passage_to_all_score[passage_hash_id] += count
- for (passage_hash_id, entity_hash_id), count in passage_to_entity_count.items():
- score = count / passage_to_all_score[passage_hash_id]
- self.node_to_node_stats[passage_hash_id][entity_hash_id] = score
-
- def extract_nodes_and_edges(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities):
- entity_nodes = set()
- sentence_nodes = set()
- passage_hash_id_to_entities = defaultdict(set)
- entity_to_sentence= defaultdict(set)
- sentence_to_entity = defaultdict(set)
- for passage_hash_id, entities in existing_passage_hash_id_to_entities.items():
- for entity in entities:
- entity_nodes.add(entity)
- passage_hash_id_to_entities[passage_hash_id].add(entity)
- for sentence,entities in existing_sentence_to_entities.items():
- sentence_nodes.add(sentence)
- for entity in entities:
- entity_to_sentence[entity].add(sentence)
- sentence_to_entity[sentence].add(entity)
- return entity_nodes, sentence_nodes, passage_hash_id_to_entities, entity_to_sentence, sentence_to_entity
-
- def merge_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities):
- existing_passage_hash_id_to_entities.update(new_passage_hash_id_to_entities)
- existing_sentence_to_entities.update(new_sentence_to_entities)
- return existing_passage_hash_id_to_entities, existing_sentence_to_entities
-
- def save_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities):
- with open(self.ner_results_path, "w") as f:
- json.dump({"passage_hash_id_to_entities": existing_passage_hash_id_to_entities, "sentence_to_entities": existing_sentence_to_entities}, f)