From ccff87c15263d1d63235643d54322b991366952e Mon Sep 17 00:00:00 2001 From: LuyaoZhuang Date: Sun, 26 Oct 2025 05:09:57 -0400 Subject: commit --- src/ner.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/ner.py (limited to 'src/ner.py') diff --git a/src/ner.py b/src/ner.py new file mode 100644 index 0000000..2ca4afb --- /dev/null +++ b/src/ner.py @@ -0,0 +1,46 @@ +import spacy +from collections import defaultdict + +class SpacyNER: + def __init__(self,spacy_model): + self.spacy_model = spacy.load(spacy_model) + + def batch_ner(self, hash_id_to_passage, max_workers): + passage_list = list(hash_id_to_passage.values()) + batch_size = len(passage_list) // max_workers + docs_list = self.spacy_model.pipe(passage_list,batch_size=batch_size) + passage_hash_id_to_entities = {} + sentence_to_entities = defaultdict(list) + for idx,doc in enumerate(docs_list): + passage_hash_id = list(hash_id_to_passage.keys())[idx] + single_passage_hash_id_to_entities,single_sentence_to_entities = self.extract_entities_sentences(doc,passage_hash_id) + passage_hash_id_to_entities.update(single_passage_hash_id_to_entities) + for sent, ents in single_sentence_to_entities.items(): + for e in ents: + if e not in sentence_to_entities[sent]: + sentence_to_entities[sent].append(e) + return passage_hash_id_to_entities,sentence_to_entities + + def extract_entities_sentences(self, doc,passage_hash_id): + sentence_to_entities = defaultdict(list) + unique_entities = set() + passage_hash_id_to_entities = {} + for ent in doc.ents: + if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL": + continue + sent_text = ent.sent.text + ent_text = ent.text + if ent_text not in sentence_to_entities[sent_text]: + sentence_to_entities[sent_text].append(ent_text) + unique_entities.add(ent_text) + passage_hash_id_to_entities[passage_hash_id] = list(unique_entities) + return passage_hash_id_to_entities,sentence_to_entities + + def question_ner(self, question: str): + doc = self.spacy_model(question) + question_entities = set() + for ent in doc.ents: + if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL": + continue + question_entities.add(ent.text.lower()) + return question_entities \ No newline at end of file -- cgit v1.2.3