dream
177 строк · 7.0 Кб
1import importlib2import re3from logging import getLogger4
5import pkg_resources6import spacy7
8log = getLogger(__name__)9
10# en_core_web_sm is installed and used by test_inferring_pretrained_model in the same interpreter session during tests.
11# Spacy checks en_core_web_sm package presence with pkg_resources, but pkg_resources is initialized with interpreter,
12# sot it doesn't see en_core_web_sm installed after interpreter initialization, so we use importlib.reload below.
13
14if "en-core-web-sm" not in pkg_resources.working_set.by_key.keys():15importlib.reload(pkg_resources)16
17# TODO: move nlp to sentence_answer, sentence_answer to rel_ranking_infer and revise en_core_web_sm requirement,
18# TODO: make proper downloading with spacy.cli.download
19nlp = spacy.load("en_core_web_sm")20
21pronouns = ["who", "what", "when", "where", "how"]22
23
24def find_tokens(tokens, node, not_inc_node):25if node != not_inc_node:26tokens.append(node.text)27for elem in node.children:28tokens = find_tokens(tokens, elem, not_inc_node)29return tokens30
31
32def find_inflect_dict(sent_nodes):33inflect_dict = {}34for node in sent_nodes:35if node.dep_ == "aux" and node.tag_ == "VBD" and (node.head.tag_ == "VBP" or node.head.tag_ == "VB"):36inflect_dict[node.text] = ""37if node.dep_ == "aux" and node.tag_ == "VBZ" and node.head.tag_ == "VB":38inflect_dict[node.text] = ""39return inflect_dict40
41
42def find_wh_node(sent_nodes):43wh_node = ""44main_head = ""45wh_node_head = ""46for node in sent_nodes:47if node.text.lower() in pronouns:48wh_node = node49break50
51if wh_node:52wh_node_head = wh_node.head53if wh_node_head.dep_ == "ccomp":54main_head = wh_node_head.head55
56return wh_node, wh_node_head, main_head57
58
59def find_tokens_to_replace(wh_node_head, main_head, question_tokens, question):60redundant_tokens_to_replace = []61question_tokens_to_replace = []62
63if main_head:64redundant_tokens_to_replace = find_tokens([], main_head, wh_node_head)65what_tokens_fnd = re.findall("what (.*) (is|was|does|did) (.*)", question, re.IGNORECASE)66if what_tokens_fnd:67what_tokens = what_tokens_fnd[0][0].split()68if len(what_tokens) <= 2:69redundant_tokens_to_replace += what_tokens70
71wh_node_head_desc = []72if wh_node_head:73wh_node_head_desc = [node for node in wh_node_head.children if node.text != "?"]74wh_node_head_dep = [75node.dep_76for node in wh_node_head.children77if (node.text != "?" and node.dep_ not in ["aux", "prep"] and node.text.lower() not in pronouns)78]79for node in wh_node_head_desc:80if node.dep_ == "nsubj" and len(wh_node_head_dep) > 1 or node.text.lower() in pronouns or node.dep_ == "aux":81question_tokens_to_replace.append(node.text)82for elem in node.subtree:83question_tokens_to_replace.append(elem.text)84
85question_tokens_to_replace = list(set(question_tokens_to_replace))86
87redundant_replace_substr = []88for token in question_tokens:89if token in redundant_tokens_to_replace:90redundant_replace_substr.append(token)91else:92if redundant_replace_substr:93break94
95redundant_replace_substr = " ".join(redundant_replace_substr)96
97question_replace_substr = []98
99for token in question_tokens:100if token in question_tokens_to_replace:101question_replace_substr.append(token)102else:103if question_replace_substr:104break105
106question_replace_substr = " ".join(question_replace_substr)107
108return redundant_replace_substr, question_replace_substr109
110
111def sentence_answer(question, entity_title, entities=None, template_answer=None):112log.debug(f"question {question} entity_title {entity_title} entities {entities} template_answer {template_answer}")113sent_nodes = nlp(question)114reverse = False115if sent_nodes[-2].tag_ == "IN":116reverse = True117question_tokens = [elem.text for elem in sent_nodes]118log.debug(f"spacy tags: {[(elem.text, elem.tag_, elem.dep_, elem.head.text) for elem in sent_nodes]}")119
120inflect_dict = find_inflect_dict(sent_nodes)121wh_node, wh_node_head, main_head = find_wh_node(sent_nodes)122redundant_replace_substr, question_replace_substr = find_tokens_to_replace(123wh_node_head, main_head, question_tokens, question124)125log.debug(f"redundant_replace_substr {redundant_replace_substr} question_replace_substr {question_replace_substr}")126if redundant_replace_substr:127answer = question.replace(redundant_replace_substr, "")128else:129answer = question130
131if answer.endswith("?"):132answer = answer.replace("?", "").strip()133
134if question_replace_substr:135if template_answer and entities:136answer = template_answer.replace("[ent]", entities[0]).replace("[ans]", entity_title)137elif wh_node.text.lower() in ["what", "who", "how"]:138fnd_date = re.findall(r"what (day|year) (.*)\?", question, re.IGNORECASE)139fnd_wh = re.findall(r"what (is|was) the name of (.*) (which|that) (.*)\?", question, re.IGNORECASE)140fnd_name = re.findall(r"what (is|was) the name (.*)\?", question, re.IGNORECASE)141if fnd_date:142fnd_date_aux = re.findall(rf"what (day|year) (is|was) ({entities[0]}) (.*)\?", question, re.IGNORECASE)143if fnd_date_aux:144answer = f"{entities[0]} {fnd_date_aux[0][1]} {fnd_date_aux[0][3]} on {entity_title}"145else:146answer = f"{fnd_date[0][1]} on {entity_title}"147elif fnd_wh:148answer = f"{entity_title} {fnd_wh[0][3]}"149elif fnd_name:150aux_verb, sent_cut = fnd_name[0]151if sent_cut.startswith("of "):152sent_cut = sent_cut[3:]153answer = f"{entity_title} {aux_verb} {sent_cut}"154else:155if reverse:156answer = answer.replace(question_replace_substr, "")157answer = f"{answer} {entity_title}"158else:159answer = answer.replace(question_replace_substr, entity_title)160elif wh_node.text.lower() in ["when", "where"] and entities:161sent_cut = re.findall(rf"(when|where) (was|is) {entities[0]} (.*)\?", question, re.IGNORECASE)162if sent_cut:163if sent_cut[0][0].lower() == "when":164answer = f"{entities[0]} {sent_cut[0][1]} {sent_cut[0][2]} on {entity_title}"165else:166answer = f"{entities[0]} {sent_cut[0][1]} {sent_cut[0][2]} in {entity_title}"167else:168answer = answer.replace(question_replace_substr, "")169answer = f"{answer} in {entity_title}"170
171for old_tok, new_tok in inflect_dict.items():172answer = answer.replace(old_tok, new_tok)173answer = re.sub(r"\s+", " ", answer).strip()174
175answer = answer + "."176
177return answer178