dream

Форк
0
177 строк · 7.0 Кб
1
import importlib
2
import re
3
from logging import getLogger
4

5
import pkg_resources
6
import spacy
7

8
log = getLogger(__name__)
9

10
# en_core_web_sm is installed and used by test_inferring_pretrained_model in the same interpreter session during tests.
11
# Spacy checks en_core_web_sm package presence with pkg_resources, but pkg_resources is initialized with interpreter,
12
# sot it doesn't see en_core_web_sm installed after interpreter initialization, so we use importlib.reload below.
13

14
if "en-core-web-sm" not in pkg_resources.working_set.by_key.keys():
15
    importlib.reload(pkg_resources)
16

17
# TODO: move nlp to sentence_answer, sentence_answer to rel_ranking_infer and revise en_core_web_sm requirement,
18
# TODO: make proper downloading with spacy.cli.download
19
nlp = spacy.load("en_core_web_sm")
20

21
pronouns = ["who", "what", "when", "where", "how"]
22

23

24
def find_tokens(tokens, node, not_inc_node):
25
    if node != not_inc_node:
26
        tokens.append(node.text)
27
        for elem in node.children:
28
            tokens = find_tokens(tokens, elem, not_inc_node)
29
    return tokens
30

31

32
def find_inflect_dict(sent_nodes):
33
    inflect_dict = {}
34
    for node in sent_nodes:
35
        if node.dep_ == "aux" and node.tag_ == "VBD" and (node.head.tag_ == "VBP" or node.head.tag_ == "VB"):
36
            inflect_dict[node.text] = ""
37
        if node.dep_ == "aux" and node.tag_ == "VBZ" and node.head.tag_ == "VB":
38
            inflect_dict[node.text] = ""
39
    return inflect_dict
40

41

42
def find_wh_node(sent_nodes):
43
    wh_node = ""
44
    main_head = ""
45
    wh_node_head = ""
46
    for node in sent_nodes:
47
        if node.text.lower() in pronouns:
48
            wh_node = node
49
            break
50

51
    if wh_node:
52
        wh_node_head = wh_node.head
53
        if wh_node_head.dep_ == "ccomp":
54
            main_head = wh_node_head.head
55

56
    return wh_node, wh_node_head, main_head
57

58

59
def find_tokens_to_replace(wh_node_head, main_head, question_tokens, question):
60
    redundant_tokens_to_replace = []
61
    question_tokens_to_replace = []
62

63
    if main_head:
64
        redundant_tokens_to_replace = find_tokens([], main_head, wh_node_head)
65
    what_tokens_fnd = re.findall("what (.*) (is|was|does|did) (.*)", question, re.IGNORECASE)
66
    if what_tokens_fnd:
67
        what_tokens = what_tokens_fnd[0][0].split()
68
        if len(what_tokens) <= 2:
69
            redundant_tokens_to_replace += what_tokens
70

71
    wh_node_head_desc = []
72
    if wh_node_head:
73
        wh_node_head_desc = [node for node in wh_node_head.children if node.text != "?"]
74
        wh_node_head_dep = [
75
            node.dep_
76
            for node in wh_node_head.children
77
            if (node.text != "?" and node.dep_ not in ["aux", "prep"] and node.text.lower() not in pronouns)
78
        ]
79
    for node in wh_node_head_desc:
80
        if node.dep_ == "nsubj" and len(wh_node_head_dep) > 1 or node.text.lower() in pronouns or node.dep_ == "aux":
81
            question_tokens_to_replace.append(node.text)
82
            for elem in node.subtree:
83
                question_tokens_to_replace.append(elem.text)
84

85
    question_tokens_to_replace = list(set(question_tokens_to_replace))
86

87
    redundant_replace_substr = []
88
    for token in question_tokens:
89
        if token in redundant_tokens_to_replace:
90
            redundant_replace_substr.append(token)
91
        else:
92
            if redundant_replace_substr:
93
                break
94

95
    redundant_replace_substr = " ".join(redundant_replace_substr)
96

97
    question_replace_substr = []
98

99
    for token in question_tokens:
100
        if token in question_tokens_to_replace:
101
            question_replace_substr.append(token)
102
        else:
103
            if question_replace_substr:
104
                break
105

106
    question_replace_substr = " ".join(question_replace_substr)
107

108
    return redundant_replace_substr, question_replace_substr
109

110

111
def sentence_answer(question, entity_title, entities=None, template_answer=None):
112
    log.debug(f"question {question} entity_title {entity_title} entities {entities} template_answer {template_answer}")
113
    sent_nodes = nlp(question)
114
    reverse = False
115
    if sent_nodes[-2].tag_ == "IN":
116
        reverse = True
117
    question_tokens = [elem.text for elem in sent_nodes]
118
    log.debug(f"spacy tags: {[(elem.text, elem.tag_, elem.dep_, elem.head.text) for elem in sent_nodes]}")
119

120
    inflect_dict = find_inflect_dict(sent_nodes)
121
    wh_node, wh_node_head, main_head = find_wh_node(sent_nodes)
122
    redundant_replace_substr, question_replace_substr = find_tokens_to_replace(
123
        wh_node_head, main_head, question_tokens, question
124
    )
125
    log.debug(f"redundant_replace_substr {redundant_replace_substr} question_replace_substr {question_replace_substr}")
126
    if redundant_replace_substr:
127
        answer = question.replace(redundant_replace_substr, "")
128
    else:
129
        answer = question
130

131
    if answer.endswith("?"):
132
        answer = answer.replace("?", "").strip()
133

134
    if question_replace_substr:
135
        if template_answer and entities:
136
            answer = template_answer.replace("[ent]", entities[0]).replace("[ans]", entity_title)
137
        elif wh_node.text.lower() in ["what", "who", "how"]:
138
            fnd_date = re.findall(r"what (day|year) (.*)\?", question, re.IGNORECASE)
139
            fnd_wh = re.findall(r"what (is|was) the name of (.*) (which|that) (.*)\?", question, re.IGNORECASE)
140
            fnd_name = re.findall(r"what (is|was) the name (.*)\?", question, re.IGNORECASE)
141
            if fnd_date:
142
                fnd_date_aux = re.findall(rf"what (day|year) (is|was) ({entities[0]}) (.*)\?", question, re.IGNORECASE)
143
                if fnd_date_aux:
144
                    answer = f"{entities[0]} {fnd_date_aux[0][1]} {fnd_date_aux[0][3]} on {entity_title}"
145
                else:
146
                    answer = f"{fnd_date[0][1]} on {entity_title}"
147
            elif fnd_wh:
148
                answer = f"{entity_title} {fnd_wh[0][3]}"
149
            elif fnd_name:
150
                aux_verb, sent_cut = fnd_name[0]
151
                if sent_cut.startswith("of "):
152
                    sent_cut = sent_cut[3:]
153
                answer = f"{entity_title} {aux_verb} {sent_cut}"
154
            else:
155
                if reverse:
156
                    answer = answer.replace(question_replace_substr, "")
157
                    answer = f"{answer} {entity_title}"
158
                else:
159
                    answer = answer.replace(question_replace_substr, entity_title)
160
        elif wh_node.text.lower() in ["when", "where"] and entities:
161
            sent_cut = re.findall(rf"(when|where) (was|is) {entities[0]} (.*)\?", question, re.IGNORECASE)
162
            if sent_cut:
163
                if sent_cut[0][0].lower() == "when":
164
                    answer = f"{entities[0]} {sent_cut[0][1]} {sent_cut[0][2]} on {entity_title}"
165
                else:
166
                    answer = f"{entities[0]} {sent_cut[0][1]} {sent_cut[0][2]} in {entity_title}"
167
            else:
168
                answer = answer.replace(question_replace_substr, "")
169
                answer = f"{answer} in {entity_title}"
170

171
    for old_tok, new_tok in inflect_dict.items():
172
        answer = answer.replace(old_tok, new_tok)
173
    answer = re.sub(r"\s+", " ", answer).strip()
174

175
    answer = answer + "."
176

177
    return answer
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.