summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/LinearRAG.py351
-rw-r--r--src/config.py21
-rw-r--r--src/embedding_store.py80
-rw-r--r--src/evaluate.py102
-rw-r--r--src/ner.py46
-rw-r--r--src/utils.py77
6 files changed, 677 insertions, 0 deletions
diff --git a/src/LinearRAG.py b/src/LinearRAG.py
new file mode 100644
index 0000000..fa6caaa
--- /dev/null
+++ b/src/LinearRAG.py
@@ -0,0 +1,351 @@
+from src.embedding_store import EmbeddingStore
+from src.utils import min_max_normalize
+import os
+import json
+from collections import defaultdict
+import numpy as np
+import math
+from concurrent.futures import ThreadPoolExecutor
+from tqdm import tqdm
+from src.ner import SpacyNER
+import igraph as ig
+import re
+import logging
+logger = logging.getLogger(__name__)
+
+class LinearRAG:
+ def __init__(self, global_config):
+ self.config = global_config
+ logger.info(f"Initializing LinearRAG with config: {self.config}")
+ self.dataset_name = global_config.dataset_name
+ self.load_embedding_store()
+ self.llm_model = self.config.llm_model
+ self.spacy_ner = SpacyNER(self.config.spacy_model)
+ self.graph = ig.Graph(directed=False)
+
+ def load_embedding_store(self):
+ self.passage_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "passage_embedding.parquet"), batch_size=self.config.batch_size, namespace="passage")
+ self.entity_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "entity_embedding.parquet"), batch_size=self.config.batch_size, namespace="entity")
+ self.sentence_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "sentence_embedding.parquet"), batch_size=self.config.batch_size, namespace="sentence")
+
+ def load_existing_data(self,passage_hash_ids):
+ self.ner_results_path = os.path.join(self.config.working_dir,self.dataset_name, "ner_results.json")
+ if os.path.exists(self.ner_results_path):
+ existing_ner_reuslts = json.load(open(self.ner_results_path))
+ existing_passage_hash_id_to_entities = existing_ner_reuslts["passage_hash_id_to_entities"]
+ existing_sentence_to_entities = existing_ner_reuslts["sentence_to_entities"]
+ existing_passage_hash_ids = set(existing_passage_hash_id_to_entities.keys())
+ new_passage_hash_ids = set(passage_hash_ids) - existing_passage_hash_ids
+ return existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_ids
+ else:
+ return {}, {}, passage_hash_ids
+
+ def qa(self, questions):
+ retrieval_results = self.retrieve(questions)
+ system_prompt = f"""As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations."""
+ all_messages = []
+ for retrieval_result in retrieval_results:
+ question = retrieval_result["question"]
+ sorted_passage = retrieval_result["sorted_passage"]
+ prompt_user = """"""
+ for passage in sorted_passage:
+ prompt_user += f"{passage}\n"
+ prompt_user += f"Question: {question}\n Thought: "
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": prompt_user}
+ ]
+ all_messages.append(messages)
+ with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
+ all_qa_results = list(tqdm(
+ executor.map(self.llm_model.infer, all_messages),
+ total=len(all_messages),
+ desc="QA Reading (Parallel)"
+ ))
+
+ for qa_result,question_info in zip(all_qa_results,retrieval_results):
+ try:
+ pred_ans = qa_result.split('Answer:')[1].strip()
+ except:
+ pred_ans = qa_result
+ question_info["pred_answer"] = pred_ans
+ return retrieval_results
+
+
+ def retrieve(self, questions):
+ self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys())
+ self.entity_embeddings = np.array(self.entity_embedding_store.embeddings)
+ self.passage_hash_ids = list(self.passage_embedding_store.hash_id_to_text.keys())
+ self.passage_embeddings = np.array(self.passage_embedding_store.embeddings)
+ self.sentence_hash_ids = list(self.sentence_embedding_store.hash_id_to_text.keys())
+ self.sentence_embeddings = np.array(self.sentence_embedding_store.embeddings)
+ self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
+ self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()}
+
+ retrieval_results = []
+ for question_info in tqdm(questions, desc="Retrieving"):
+ question = question_info["question"]
+ question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
+ seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question)
+ if len(seed_entities) != 0:
+ sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k]
+ final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
+ final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids]
+ else:
+ sorted_passage_indices,sorted_passage_scores = self.dense_passage_retrieval(question_embedding)
+ final_passage_indices = sorted_passage_indices[:self.config.retrieval_top_k]
+ final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k]
+ final_passages = [self.passage_embedding_store.texts[idx] for idx in final_passage_indices]
+ result = {
+ "question": question,
+ "sorted_passage": final_passages,
+ "sorted_passage_scores": final_passage_scores,
+ "gold_answer": question_info["answer"]
+ }
+ retrieval_results.append(result)
+ return retrieval_results
+
+
+
+ def graph_search_with_seed_entities(self, question_embedding, seed_entity_idx, seed_entities, seed_entity_hash_ids, seed_entity_scores):
+ entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores)
+ passage_weights = self.calculate_passage_scores(question_embedding,actived_entities)
+ node_weights = entity_weights + passage_weights
+ ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights)
+ return ppr_sorted_passage_indices,ppr_sorted_passage_scores
+
+ def run_ppr(self, node_weights):
+ reset_prob = np.where(np.isnan(node_weights) | (node_weights < 0), 0, node_weights)
+ pagerank_scores = self.graph.personalized_pagerank(
+ vertices=range(len(self.node_name_to_vertex_idx)),
+ damping=self.config.damping,
+ directed=False,
+ weights='weight',
+ reset=reset_prob,
+ implementation='prpack'
+ )
+
+ doc_scores = np.array([pagerank_scores[idx] for idx in self.passage_node_indices])
+ sorted_indices_in_doc_scores = np.argsort(doc_scores)[::-1]
+ sorted_passage_scores = doc_scores[sorted_indices_in_doc_scores]
+
+ sorted_passage_hash_ids = [
+ self.vertex_idx_to_node_name[self.passage_node_indices[i]]
+ for i in sorted_indices_in_doc_scores
+ ]
+
+ return sorted_passage_hash_ids, sorted_passage_scores.tolist()
+
+ def calculate_entity_scores(self,question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ actived_entities = {}
+ entity_weights = np.zeros(len(self.graph.vs["name"]))
+ for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores):
+ actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1)
+ seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id]
+ entity_weights[seed_entity_node_idx] = seed_entity_score
+ used_sentence_hash_ids = set()
+ current_entities = actived_entities.copy()
+ iteration = 1
+ while len(current_entities) > 0 and iteration < self.config.max_iterations:
+ new_entities = {}
+ for entity_hash_id, (entity_id, entity_score, tier) in current_entities.items():
+ if entity_score < self.config.iteration_threshold:
+ continue
+ sentence_hash_ids = [sid for sid in list(self.entity_hash_id_to_sentence_hash_ids[entity_hash_id]) if sid not in used_sentence_hash_ids]
+ if not sentence_hash_ids:
+ continue
+ sentence_indices = [self.sentence_embedding_store.hash_id_to_idx[sid] for sid in sentence_hash_ids]
+ sentence_embeddings = self.sentence_embeddings[sentence_indices]
+ question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding
+ sentence_similarities = np.dot(sentence_embeddings, question_emb).flatten()
+ top_sentence_indices = np.argsort(sentence_similarities)[::-1][:self.config.top_k_sentence]
+ for top_sentence_index in top_sentence_indices:
+ top_sentence_hash_id = sentence_hash_ids[top_sentence_index]
+ top_sentence_score = sentence_similarities[top_sentence_index]
+ used_sentence_hash_ids.add(top_sentence_hash_id)
+ entity_hash_ids_in_sentence = self.sentence_hash_id_to_entity_hash_ids[top_sentence_hash_id]
+ for next_entity_hash_id in entity_hash_ids_in_sentence:
+ next_entity_score = entity_score * top_sentence_score
+ if next_entity_score < self.config.iteration_threshold:
+ continue
+ next_enitity_node_idx = self.node_name_to_vertex_idx[next_entity_hash_id]
+ entity_weights[next_enitity_node_idx] += next_entity_score
+ new_entities[next_entity_hash_id] = (next_enitity_node_idx, next_entity_score, iteration+1)
+ actived_entities.update(new_entities)
+ current_entities = new_entities.copy()
+ iteration += 1
+ return entity_weights, actived_entities
+
+ def calculate_passage_scores(self, question_embedding, actived_entities):
+ passage_weights = np.zeros(len(self.graph.vs["name"]))
+ dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding)
+ dpr_passage_scores = min_max_normalize(dpr_passage_scores)
+ for i, dpr_passage_index in enumerate(dpr_passage_indices):
+ total_entity_bonus = 0
+ passage_hash_id = self.passage_embedding_store.hash_ids[dpr_passage_index]
+ dpr_passage_score = dpr_passage_scores[i]
+ passage_text_lower = self.passage_embedding_store.hash_id_to_text[passage_hash_id].lower()
+ for entity_hash_id, (entity_id, entity_score, tier) in actived_entities.items():
+ entity_lower = self.entity_embedding_store.hash_id_to_text[entity_hash_id].lower()
+ entity_occurrences = passage_text_lower.count(entity_lower)
+ if entity_occurrences > 0:
+ denom = tier if tier >= 1 else 1
+ entity_bonus = entity_score * math.log(1 + entity_occurrences) / denom
+ total_entity_bonus += entity_bonus
+ passage_score = self.config.passage_ratio * dpr_passage_score + math.log(1 + total_entity_bonus)
+ passage_node_idx = self.node_name_to_vertex_idx[passage_hash_id]
+ passage_weights[passage_node_idx] = passage_score * self.config.passage_node_weight
+ return passage_weights
+
+ def dense_passage_retrieval(self, question_embedding):
+ question_emb = question_embedding.reshape(1, -1)
+ question_passage_similarities = np.dot(self.passage_embeddings, question_emb.T).flatten()
+ sorted_passage_indices = np.argsort(question_passage_similarities)[::-1]
+ sorted_passage_scores = question_passage_similarities[sorted_passage_indices].tolist()
+ return sorted_passage_indices, sorted_passage_scores
+
+ def get_seed_entities(self, question):
+ question_entities = list(self.spacy_ner.question_ner(question))
+ if len(question_entities) == 0:
+ return [],[],[],[]
+ question_entity_embeddings = self.config.embedding_model.encode(question_entities,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size)
+ similarities = np.dot(self.entity_embeddings, question_entity_embeddings.T)
+ seed_entity_indices = []
+ seed_entity_texts = []
+ seed_entity_hash_ids = []
+ seed_entity_scores = []
+ for query_entity_idx in range(len(question_entities)):
+ entity_scores = similarities[:, query_entity_idx]
+ best_entity_idx = np.argmax(entity_scores)
+ best_entity_score = entity_scores[best_entity_idx]
+ if best_entity_score < self.config.initial_threshold:
+ continue
+ best_entity_hash_id = self.entity_hash_ids[best_entity_idx]
+ best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id]
+ seed_entity_indices.append(best_entity_idx)
+ seed_entity_texts.append(best_entity_text)
+ seed_entity_hash_ids.append(best_entity_hash_id)
+ seed_entity_scores.append(best_entity_score)
+ return seed_entity_indices, seed_entity_texts, seed_entity_hash_ids, seed_entity_scores
+
+ def index(self, passages):
+ self.node_to_node_stats = defaultdict(dict)
+ self.entity_to_sentence_stats = defaultdict(dict)
+ self.passage_embedding_store.insert_text(passages)
+ hash_id_to_passage = self.passage_embedding_store.get_hash_id_to_text()
+ existing_passage_hash_id_to_entities,existing_sentence_to_entities, new_passage_hash_ids = self.load_existing_data(hash_id_to_passage.keys())
+ if len(new_passage_hash_ids) > 0:
+ new_hash_id_to_passage = {k : hash_id_to_passage[k] for k in new_passage_hash_ids}
+ new_passage_hash_id_to_entities,new_sentence_to_entities = self.spacy_ner.batch_ner(new_hash_id_to_passage, self.config.max_workers)
+ self.merge_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities)
+ self.save_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities)
+ entity_nodes, sentence_nodes,passage_hash_id_to_entities,self.entity_to_sentence,self.sentence_to_entity = self.extract_nodes_and_edges(existing_passage_hash_id_to_entities, existing_sentence_to_entities)
+ self.sentence_embedding_store.insert_text(list(sentence_nodes))
+ self.entity_embedding_store.insert_text(list(entity_nodes))
+ self.entity_hash_id_to_sentence_hash_ids = {}
+ for entity, sentence in self.entity_to_sentence.items():
+ entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity]
+ self.entity_hash_id_to_sentence_hash_ids[entity_hash_id] = [self.sentence_embedding_store.text_to_hash_id[s] for s in sentence]
+ self.sentence_hash_id_to_entity_hash_ids = {}
+ for sentence, entities in self.sentence_to_entity.items():
+ sentence_hash_id = self.sentence_embedding_store.text_to_hash_id[sentence]
+ self.sentence_hash_id_to_entity_hash_ids[sentence_hash_id] = [self.entity_embedding_store.text_to_hash_id[e] for e in entities]
+ self.add_entity_to_passage_edges(passage_hash_id_to_entities)
+ self.add_adjacent_passage_edges()
+ self.augment_graph()
+ output_graphml_path = os.path.join(self.config.working_dir,self.dataset_name, "LinearRAG.graphml")
+ os.makedirs(os.path.dirname(output_graphml_path), exist_ok=True)
+ self.graph.write_graphml(output_graphml_path)
+
+ def add_adjacent_passage_edges(self):
+ passage_id_to_text = self.passage_embedding_store.get_hash_id_to_text()
+ index_pattern = re.compile(r'^(\d+):')
+ indexed_items = [
+ (int(match.group(1)), node_key)
+ for node_key, text in passage_id_to_text.items()
+ if (match := index_pattern.match(text.strip()))
+ ]
+ indexed_items.sort(key=lambda x: x[0])
+ for i in range(len(indexed_items) - 1):
+ current_node = indexed_items[i][1]
+ next_node = indexed_items[i + 1][1]
+ self.node_to_node_stats[current_node][next_node] = 1.0
+
+ def augment_graph(self):
+ self.add_nodes()
+ self.add_edges()
+
+ def add_nodes(self):
+ existing_nodes = {v["name"]: v for v in self.graph.vs if "name" in v.attributes()}
+ entity_hash_id_to_text = self.entity_embedding_store.get_hash_id_to_text()
+ passage_hash_id_to_text = self.passage_embedding_store.get_hash_id_to_text()
+ all_hash_id_to_text = {**entity_hash_id_to_text, **passage_hash_id_to_text}
+
+ passage_hash_ids = set(passage_hash_id_to_text.keys())
+
+ for hash_id, text in all_hash_id_to_text.items():
+ if hash_id not in existing_nodes:
+ self.graph.add_vertex(name=hash_id, content=text)
+
+ self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()}
+ self.passage_node_indices = [
+ self.node_name_to_vertex_idx[passage_id]
+ for passage_id in passage_hash_ids
+ if passage_id in self.node_name_to_vertex_idx
+ ]
+
+ def add_edges(self):
+ edges = []
+ weights = []
+
+ for node_hash_id, node_to_node_stats in self.node_to_node_stats.items():
+ for neighbor_hash_id, weight in node_to_node_stats.items():
+ if node_hash_id == neighbor_hash_id:
+ continue
+ edges.append((node_hash_id, neighbor_hash_id))
+ weights.append(weight)
+ self.graph.add_edges(edges)
+ self.graph.es['weight'] = weights
+
+
+
+ def add_entity_to_passage_edges(self, passage_hash_id_to_entities):
+ passage_to_entity_count ={}
+ passage_to_all_score = defaultdict(int)
+ for passage_hash_id, entities in passage_hash_id_to_entities.items():
+ passage = self.passage_embedding_store.hash_id_to_text[passage_hash_id]
+ for entity in entities:
+ entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity]
+ count = passage.count(entity)
+ passage_to_entity_count[(passage_hash_id, entity_hash_id)] = count
+ passage_to_all_score[passage_hash_id] += count
+ for (passage_hash_id, entity_hash_id), count in passage_to_entity_count.items():
+ score = count / passage_to_all_score[passage_hash_id]
+ self.node_to_node_stats[passage_hash_id][entity_hash_id] = score
+
+ def extract_nodes_and_edges(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities):
+ entity_nodes = set()
+ sentence_nodes = set()
+ passage_hash_id_to_entities = defaultdict(set)
+ entity_to_sentence= defaultdict(set)
+ sentence_to_entity = defaultdict(set)
+ for passage_hash_id, entities in existing_passage_hash_id_to_entities.items():
+ for entity in entities:
+ entity_nodes.add(entity)
+ passage_hash_id_to_entities[passage_hash_id].add(entity)
+ for sentence,entities in existing_sentence_to_entities.items():
+ sentence_nodes.add(sentence)
+ for entity in entities:
+ entity_to_sentence[entity].add(sentence)
+ sentence_to_entity[sentence].add(entity)
+ return entity_nodes, sentence_nodes, passage_hash_id_to_entities, entity_to_sentence, sentence_to_entity
+
+ def merge_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities):
+ existing_passage_hash_id_to_entities.update(new_passage_hash_id_to_entities)
+ existing_sentence_to_entities.update(new_sentence_to_entities)
+ return existing_passage_hash_id_to_entities, existing_sentence_to_entities
+
+ def save_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities):
+ with open(self.ner_results_path, "w") as f:
+ json.dump({"passage_hash_id_to_entities": existing_passage_hash_id_to_entities, "sentence_to_entities": existing_sentence_to_entities}, f)
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000..6258a65
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from src.utils import LLM_Model
+@dataclass
+class LinearRAGConfig:
+ dataset_name: str
+ embedding_model: str = "all-mpnet-base-v2"
+ llm_model: LLM_Model = None
+ chunk_token_size: int = 1000
+ chunk_overlap_token_size: int = 100
+ spacy_model: str = "en_core_web_trf"
+ working_dir: str = "./import"
+ batch_size: int = 128
+ max_workers: int = 16
+ retrieval_top_k: int = 5
+ max_iterations: int = 3
+ top_k_sentence: int = 2
+ passage_ratio: float = 1.5
+ passage_node_weight: float = 0.05
+ damping: float = 0.5
+ initial_threshold: float = 0.5
+ iteration_threshold: float = 0.5 \ No newline at end of file
diff --git a/src/embedding_store.py b/src/embedding_store.py
new file mode 100644
index 0000000..2aff57e
--- /dev/null
+++ b/src/embedding_store.py
@@ -0,0 +1,80 @@
+from copy import deepcopy
+from src.utils import compute_mdhash_id
+import numpy as np
+import pandas as pd
+import os
+
+class EmbeddingStore:
+ def __init__(self, embedding_model, db_filename, batch_size, namespace):
+ self.embedding_model = embedding_model
+ self.db_filename = db_filename
+ self.batch_size = batch_size
+ self.namespace = namespace
+
+ self.hash_ids = []
+ self.texts = []
+ self.embeddings = []
+ self.hash_id_to_text = {}
+ self.hash_id_to_idx = {}
+ self.text_to_hash_id = {}
+
+ self._load_data()
+
+ def _load_data(self):
+ if os.path.exists(self.db_filename):
+ df = pd.read_parquet(self.db_filename)
+ self.hash_ids = df["hash_id"].values.tolist()
+ self.texts = df["text"].values.tolist()
+ self.embeddings = df["embedding"].values.tolist()
+
+ self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)}
+ self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)}
+ self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)}
+ print(f"[{self.namespace}] Loaded {len(self.hash_ids)} records from {self.db_filename}")
+
+ def insert_text(self, text_list):
+ nodes_dict = {}
+ for text in text_list:
+ nodes_dict[compute_mdhash_id(text, prefix=self.namespace + "-")] = {'content': text}
+
+ all_hash_ids = list(nodes_dict.keys())
+
+ existing = set(self.hash_ids)
+ missing_ids = [h for h in all_hash_ids if h not in existing]
+ texts_to_encode = [nodes_dict[hash_id]["content"] for hash_id in missing_ids]
+ all_embeddings = self.embedding_model.encode(texts_to_encode,normalize_embeddings=True, show_progress_bar=False,batch_size=self.batch_size)
+
+ self._upsert(missing_ids, texts_to_encode, all_embeddings)
+
+ def _upsert(self, hash_ids, texts, embeddings):
+ self.hash_ids.extend(hash_ids)
+ self.texts.extend(texts)
+ self.embeddings.extend(embeddings)
+
+ self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)}
+ self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)}
+ self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)}
+
+ self._save_data()
+
+ def _save_data(self):
+ data_to_save = pd.DataFrame({
+ "hash_id": self.hash_ids,
+ "text": self.texts,
+ "embedding": self.embeddings
+ })
+ os.makedirs(os.path.dirname(self.db_filename), exist_ok=True)
+ data_to_save.to_parquet(self.db_filename, index=False)
+
+ def get_hash_id_to_text(self):
+ return deepcopy(self.hash_id_to_text)
+
+ def encode_texts(self, texts):
+ return self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False, batch_size=self.batch_size)
+
+ def get_embeddings(self, hash_ids):
+ if not hash_ids:
+ return np.array([])
+ indices = np.array([self.hash_id_to_idx[h] for h in hash_ids], dtype=np.intp)
+ embeddings = np.array(self.embeddings)[indices]
+ return embeddings \ No newline at end of file
diff --git a/src/evaluate.py b/src/evaluate.py
new file mode 100644
index 0000000..32e652c
--- /dev/null
+++ b/src/evaluate.py
@@ -0,0 +1,102 @@
+import json
+import os
+from src.utils import normalize_answer
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from tqdm import tqdm
+import logging
+logger = logging.getLogger(__name__)
+
+class Evaluator:
+ def __init__(self, llm_model, predictions_path):
+ self.llm_model = llm_model
+ self.predictions_path = predictions_path
+ self.prediction_results = self.load_predictions()
+
+ def load_predictions(self):
+ prediction_results = json.load(open(self.predictions_path))
+ return prediction_results
+
+ def calculate_llm_accuracy(self,pre_answer,gold_ans):
+ system_prompt = """You are an expert evaluator.
+ """
+ user_prompt = f"""Please evaluate if the generated answer is correct by comparing it with the gold answer.
+ Generated answer: {pre_answer}
+ Gold answer: {gold_ans}
+
+ The generated answer should be considered correct if it:
+ 1. Contains the key information from the gold answer
+ 2. Is factually accurate and consistent with the gold answer
+ 3. Does not contain any contradicting information
+
+ Respond with ONLY 'correct' or 'incorrect'.
+ Response:
+ """
+ response = self.llm_model.infer([{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}])
+ if response.strip().lower() == "correct":
+ return 1.0
+ else:
+ return 0.0
+
+ def calculate_contain(self,pre_answers,gold_ans):
+ if pre_answers is None or pre_answers == "" or (isinstance(pre_answers, str) and pre_answers.strip() == ""):
+ return 0
+ if gold_ans is None or gold_ans == "" or (isinstance(gold_ans, str) and gold_ans.strip() == ""):
+ return 0
+ s1 = normalize_answer(pre_answers)
+ s2 = normalize_answer(gold_ans)
+ if s2 in s1:
+ return 1
+ else:
+ return 0
+ def evaluate_sig_sample(self,idx,prediction):
+ pre_answer = prediction["pred_answer"]
+ gold_ans = prediction["gold_answer"]
+ # llm_acc = 0.0
+ llm_acc = self.calculate_llm_accuracy(pre_answer, gold_ans)
+ contain_acc = self.calculate_contain(pre_answer, gold_ans)
+ return idx, llm_acc, contain_acc
+
+ def evaluate(self,max_workers):
+ llm_scores = [0.0] * len(self.prediction_results)
+ contain_scores = [0.0] * len(self.prediction_results)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {
+ executor.submit(self.evaluate_sig_sample, idx, pred): idx
+ for idx, pred in enumerate(self.prediction_results)
+ }
+
+ completed = 0
+ total_llm_score = 0.0
+ total_contain_score = 0.0
+ pbar = tqdm(total=len(futures), desc="Evaluating samples", unit="sample")
+ for future in as_completed(futures):
+ idx, llm_acc, contain_acc = future.result()
+ llm_scores[idx] = llm_acc
+ contain_scores[idx] = contain_acc
+ self.prediction_results[idx]["llm_accuracy"] = llm_acc
+ self.prediction_results[idx]["contain_accuracy"] = contain_acc
+ total_llm_score += llm_acc
+ total_contain_score += contain_acc
+ completed += 1
+ current_llm_acc = total_llm_score / completed
+ current_contain_acc = total_contain_score / completed
+ pbar.set_postfix({
+ 'LLM_Acc': f'{current_llm_acc:.3f}',
+ 'Contain_Acc': f'{current_contain_acc:.3f}'
+ })
+ pbar.update(1)
+ pbar.close()
+
+ llm_accuracy = sum(llm_scores) / len(llm_scores)
+ contain_accuracy = sum(contain_scores) / len(contain_scores)
+
+ logger.info(f"Evaluation Results:")
+ logger.info(f" LLM Accuracy: {llm_accuracy:.4f} ({sum(llm_scores)}/{len(llm_scores)})")
+ logger.info(f" Contain Accuracy: {contain_accuracy:.4f} ({sum(contain_scores)}/{len(contain_scores)})")
+ with open(self.predictions_path, "w", encoding="utf-8") as f:
+ json.dump(self.prediction_results, f, ensure_ascii=False, indent=4)
+
+ with open(os.path.join(os.path.dirname(self.predictions_path), "evaluation_results.json"), "w", encoding="utf-8") as f:
+ json.dump({"llm_accuracy": llm_accuracy, "contain_accuracy": contain_accuracy}, f, ensure_ascii=False, indent=4)
+ return llm_accuracy, contain_accuracy \ No newline at end of file
diff --git a/src/ner.py b/src/ner.py
new file mode 100644
index 0000000..2ca4afb
--- /dev/null
+++ b/src/ner.py
@@ -0,0 +1,46 @@
+import spacy
+from collections import defaultdict
+
+class SpacyNER:
+ def __init__(self,spacy_model):
+ self.spacy_model = spacy.load(spacy_model)
+
+ def batch_ner(self, hash_id_to_passage, max_workers):
+ passage_list = list(hash_id_to_passage.values())
+ batch_size = len(passage_list) // max_workers
+ docs_list = self.spacy_model.pipe(passage_list,batch_size=batch_size)
+ passage_hash_id_to_entities = {}
+ sentence_to_entities = defaultdict(list)
+ for idx,doc in enumerate(docs_list):
+ passage_hash_id = list(hash_id_to_passage.keys())[idx]
+ single_passage_hash_id_to_entities,single_sentence_to_entities = self.extract_entities_sentences(doc,passage_hash_id)
+ passage_hash_id_to_entities.update(single_passage_hash_id_to_entities)
+ for sent, ents in single_sentence_to_entities.items():
+ for e in ents:
+ if e not in sentence_to_entities[sent]:
+ sentence_to_entities[sent].append(e)
+ return passage_hash_id_to_entities,sentence_to_entities
+
+ def extract_entities_sentences(self, doc,passage_hash_id):
+ sentence_to_entities = defaultdict(list)
+ unique_entities = set()
+ passage_hash_id_to_entities = {}
+ for ent in doc.ents:
+ if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
+ continue
+ sent_text = ent.sent.text
+ ent_text = ent.text
+ if ent_text not in sentence_to_entities[sent_text]:
+ sentence_to_entities[sent_text].append(ent_text)
+ unique_entities.add(ent_text)
+ passage_hash_id_to_entities[passage_hash_id] = list(unique_entities)
+ return passage_hash_id_to_entities,sentence_to_entities
+
+ def question_ner(self, question: str):
+ doc = self.spacy_model(question)
+ question_entities = set()
+ for ent in doc.ents:
+ if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
+ continue
+ question_entities.add(ent.text.lower())
+ return question_entities \ No newline at end of file
diff --git a/src/utils.py b/src/utils.py
new file mode 100644
index 0000000..50ecc34
--- /dev/null
+++ b/src/utils.py
@@ -0,0 +1,77 @@
+from hashlib import md5
+from dataclasses import dataclass, field
+from typing import List, Dict
+import httpx
+from openai import OpenAI
+from collections import defaultdict
+import multiprocessing as mp
+import re
+import string
+import logging
+import numpy as np
+import os
+
+def compute_mdhash_id(content: str, prefix: str = "") -> str:
+ return prefix + md5(content.encode()).hexdigest()
+
+class LLM_Model:
+ def __init__(self, llm_model):
+ http_client = httpx.Client(timeout=60.0, trust_env=False)
+ self.openai_client = OpenAI(
+ api_key=os.getenv("OPENAI_API_KEY"),
+ base_url=os.getenv("OPENAI_BASE_URL"),
+ http_client=http_client
+ )
+ self.llm_config = {
+ "model": llm_model,
+ "max_tokens": 2000,
+ "temperature": 0,
+ }
+ def infer(self, messages):
+ response = self.openai_client.chat.completions.create(**self.llm_config,messages=messages)
+ return response.choices[0].message.content
+
+
+
+def normalize_answer(s):
+ if s is None:
+ return ""
+ if not isinstance(s, str):
+ s = str(s)
+ def remove_articles(text):
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+ def white_space_fix(text):
+ return " ".join(text.split())
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+ def lower(text):
+ return text.lower()
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+def setup_logging(log_file):
+ log_format = '%(asctime)s - %(levelname)s - %(message)s'
+ handlers = [logging.StreamHandler()]
+ os.makedirs(os.path.dirname(log_file), exist_ok=True)
+ handlers.append(logging.FileHandler(log_file, mode='a', encoding='utf-8'))
+ logging.basicConfig(
+ level=logging.INFO,
+ format=log_format,
+ handlers=handlers,
+ force=True
+ )
+ # Suppress noisy HTTP request logs (e.g., 401 Unauthorized) from httpx/openai
+ logging.getLogger("httpx").setLevel(logging.WARNING)
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
+ logging.getLogger("openai").setLevel(logging.WARNING)
+
+def min_max_normalize(x):
+ min_val = np.min(x)
+ max_val = np.max(x)
+ range_val = max_val - min_val
+
+ # Handle the case where all values are the same (range is zero)
+ if range_val == 0:
+ return np.ones_like(x) # Return an array of ones with the same shape as x
+
+ return (x - min_val) / range_val