diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/LinearRAG.py | 15 | ||||
| -rw-r--r-- | src/config.py | 3 |
2 files changed, 7 insertions, 11 deletions
diff --git a/src/LinearRAG.py b/src/LinearRAG.py index fa6caaa..0f4f6f8 100644 --- a/src/LinearRAG.py +++ b/src/LinearRAG.py @@ -12,7 +12,6 @@ import igraph as ig import re import logging logger = logging.getLogger(__name__) - class LinearRAG: def __init__(self, global_config): self.config = global_config @@ -86,9 +85,9 @@ class LinearRAG: for question_info in tqdm(questions, desc="Retrieving"): question = question_info["question"] question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size) - seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question) + seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question) if len(seed_entities) != 0: - sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores) + sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k] final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k] final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids] @@ -108,8 +107,8 @@ class LinearRAG: - def graph_search_with_seed_entities(self, question_embedding, seed_entity_idx, seed_entities, seed_entity_hash_ids, seed_entity_scores): - entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores) + def graph_search_with_seed_entities(self, question_embedding, seed_entity_indices, seed_entities, seed_entity_hash_ids, seed_entity_scores): + entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores) passage_weights = self.calculate_passage_scores(question_embedding,actived_entities) node_weights = entity_weights + passage_weights ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights) @@ -137,10 +136,10 @@ class LinearRAG: return sorted_passage_hash_ids, sorted_passage_scores.tolist() - def calculate_entity_scores(self,question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores): + def calculate_entity_scores(self,question_embedding,seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): actived_entities = {} entity_weights = np.zeros(len(self.graph.vs["name"])) - for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores): + for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_indices,seed_entities,seed_entity_hash_ids,seed_entity_scores): actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1) seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id] entity_weights[seed_entity_node_idx] = seed_entity_score @@ -219,8 +218,6 @@ class LinearRAG: entity_scores = similarities[:, query_entity_idx] best_entity_idx = np.argmax(entity_scores) best_entity_score = entity_scores[best_entity_idx] - if best_entity_score < self.config.initial_threshold: - continue best_entity_hash_id = self.entity_hash_ids[best_entity_idx] best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id] seed_entity_indices.append(best_entity_idx) diff --git a/src/config.py b/src/config.py index 6258a65..3ef136c 100644 --- a/src/config.py +++ b/src/config.py @@ -13,9 +13,8 @@ class LinearRAGConfig: max_workers: int = 16 retrieval_top_k: int = 5 max_iterations: int = 3 - top_k_sentence: int = 2 + top_k_sentence: int = 1 passage_ratio: float = 1.5 passage_node_weight: float = 0.05 damping: float = 0.5 - initial_threshold: float = 0.5 iteration_threshold: float = 0.5
\ No newline at end of file |
