summaryrefslogtreecommitdiff
path: root/src/ner.py
diff options
context:
space:
mode:
authorLuyaoZhuang <zhuangluyao523@gmail.com>2025-10-26 05:09:57 -0400
committerLuyaoZhuang <zhuangluyao523@gmail.com>2025-10-26 05:09:57 -0400
commitccff87c15263d1d63235643d54322b991366952e (patch)
tree71051f88bc8df1b3d40a65df6ca72e4586a1d1aa /src/ner.py
commit
Diffstat (limited to 'src/ner.py')
-rw-r--r--src/ner.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/src/ner.py b/src/ner.py
new file mode 100644
index 0000000..2ca4afb
--- /dev/null
+++ b/src/ner.py
@@ -0,0 +1,46 @@
+import spacy
+from collections import defaultdict
+
+class SpacyNER:
+ def __init__(self,spacy_model):
+ self.spacy_model = spacy.load(spacy_model)
+
+ def batch_ner(self, hash_id_to_passage, max_workers):
+ passage_list = list(hash_id_to_passage.values())
+ batch_size = len(passage_list) // max_workers
+ docs_list = self.spacy_model.pipe(passage_list,batch_size=batch_size)
+ passage_hash_id_to_entities = {}
+ sentence_to_entities = defaultdict(list)
+ for idx,doc in enumerate(docs_list):
+ passage_hash_id = list(hash_id_to_passage.keys())[idx]
+ single_passage_hash_id_to_entities,single_sentence_to_entities = self.extract_entities_sentences(doc,passage_hash_id)
+ passage_hash_id_to_entities.update(single_passage_hash_id_to_entities)
+ for sent, ents in single_sentence_to_entities.items():
+ for e in ents:
+ if e not in sentence_to_entities[sent]:
+ sentence_to_entities[sent].append(e)
+ return passage_hash_id_to_entities,sentence_to_entities
+
+ def extract_entities_sentences(self, doc,passage_hash_id):
+ sentence_to_entities = defaultdict(list)
+ unique_entities = set()
+ passage_hash_id_to_entities = {}
+ for ent in doc.ents:
+ if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
+ continue
+ sent_text = ent.sent.text
+ ent_text = ent.text
+ if ent_text not in sentence_to_entities[sent_text]:
+ sentence_to_entities[sent_text].append(ent_text)
+ unique_entities.add(ent_text)
+ passage_hash_id_to_entities[passage_hash_id] = list(unique_entities)
+ return passage_hash_id_to_entities,sentence_to_entities
+
+ def question_ner(self, question: str):
+ doc = self.spacy_model(question)
+ question_entities = set()
+ for ent in doc.ents:
+ if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
+ continue
+ question_entities.add(ent.text.lower())
+ return question_entities \ No newline at end of file