summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/LinearRAG.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/src/LinearRAG.py b/src/LinearRAG.py
index fa6caaa..2978542 100644
--- a/src/LinearRAG.py
+++ b/src/LinearRAG.py
@@ -86,9 +86,9 @@ class LinearRAG:
for question_info in tqdm(questions, desc="Retrieving"):
question = question_info["question"]
question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
- seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question)
+ seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question)
if len(seed_entities) != 0:
- sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k]
final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids]
@@ -108,8 +108,8 @@ class LinearRAG:
- def graph_search_with_seed_entities(self, question_embedding, seed_entity_idx, seed_entities, seed_entity_hash_ids, seed_entity_scores):
- entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores):
+ entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores)
passage_weights = self.calculate_passage_scores(question_embedding,actived_entities)
node_weights = entity_weights + passage_weights
ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights)
@@ -137,10 +137,10 @@ class LinearRAG:
return sorted_passage_hash_ids, sorted_passage_scores.tolist()
- def calculate_entity_scores(self,question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ def calculate_entity_scores(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
actived_entities = {}
entity_weights = np.zeros(len(self.graph.vs["name"]))
- for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores):
actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1)
seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
entity_weights[seed_entity_node_idx] = seed_entity_score
@@ -219,8 +219,6 @@ class LinearRAG:
entity_scores = similarities[:, query_entity_idx]
best_entity_idx = np.argmax(entity_scores)
best_entity_score = entity_scores[best_entity_idx]
- if best_entity_score < self.config.initial_threshold:
- continue
best_entity_hash_id = self.entity_hash_ids[best_entity_idx]
best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id]
seed_entity_indices.append(best_entity_idx)