From 525bfcf021d024fae6f96ea6775074fa8b9e6c43 Mon Sep 17 00:00:00 2001 From: LuyaoZhuang Date: Mon, 27 Oct 2025 04:54:31 -0400 Subject: commit --- src/LinearRAG.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) (limited to 'src/LinearRAG.py') diff --git a/src/LinearRAG.py b/src/LinearRAG.py index fa6caaa..2978542 100644 --- a/src/LinearRAG.py +++ b/src/LinearRAG.py @@ -86,9 +86,9 @@ class LinearRAG: for question_info in tqdm(questions, desc="Retrieving"): question = question_info["question"] question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size) - seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question) + seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question) if len(seed_entities) != 0: - sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores) + sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k] final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k] final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids] @@ -108,8 +108,8 @@ class LinearRAG: - def graph_search_with_seed_entities(self, question_embedding, seed_entity_idx, seed_entities, seed_entity_hash_ids, seed_entity_scores): - entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores) + def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores): + entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) passage_weights = self.calculate_passage_scores(question_embedding,actived_entities) node_weights = entity_weights + passage_weights ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights) @@ -137,10 +137,10 @@ class LinearRAG: return sorted_passage_hash_ids, sorted_passage_scores.tolist() - def calculate_entity_scores(self,question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores): + def calculate_entity_scores(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): actived_entities = {} entity_weights = np.zeros(len(self.graph.vs["name"])) - for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores): + for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1) seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id] entity_weights[seed_entity_node_idx] = seed_entity_score @@ -219,8 +219,6 @@ class LinearRAG: entity_scores = similarities[:, query_entity_idx] best_entity_idx = np.argmax(entity_scores) best_entity_score = entity_scores[best_entity_idx] - if best_entity_score < self.config.initial_threshold: - continue best_entity_hash_id = self.entity_hash_ids[best_entity_idx] best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id] seed_entity_indices.append(best_entity_idx) -- cgit v1.2.3