From ccff87c15263d1d63235643d54322b991366952e Mon Sep 17 00:00:00 2001 From: LuyaoZhuang Date: Sun, 26 Oct 2025 05:09:57 -0400 Subject: commit --- src/LinearRAG.py | 351 +++++++++++++++++++++++++++++++++++++++++++++++++ src/config.py | 21 +++ src/embedding_store.py | 80 +++++++++++ src/evaluate.py | 102 ++++++++++++++ src/ner.py | 46 +++++++ src/utils.py | 77 +++++++++++ 6 files changed, 677 insertions(+) create mode 100644 src/LinearRAG.py create mode 100644 src/config.py create mode 100644 src/embedding_store.py create mode 100644 src/evaluate.py create mode 100644 src/ner.py create mode 100644 src/utils.py diff --git a/src/LinearRAG.py b/src/LinearRAG.py new file mode 100644 index 0000000..fa6caaa --- /dev/null +++ b/src/LinearRAG.py @@ -0,0 +1,351 @@ +from src.embedding_store import EmbeddingStore +from src.utils import min_max_normalize +import os +import json +from collections import defaultdict +import numpy as np +import math +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +from src.ner import SpacyNER +import igraph as ig +import re +import logging +logger = logging.getLogger(__name__) + +class LinearRAG: + def __init__(self, global_config): + self.config = global_config + logger.info(f"Initializing LinearRAG with config: {self.config}") + self.dataset_name = global_config.dataset_name + self.load_embedding_store() + self.llm_model = self.config.llm_model + self.spacy_ner = SpacyNER(self.config.spacy_model) + self.graph = ig.Graph(directed=False) + + def load_embedding_store(self): + self.passage_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "passage_embedding.parquet"), batch_size=self.config.batch_size, namespace="passage") + self.entity_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "entity_embedding.parquet"), batch_size=self.config.batch_size, namespace="entity") + self.sentence_embedding_store = EmbeddingStore(self.config.embedding_model, db_filename=os.path.join(self.config.working_dir,self.dataset_name, "sentence_embedding.parquet"), batch_size=self.config.batch_size, namespace="sentence") + + def load_existing_data(self,passage_hash_ids): + self.ner_results_path = os.path.join(self.config.working_dir,self.dataset_name, "ner_results.json") + if os.path.exists(self.ner_results_path): + existing_ner_reuslts = json.load(open(self.ner_results_path)) + existing_passage_hash_id_to_entities = existing_ner_reuslts["passage_hash_id_to_entities"] + existing_sentence_to_entities = existing_ner_reuslts["sentence_to_entities"] + existing_passage_hash_ids = set(existing_passage_hash_id_to_entities.keys()) + new_passage_hash_ids = set(passage_hash_ids) - existing_passage_hash_ids + return existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_ids + else: + return {}, {}, passage_hash_ids + + def qa(self, questions): + retrieval_results = self.retrieve(questions) + system_prompt = f"""As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations.""" + all_messages = [] + for retrieval_result in retrieval_results: + question = retrieval_result["question"] + sorted_passage = retrieval_result["sorted_passage"] + prompt_user = """""" + for passage in sorted_passage: + prompt_user += f"{passage}\n" + prompt_user += f"Question: {question}\n Thought: " + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt_user} + ] + all_messages.append(messages) + with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: + all_qa_results = list(tqdm( + executor.map(self.llm_model.infer, all_messages), + total=len(all_messages), + desc="QA Reading (Parallel)" + )) + + for qa_result,question_info in zip(all_qa_results,retrieval_results): + try: + pred_ans = qa_result.split('Answer:')[1].strip() + except: + pred_ans = qa_result + question_info["pred_answer"] = pred_ans + return retrieval_results + + + def retrieve(self, questions): + self.entity_hash_ids = list(self.entity_embedding_store.hash_id_to_text.keys()) + self.entity_embeddings = np.array(self.entity_embedding_store.embeddings) + self.passage_hash_ids = list(self.passage_embedding_store.hash_id_to_text.keys()) + self.passage_embeddings = np.array(self.passage_embedding_store.embeddings) + self.sentence_hash_ids = list(self.sentence_embedding_store.hash_id_to_text.keys()) + self.sentence_embeddings = np.array(self.sentence_embedding_store.embeddings) + self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()} + self.vertex_idx_to_node_name = {v.index: v["name"] for v in self.graph.vs if "name" in v.attributes()} + + retrieval_results = [] + for question_info in tqdm(questions, desc="Retrieving"): + question = question_info["question"] + question_embedding = self.config.embedding_model.encode(question,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size) + seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores = self.get_seed_entities(question) + if len(seed_entities) != 0: + sorted_passage_hash_ids,sorted_passage_scores = self.graph_search_with_seed_entities(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores) + final_passage_hash_ids = sorted_passage_hash_ids[:self.config.retrieval_top_k] + final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k] + final_passages = [self.passage_embedding_store.hash_id_to_text[passage_hash_id] for passage_hash_id in final_passage_hash_ids] + else: + sorted_passage_indices,sorted_passage_scores = self.dense_passage_retrieval(question_embedding) + final_passage_indices = sorted_passage_indices[:self.config.retrieval_top_k] + final_passage_scores = sorted_passage_scores[:self.config.retrieval_top_k] + final_passages = [self.passage_embedding_store.texts[idx] for idx in final_passage_indices] + result = { + "question": question, + "sorted_passage": final_passages, + "sorted_passage_scores": final_passage_scores, + "gold_answer": question_info["answer"] + } + retrieval_results.append(result) + return retrieval_results + + + + def graph_search_with_seed_entities(self, question_embedding, seed_entity_idx, seed_entities, seed_entity_hash_ids, seed_entity_scores): + entity_weights, actived_entities = self.calculate_entity_scores(question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores) + passage_weights = self.calculate_passage_scores(question_embedding,actived_entities) + node_weights = entity_weights + passage_weights + ppr_sorted_passage_indices,ppr_sorted_passage_scores = self.run_ppr(node_weights) + return ppr_sorted_passage_indices,ppr_sorted_passage_scores + + def run_ppr(self, node_weights): + reset_prob = np.where(np.isnan(node_weights) | (node_weights < 0), 0, node_weights) + pagerank_scores = self.graph.personalized_pagerank( + vertices=range(len(self.node_name_to_vertex_idx)), + damping=self.config.damping, + directed=False, + weights='weight', + reset=reset_prob, + implementation='prpack' + ) + + doc_scores = np.array([pagerank_scores[idx] for idx in self.passage_node_indices]) + sorted_indices_in_doc_scores = np.argsort(doc_scores)[::-1] + sorted_passage_scores = doc_scores[sorted_indices_in_doc_scores] + + sorted_passage_hash_ids = [ + self.vertex_idx_to_node_name[self.passage_node_indices[i]] + for i in sorted_indices_in_doc_scores + ] + + return sorted_passage_hash_ids, sorted_passage_scores.tolist() + + def calculate_entity_scores(self,question_embedding,seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores): + actived_entities = {} + entity_weights = np.zeros(len(self.graph.vs["name"])) + for seed_entity_idx,seed_entity,seed_entity_hash_id,seed_entity_score in zip(seed_entity_idx,seed_entities,seed_entity_hash_ids,seed_entity_scores): + actived_entities[seed_entity_hash_id] = (seed_entity_idx, seed_entity_score, 1) + seed_entity_node_idx = self.node_name_to_vertex_idx[seed_entity_hash_id] + entity_weights[seed_entity_node_idx] = seed_entity_score + used_sentence_hash_ids = set() + current_entities = actived_entities.copy() + iteration = 1 + while len(current_entities) > 0 and iteration < self.config.max_iterations: + new_entities = {} + for entity_hash_id, (entity_id, entity_score, tier) in current_entities.items(): + if entity_score < self.config.iteration_threshold: + continue + sentence_hash_ids = [sid for sid in list(self.entity_hash_id_to_sentence_hash_ids[entity_hash_id]) if sid not in used_sentence_hash_ids] + if not sentence_hash_ids: + continue + sentence_indices = [self.sentence_embedding_store.hash_id_to_idx[sid] for sid in sentence_hash_ids] + sentence_embeddings = self.sentence_embeddings[sentence_indices] + question_emb = question_embedding.reshape(-1, 1) if len(question_embedding.shape) == 1 else question_embedding + sentence_similarities = np.dot(sentence_embeddings, question_emb).flatten() + top_sentence_indices = np.argsort(sentence_similarities)[::-1][:self.config.top_k_sentence] + for top_sentence_index in top_sentence_indices: + top_sentence_hash_id = sentence_hash_ids[top_sentence_index] + top_sentence_score = sentence_similarities[top_sentence_index] + used_sentence_hash_ids.add(top_sentence_hash_id) + entity_hash_ids_in_sentence = self.sentence_hash_id_to_entity_hash_ids[top_sentence_hash_id] + for next_entity_hash_id in entity_hash_ids_in_sentence: + next_entity_score = entity_score * top_sentence_score + if next_entity_score < self.config.iteration_threshold: + continue + next_enitity_node_idx = self.node_name_to_vertex_idx[next_entity_hash_id] + entity_weights[next_enitity_node_idx] += next_entity_score + new_entities[next_entity_hash_id] = (next_enitity_node_idx, next_entity_score, iteration+1) + actived_entities.update(new_entities) + current_entities = new_entities.copy() + iteration += 1 + return entity_weights, actived_entities + + def calculate_passage_scores(self, question_embedding, actived_entities): + passage_weights = np.zeros(len(self.graph.vs["name"])) + dpr_passage_indices, dpr_passage_scores = self.dense_passage_retrieval(question_embedding) + dpr_passage_scores = min_max_normalize(dpr_passage_scores) + for i, dpr_passage_index in enumerate(dpr_passage_indices): + total_entity_bonus = 0 + passage_hash_id = self.passage_embedding_store.hash_ids[dpr_passage_index] + dpr_passage_score = dpr_passage_scores[i] + passage_text_lower = self.passage_embedding_store.hash_id_to_text[passage_hash_id].lower() + for entity_hash_id, (entity_id, entity_score, tier) in actived_entities.items(): + entity_lower = self.entity_embedding_store.hash_id_to_text[entity_hash_id].lower() + entity_occurrences = passage_text_lower.count(entity_lower) + if entity_occurrences > 0: + denom = tier if tier >= 1 else 1 + entity_bonus = entity_score * math.log(1 + entity_occurrences) / denom + total_entity_bonus += entity_bonus + passage_score = self.config.passage_ratio * dpr_passage_score + math.log(1 + total_entity_bonus) + passage_node_idx = self.node_name_to_vertex_idx[passage_hash_id] + passage_weights[passage_node_idx] = passage_score * self.config.passage_node_weight + return passage_weights + + def dense_passage_retrieval(self, question_embedding): + question_emb = question_embedding.reshape(1, -1) + question_passage_similarities = np.dot(self.passage_embeddings, question_emb.T).flatten() + sorted_passage_indices = np.argsort(question_passage_similarities)[::-1] + sorted_passage_scores = question_passage_similarities[sorted_passage_indices].tolist() + return sorted_passage_indices, sorted_passage_scores + + def get_seed_entities(self, question): + question_entities = list(self.spacy_ner.question_ner(question)) + if len(question_entities) == 0: + return [],[],[],[] + question_entity_embeddings = self.config.embedding_model.encode(question_entities,normalize_embeddings=True,show_progress_bar=False,batch_size=self.config.batch_size) + similarities = np.dot(self.entity_embeddings, question_entity_embeddings.T) + seed_entity_indices = [] + seed_entity_texts = [] + seed_entity_hash_ids = [] + seed_entity_scores = [] + for query_entity_idx in range(len(question_entities)): + entity_scores = similarities[:, query_entity_idx] + best_entity_idx = np.argmax(entity_scores) + best_entity_score = entity_scores[best_entity_idx] + if best_entity_score < self.config.initial_threshold: + continue + best_entity_hash_id = self.entity_hash_ids[best_entity_idx] + best_entity_text = self.entity_embedding_store.hash_id_to_text[best_entity_hash_id] + seed_entity_indices.append(best_entity_idx) + seed_entity_texts.append(best_entity_text) + seed_entity_hash_ids.append(best_entity_hash_id) + seed_entity_scores.append(best_entity_score) + return seed_entity_indices, seed_entity_texts, seed_entity_hash_ids, seed_entity_scores + + def index(self, passages): + self.node_to_node_stats = defaultdict(dict) + self.entity_to_sentence_stats = defaultdict(dict) + self.passage_embedding_store.insert_text(passages) + hash_id_to_passage = self.passage_embedding_store.get_hash_id_to_text() + existing_passage_hash_id_to_entities,existing_sentence_to_entities, new_passage_hash_ids = self.load_existing_data(hash_id_to_passage.keys()) + if len(new_passage_hash_ids) > 0: + new_hash_id_to_passage = {k : hash_id_to_passage[k] for k in new_passage_hash_ids} + new_passage_hash_id_to_entities,new_sentence_to_entities = self.spacy_ner.batch_ner(new_hash_id_to_passage, self.config.max_workers) + self.merge_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities) + self.save_ner_results(existing_passage_hash_id_to_entities, existing_sentence_to_entities) + entity_nodes, sentence_nodes,passage_hash_id_to_entities,self.entity_to_sentence,self.sentence_to_entity = self.extract_nodes_and_edges(existing_passage_hash_id_to_entities, existing_sentence_to_entities) + self.sentence_embedding_store.insert_text(list(sentence_nodes)) + self.entity_embedding_store.insert_text(list(entity_nodes)) + self.entity_hash_id_to_sentence_hash_ids = {} + for entity, sentence in self.entity_to_sentence.items(): + entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity] + self.entity_hash_id_to_sentence_hash_ids[entity_hash_id] = [self.sentence_embedding_store.text_to_hash_id[s] for s in sentence] + self.sentence_hash_id_to_entity_hash_ids = {} + for sentence, entities in self.sentence_to_entity.items(): + sentence_hash_id = self.sentence_embedding_store.text_to_hash_id[sentence] + self.sentence_hash_id_to_entity_hash_ids[sentence_hash_id] = [self.entity_embedding_store.text_to_hash_id[e] for e in entities] + self.add_entity_to_passage_edges(passage_hash_id_to_entities) + self.add_adjacent_passage_edges() + self.augment_graph() + output_graphml_path = os.path.join(self.config.working_dir,self.dataset_name, "LinearRAG.graphml") + os.makedirs(os.path.dirname(output_graphml_path), exist_ok=True) + self.graph.write_graphml(output_graphml_path) + + def add_adjacent_passage_edges(self): + passage_id_to_text = self.passage_embedding_store.get_hash_id_to_text() + index_pattern = re.compile(r'^(\d+):') + indexed_items = [ + (int(match.group(1)), node_key) + for node_key, text in passage_id_to_text.items() + if (match := index_pattern.match(text.strip())) + ] + indexed_items.sort(key=lambda x: x[0]) + for i in range(len(indexed_items) - 1): + current_node = indexed_items[i][1] + next_node = indexed_items[i + 1][1] + self.node_to_node_stats[current_node][next_node] = 1.0 + + def augment_graph(self): + self.add_nodes() + self.add_edges() + + def add_nodes(self): + existing_nodes = {v["name"]: v for v in self.graph.vs if "name" in v.attributes()} + entity_hash_id_to_text = self.entity_embedding_store.get_hash_id_to_text() + passage_hash_id_to_text = self.passage_embedding_store.get_hash_id_to_text() + all_hash_id_to_text = {**entity_hash_id_to_text, **passage_hash_id_to_text} + + passage_hash_ids = set(passage_hash_id_to_text.keys()) + + for hash_id, text in all_hash_id_to_text.items(): + if hash_id not in existing_nodes: + self.graph.add_vertex(name=hash_id, content=text) + + self.node_name_to_vertex_idx = {v["name"]: v.index for v in self.graph.vs if "name" in v.attributes()} + self.passage_node_indices = [ + self.node_name_to_vertex_idx[passage_id] + for passage_id in passage_hash_ids + if passage_id in self.node_name_to_vertex_idx + ] + + def add_edges(self): + edges = [] + weights = [] + + for node_hash_id, node_to_node_stats in self.node_to_node_stats.items(): + for neighbor_hash_id, weight in node_to_node_stats.items(): + if node_hash_id == neighbor_hash_id: + continue + edges.append((node_hash_id, neighbor_hash_id)) + weights.append(weight) + self.graph.add_edges(edges) + self.graph.es['weight'] = weights + + + + def add_entity_to_passage_edges(self, passage_hash_id_to_entities): + passage_to_entity_count ={} + passage_to_all_score = defaultdict(int) + for passage_hash_id, entities in passage_hash_id_to_entities.items(): + passage = self.passage_embedding_store.hash_id_to_text[passage_hash_id] + for entity in entities: + entity_hash_id = self.entity_embedding_store.text_to_hash_id[entity] + count = passage.count(entity) + passage_to_entity_count[(passage_hash_id, entity_hash_id)] = count + passage_to_all_score[passage_hash_id] += count + for (passage_hash_id, entity_hash_id), count in passage_to_entity_count.items(): + score = count / passage_to_all_score[passage_hash_id] + self.node_to_node_stats[passage_hash_id][entity_hash_id] = score + + def extract_nodes_and_edges(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities): + entity_nodes = set() + sentence_nodes = set() + passage_hash_id_to_entities = defaultdict(set) + entity_to_sentence= defaultdict(set) + sentence_to_entity = defaultdict(set) + for passage_hash_id, entities in existing_passage_hash_id_to_entities.items(): + for entity in entities: + entity_nodes.add(entity) + passage_hash_id_to_entities[passage_hash_id].add(entity) + for sentence,entities in existing_sentence_to_entities.items(): + sentence_nodes.add(sentence) + for entity in entities: + entity_to_sentence[entity].add(sentence) + sentence_to_entity[sentence].add(entity) + return entity_nodes, sentence_nodes, passage_hash_id_to_entities, entity_to_sentence, sentence_to_entity + + def merge_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities, new_passage_hash_id_to_entities, new_sentence_to_entities): + existing_passage_hash_id_to_entities.update(new_passage_hash_id_to_entities) + existing_sentence_to_entities.update(new_sentence_to_entities) + return existing_passage_hash_id_to_entities, existing_sentence_to_entities + + def save_ner_results(self, existing_passage_hash_id_to_entities, existing_sentence_to_entities): + with open(self.ner_results_path, "w") as f: + json.dump({"passage_hash_id_to_entities": existing_passage_hash_id_to_entities, "sentence_to_entities": existing_sentence_to_entities}, f) diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..6258a65 --- /dev/null +++ b/src/config.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from src.utils import LLM_Model +@dataclass +class LinearRAGConfig: + dataset_name: str + embedding_model: str = "all-mpnet-base-v2" + llm_model: LLM_Model = None + chunk_token_size: int = 1000 + chunk_overlap_token_size: int = 100 + spacy_model: str = "en_core_web_trf" + working_dir: str = "./import" + batch_size: int = 128 + max_workers: int = 16 + retrieval_top_k: int = 5 + max_iterations: int = 3 + top_k_sentence: int = 2 + passage_ratio: float = 1.5 + passage_node_weight: float = 0.05 + damping: float = 0.5 + initial_threshold: float = 0.5 + iteration_threshold: float = 0.5 \ No newline at end of file diff --git a/src/embedding_store.py b/src/embedding_store.py new file mode 100644 index 0000000..2aff57e --- /dev/null +++ b/src/embedding_store.py @@ -0,0 +1,80 @@ +from copy import deepcopy +from src.utils import compute_mdhash_id +import numpy as np +import pandas as pd +import os + +class EmbeddingStore: + def __init__(self, embedding_model, db_filename, batch_size, namespace): + self.embedding_model = embedding_model + self.db_filename = db_filename + self.batch_size = batch_size + self.namespace = namespace + + self.hash_ids = [] + self.texts = [] + self.embeddings = [] + self.hash_id_to_text = {} + self.hash_id_to_idx = {} + self.text_to_hash_id = {} + + self._load_data() + + def _load_data(self): + if os.path.exists(self.db_filename): + df = pd.read_parquet(self.db_filename) + self.hash_ids = df["hash_id"].values.tolist() + self.texts = df["text"].values.tolist() + self.embeddings = df["embedding"].values.tolist() + + self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)} + self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)} + self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)} + print(f"[{self.namespace}] Loaded {len(self.hash_ids)} records from {self.db_filename}") + + def insert_text(self, text_list): + nodes_dict = {} + for text in text_list: + nodes_dict[compute_mdhash_id(text, prefix=self.namespace + "-")] = {'content': text} + + all_hash_ids = list(nodes_dict.keys()) + + existing = set(self.hash_ids) + missing_ids = [h for h in all_hash_ids if h not in existing] + texts_to_encode = [nodes_dict[hash_id]["content"] for hash_id in missing_ids] + all_embeddings = self.embedding_model.encode(texts_to_encode,normalize_embeddings=True, show_progress_bar=False,batch_size=self.batch_size) + + self._upsert(missing_ids, texts_to_encode, all_embeddings) + + def _upsert(self, hash_ids, texts, embeddings): + self.hash_ids.extend(hash_ids) + self.texts.extend(texts) + self.embeddings.extend(embeddings) + + self.hash_id_to_idx = {h: idx for idx, h in enumerate(self.hash_ids)} + self.hash_id_to_text = {h: t for h, t in zip(self.hash_ids, self.texts)} + self.text_to_hash_id = {t: h for t, h in zip(self.texts, self.hash_ids)} + + self._save_data() + + def _save_data(self): + data_to_save = pd.DataFrame({ + "hash_id": self.hash_ids, + "text": self.texts, + "embedding": self.embeddings + }) + os.makedirs(os.path.dirname(self.db_filename), exist_ok=True) + data_to_save.to_parquet(self.db_filename, index=False) + + def get_hash_id_to_text(self): + return deepcopy(self.hash_id_to_text) + + def encode_texts(self, texts): + return self.embedding_model.encode(texts, normalize_embeddings=True, show_progress_bar=False, batch_size=self.batch_size) + + def get_embeddings(self, hash_ids): + if not hash_ids: + return np.array([]) + indices = np.array([self.hash_id_to_idx[h] for h in hash_ids], dtype=np.intp) + embeddings = np.array(self.embeddings)[indices] + return embeddings \ No newline at end of file diff --git a/src/evaluate.py b/src/evaluate.py new file mode 100644 index 0000000..32e652c --- /dev/null +++ b/src/evaluate.py @@ -0,0 +1,102 @@ +import json +import os +from src.utils import normalize_answer +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm +import logging +logger = logging.getLogger(__name__) + +class Evaluator: + def __init__(self, llm_model, predictions_path): + self.llm_model = llm_model + self.predictions_path = predictions_path + self.prediction_results = self.load_predictions() + + def load_predictions(self): + prediction_results = json.load(open(self.predictions_path)) + return prediction_results + + def calculate_llm_accuracy(self,pre_answer,gold_ans): + system_prompt = """You are an expert evaluator. + """ + user_prompt = f"""Please evaluate if the generated answer is correct by comparing it with the gold answer. + Generated answer: {pre_answer} + Gold answer: {gold_ans} + + The generated answer should be considered correct if it: + 1. Contains the key information from the gold answer + 2. Is factually accurate and consistent with the gold answer + 3. Does not contain any contradicting information + + Respond with ONLY 'correct' or 'incorrect'. + Response: + """ + response = self.llm_model.infer([{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]) + if response.strip().lower() == "correct": + return 1.0 + else: + return 0.0 + + def calculate_contain(self,pre_answers,gold_ans): + if pre_answers is None or pre_answers == "" or (isinstance(pre_answers, str) and pre_answers.strip() == ""): + return 0 + if gold_ans is None or gold_ans == "" or (isinstance(gold_ans, str) and gold_ans.strip() == ""): + return 0 + s1 = normalize_answer(pre_answers) + s2 = normalize_answer(gold_ans) + if s2 in s1: + return 1 + else: + return 0 + def evaluate_sig_sample(self,idx,prediction): + pre_answer = prediction["pred_answer"] + gold_ans = prediction["gold_answer"] + # llm_acc = 0.0 + llm_acc = self.calculate_llm_accuracy(pre_answer, gold_ans) + contain_acc = self.calculate_contain(pre_answer, gold_ans) + return idx, llm_acc, contain_acc + + def evaluate(self,max_workers): + llm_scores = [0.0] * len(self.prediction_results) + contain_scores = [0.0] * len(self.prediction_results) + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self.evaluate_sig_sample, idx, pred): idx + for idx, pred in enumerate(self.prediction_results) + } + + completed = 0 + total_llm_score = 0.0 + total_contain_score = 0.0 + pbar = tqdm(total=len(futures), desc="Evaluating samples", unit="sample") + for future in as_completed(futures): + idx, llm_acc, contain_acc = future.result() + llm_scores[idx] = llm_acc + contain_scores[idx] = contain_acc + self.prediction_results[idx]["llm_accuracy"] = llm_acc + self.prediction_results[idx]["contain_accuracy"] = contain_acc + total_llm_score += llm_acc + total_contain_score += contain_acc + completed += 1 + current_llm_acc = total_llm_score / completed + current_contain_acc = total_contain_score / completed + pbar.set_postfix({ + 'LLM_Acc': f'{current_llm_acc:.3f}', + 'Contain_Acc': f'{current_contain_acc:.3f}' + }) + pbar.update(1) + pbar.close() + + llm_accuracy = sum(llm_scores) / len(llm_scores) + contain_accuracy = sum(contain_scores) / len(contain_scores) + + logger.info(f"Evaluation Results:") + logger.info(f" LLM Accuracy: {llm_accuracy:.4f} ({sum(llm_scores)}/{len(llm_scores)})") + logger.info(f" Contain Accuracy: {contain_accuracy:.4f} ({sum(contain_scores)}/{len(contain_scores)})") + with open(self.predictions_path, "w", encoding="utf-8") as f: + json.dump(self.prediction_results, f, ensure_ascii=False, indent=4) + + with open(os.path.join(os.path.dirname(self.predictions_path), "evaluation_results.json"), "w", encoding="utf-8") as f: + json.dump({"llm_accuracy": llm_accuracy, "contain_accuracy": contain_accuracy}, f, ensure_ascii=False, indent=4) + return llm_accuracy, contain_accuracy \ No newline at end of file diff --git a/src/ner.py b/src/ner.py new file mode 100644 index 0000000..2ca4afb --- /dev/null +++ b/src/ner.py @@ -0,0 +1,46 @@ +import spacy +from collections import defaultdict + +class SpacyNER: + def __init__(self,spacy_model): + self.spacy_model = spacy.load(spacy_model) + + def batch_ner(self, hash_id_to_passage, max_workers): + passage_list = list(hash_id_to_passage.values()) + batch_size = len(passage_list) // max_workers + docs_list = self.spacy_model.pipe(passage_list,batch_size=batch_size) + passage_hash_id_to_entities = {} + sentence_to_entities = defaultdict(list) + for idx,doc in enumerate(docs_list): + passage_hash_id = list(hash_id_to_passage.keys())[idx] + single_passage_hash_id_to_entities,single_sentence_to_entities = self.extract_entities_sentences(doc,passage_hash_id) + passage_hash_id_to_entities.update(single_passage_hash_id_to_entities) + for sent, ents in single_sentence_to_entities.items(): + for e in ents: + if e not in sentence_to_entities[sent]: + sentence_to_entities[sent].append(e) + return passage_hash_id_to_entities,sentence_to_entities + + def extract_entities_sentences(self, doc,passage_hash_id): + sentence_to_entities = defaultdict(list) + unique_entities = set() + passage_hash_id_to_entities = {} + for ent in doc.ents: + if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL": + continue + sent_text = ent.sent.text + ent_text = ent.text + if ent_text not in sentence_to_entities[sent_text]: + sentence_to_entities[sent_text].append(ent_text) + unique_entities.add(ent_text) + passage_hash_id_to_entities[passage_hash_id] = list(unique_entities) + return passage_hash_id_to_entities,sentence_to_entities + + def question_ner(self, question: str): + doc = self.spacy_model(question) + question_entities = set() + for ent in doc.ents: + if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL": + continue + question_entities.add(ent.text.lower()) + return question_entities \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..50ecc34 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,77 @@ +from hashlib import md5 +from dataclasses import dataclass, field +from typing import List, Dict +import httpx +from openai import OpenAI +from collections import defaultdict +import multiprocessing as mp +import re +import string +import logging +import numpy as np +import os + +def compute_mdhash_id(content: str, prefix: str = "") -> str: + return prefix + md5(content.encode()).hexdigest() + +class LLM_Model: + def __init__(self, llm_model): + http_client = httpx.Client(timeout=60.0, trust_env=False) + self.openai_client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_BASE_URL"), + http_client=http_client + ) + self.llm_config = { + "model": llm_model, + "max_tokens": 2000, + "temperature": 0, + } + def infer(self, messages): + response = self.openai_client.chat.completions.create(**self.llm_config,messages=messages) + return response.choices[0].message.content + + + +def normalize_answer(s): + if s is None: + return "" + if not isinstance(s, str): + s = str(s) + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + def white_space_fix(text): + return " ".join(text.split()) + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + def lower(text): + return text.lower() + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def setup_logging(log_file): + log_format = '%(asctime)s - %(levelname)s - %(message)s' + handlers = [logging.StreamHandler()] + os.makedirs(os.path.dirname(log_file), exist_ok=True) + handlers.append(logging.FileHandler(log_file, mode='a', encoding='utf-8')) + logging.basicConfig( + level=logging.INFO, + format=log_format, + handlers=handlers, + force=True + ) + # Suppress noisy HTTP request logs (e.g., 401 Unauthorized) from httpx/openai + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + logging.getLogger("openai").setLevel(logging.WARNING) + +def min_max_normalize(x): + min_val = np.min(x) + max_val = np.max(x) + range_val = max_val - min_val + + # Handle the case where all values are the same (range is zero) + if range_val == 0: + return np.ones_like(x) # Return an array of ones with the same shape as x + + return (x - min_val) / range_val -- cgit v1.2.3