blob: e758fde9c9fb152cd967d2beadee36f6ca22462d (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
import spacy
from collections import defaultdict
import pdb
class SpacyNER:
def __init__(self,spacy_model):
self.spacy_model = spacy.load(spacy_model)
def batch_ner(self, hash_id_to_passage, max_workers):
passage_list = list(hash_id_to_passage.values())
batch_size = len(passage_list) // max_workers
docs_list = self.spacy_model.pipe(passage_list,batch_size=batch_size)
passage_hash_id_to_entities = {}
sentence_to_entities = defaultdict(list)
for idx,doc in enumerate(docs_list):
passage_hash_id = list(hash_id_to_passage.keys())[idx]
single_passage_hash_id_to_entities,single_sentence_to_entities = self.extract_entities_sentences(doc,passage_hash_id)
passage_hash_id_to_entities.update(single_passage_hash_id_to_entities)
for sent, ents in single_sentence_to_entities.items():
for e in ents:
if e not in sentence_to_entities[sent]:
sentence_to_entities[sent].append(e)
return passage_hash_id_to_entities,sentence_to_entities
def extract_entities_sentences(self, doc,passage_hash_id):
sentence_to_entities = defaultdict(list)
unique_entities = set()
passage_hash_id_to_entities = {}
# pdb.set_trace() # 注释掉调试断点
for ent in doc.ents:
if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
continue
sent_text = ent.sent.text
ent_text = ent.text
if ent_text not in sentence_to_entities[sent_text]:
sentence_to_entities[sent_text].append(ent_text)
unique_entities.add(ent_text)
passage_hash_id_to_entities[passage_hash_id] = list(unique_entities)
return passage_hash_id_to_entities,sentence_to_entities
def question_ner(self, question: str):
doc = self.spacy_model(question)
question_entities = set()
for ent in doc.ents:
if ent.label_ == "ORDINAL" or ent.label_ == "CARDINAL":
continue
question_entities.add(ent.text.lower())
return question_entities
|