dream

Форк
0
278 строк · 11.6 Кб
1
import logging
2
import os
3
import re
4
import time
5
import string
6
import pickle
7
import json
8
from itertools import chain, product, zip_longest
9

10
import nltk
11
import sentry_sdk
12
import spacy
13
import numpy as np
14
from flask import Flask, jsonify, request
15

16
from deeppavlov import build_model
17
from src.sentence_answer import sentence_answer
18

19
sentry_sdk.init(os.getenv("SENTRY_DSN"))
20

21
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG)
22
logger = logging.getLogger(__name__)
23
app = Flask(__name__)
24

25
stemmer = nltk.PorterStemmer()
26
nlp = spacy.load("en_core_web_sm")
27

28
t5_config = os.getenv("CONFIG_T5")
29
rel_ranker_config = os.getenv("CONFIG_REL_RANKER")
30
add_entity_info = int(os.getenv("ADD_ENTITY_INFO", "0"))
31

32
try:
33
    generative_ie = build_model(t5_config, download=True)
34
    rel_ranker = build_model(rel_ranker_config, download=True)
35
    logger.info("property extraction model is loaded.")
36
except Exception as e:
37
    sentry_sdk.capture_exception(e)
38
    logger.exception(e)
39
    raise e
40

41
rel_type_dict = {}
42
relations_all = []
43
with open("rel_list.txt", "r") as fl:
44
    lines = fl.readlines()
45
    for line in lines:
46
        rel, rel_type = line.strip().split()
47
        relations_all.append(rel.replace("_", " "))
48
        if rel_type == "r":
49
            rel_type = "relation"
50
        else:
51
            rel_type = "property"
52
        rel_type_dict[rel.replace("_", " ")] = rel_type
53

54
config_metadata = json.load(open(rel_ranker_config))["metadata"]["variables"]
55
root_path = config_metadata["ROOT_PATH"]
56
model_path = config_metadata["MODEL_PATH"].replace("{ROOT_PATH}", root_path)
57
rels_path = os.path.expanduser(f"{model_path}/rel_groups.pickle")
58
with open(rels_path, "rb") as fl:
59
    rel_groups_list = pickle.load(fl)
60

61

62
def sentrewrite(sentence, init_answer):
63
    answer = init_answer.strip(".")
64
    if any([sentence.startswith(elem) for elem in ["what's", "what is"]]):
65
        for old_tok, new_tok in [
66
            ("what's your", f"{answer} is my"),
67
            ("what is your", f"{answer} is my"),
68
            ("what is", f"{answer} is"),
69
            ("what's", f"{answer} is"),
70
        ]:
71
            sentence = sentence.replace(old_tok, new_tok)
72
    elif any([sentence.startswith(elem) for elem in ["where", "when"]]):
73
        sentence = sentence_answer(sentence, answer)
74
    elif any([sentence.startswith(elem) for elem in ["is there"]]):
75
        for old_tok, new_tok in [("is there any", f"{answer} is"), ("is there", f"{answer} is")]:
76
            sentence = sentence.replace(old_tok, new_tok)
77
    else:
78
        sentence = f"{sentence} {init_answer}"
79
    return sentence
80

81

82
def get_relations(uttr_batch, thres=0.5):
83
    relations_pred_batch = []
84
    input_batch = list(zip(*product(uttr_batch, relations_all)))
85
    rels_scores = rel_ranker(*input_batch)
86
    rels_scores = np.array(rels_scores).reshape((len(uttr_batch), len(relations_all), 2))
87
    for curr_scores in rels_scores:
88
        pred_rels = []
89
        rels_with_scores = [
90
            (curr_score[1], curr_rel)
91
            for curr_score, curr_rel in zip(curr_scores, relations_all)
92
            if curr_score[1] > thres
93
        ]
94
        for rel_group in rel_groups_list:
95
            pred_rel_group = [
96
                (curr_score, curr_rel) for curr_score, curr_rel in rels_with_scores if curr_rel in rel_group
97
            ]
98
            if len(pred_rel_group) == 1:
99
                pred_rel = pred_rel_group[0][1]
100
                pred_rels.append(pred_rel)
101
            elif len(pred_rel_group) >= 2:
102
                pred_rel = max(pred_rel_group)[1]
103
                pred_rels.append(pred_rel)
104
        relations_pred_batch.append(pred_rels or [""])
105
    logger.debug(f"rel clf raw output: {relations_pred_batch}")
106
    return relations_pred_batch
107

108

109
def postprocess_triplets(triplets_init, scores_init, uttr):
110
    triplets, existing_obj = [], []
111
    scores_dict = {}
112
    for triplet_init, score in zip(triplets_init, scores_init):
113
        triplet = ""
114
        fnd = re.findall(r"<subj> (.*?)<rel> (.*?)<obj> (.*)", triplet_init)
115
        if fnd and fnd[0][1] in rel_type_dict:
116
            triplet = list(fnd[0])
117
            if triplet[0] in ["i", "my"]:
118
                triplet[0] = "user"
119
            obj = triplet[2]
120
            for punc in string.punctuation:
121
                obj = obj.replace(punc, "")
122
            if obj in existing_obj:
123
                prev_triplet, prev_score = scores_dict[obj]
124
                if score > prev_score:
125
                    triplets.remove(prev_triplet)
126
                else:
127
                    continue
128
            scores_dict[obj] = (triplet, score)
129
            existing_obj.append(obj)
130
            if obj.islower() and obj.capitalize() in uttr:
131
                triplet[2] = obj.capitalize()
132
        triplets.append(triplet)
133
    return triplets
134

135

136
def generate_triplets(uttr_batch, relations_pred_batch):
137
    triplets_corr_batch = []
138
    t5_input_uttrs = []
139
    for uttr, preds in zip(uttr_batch, relations_pred_batch):
140
        uttrs_mult = [uttr for _ in preds]
141
        t5_input_uttrs.extend(uttrs_mult)
142
    relations_pred_flat = list(chain(*relations_pred_batch))
143
    t5_pred_triplets, t5_pred_scores = generative_ie(t5_input_uttrs, relations_pred_flat)
144
    logger.debug(f"t5 raw output: {t5_pred_triplets} scores: {t5_pred_scores}")
145

146
    offset_start = 0
147
    for uttr, pred_rels in zip(uttr_batch, relations_pred_batch):
148
        rels_len = len(pred_rels)
149
        triplets_init = t5_pred_triplets[offset_start : (offset_start + rels_len)]
150
        scores_init = t5_pred_scores[offset_start : (offset_start + rels_len)]
151
        offset_start += rels_len
152
        triplets = postprocess_triplets(triplets_init, scores_init, uttr)
153
        triplets_corr_batch.append(triplets)
154
    return triplets_corr_batch
155

156

157
def get_result(request):
158
    st_time = time.time()
159
    init_uttrs = request.json.get("utterances", [])
160
    named_entities_batch = request.json.get("named_entities", [[] for _ in init_uttrs])
161
    entities_with_labels_batch = request.json.get("entities_with_labels", [[] for _ in init_uttrs])
162
    entity_info_batch = request.json.get("entity_info", [[] for _ in init_uttrs])
163
    logger.info(
164
        f"init_uttrs {init_uttrs} entities_with_labels: {entities_with_labels_batch} entity_info: {entity_info_batch}"
165
    )
166
    uttrs, indices = [], [0]
167
    for uttr_list in init_uttrs:
168
        if len(uttr_list) == 1:
169
            sents = nltk.sent_tokenize(uttr_list[0]) or [""]
170
            uttrs.extend(sents)
171
        else:
172
            utt_prev = uttr_list[-2]
173
            utt_prev_sentences = nltk.sent_tokenize(utt_prev)
174
            utt_prev = utt_prev_sentences[-1].lower()
175
            utt_cur = uttr_list[-1].lower()
176
            is_q = (
177
                any([utt_prev.startswith(q_word) for q_word in ["what ", "who ", "when ", "where "]]) or "?" in utt_prev
178
            )
179

180
            is_sentence = False
181
            parsed_sentence = nlp(utt_cur)
182
            if parsed_sentence:
183
                tokens = [elem.text for elem in parsed_sentence]
184
                tags = [elem.tag_ for elem in parsed_sentence]
185
                found_verbs = any([tag in tags for tag in ["VB", "VBZ", "VBP", "VBD"]])
186
                if found_verbs and len(tokens) > 2:
187
                    is_sentence = True
188

189
            logger.info(f"is_q: {is_q} --- is_s: {is_sentence} --- utt_prev: {utt_prev} --- utt_cur: {utt_cur}")
190
            if is_q and not is_sentence:
191
                uttrs.append(sentrewrite(utt_prev, utt_cur))
192
            else:
193
                uttrs.append(utt_cur)
194
        indices.append(len(uttrs))
195

196
    logger.info(f"input utterances: {uttrs}")
197
    relations_pred = get_relations(uttrs)
198
    triplets_batch = generate_triplets(uttrs, relations_pred)
199

200
    logger.info(f"triplets_batch {triplets_batch}")
201
    triplets_info_batch = []
202
    triplets_batch = [list(chain(*triplets_batch[start:end])) for start, end in zip_longest(indices, indices[1:])]
203
    uttrs = [" ".join(uttrs[start:end]) for start, end in zip_longest(indices, indices[1:])]
204
    for triplets, uttr, named_entities, entities_with_labels, entity_info_list in zip(
205
        triplets_batch, uttrs, named_entities_batch, entities_with_labels_batch, entity_info_batch
206
    ):
207
        uttr = uttr.lower()
208
        entity_substr_dict = {}
209
        formatted_triplets, per_triplets = [], []
210
        if len(uttr.split()) > 2:
211
            for triplet in triplets:
212
                if triplet:
213
                    for entity in entities_with_labels:
214
                        entity_substr = entity.get("text", "")
215
                        offsets = entity.get("offsets", [])
216
                        if not offsets:
217
                            start_offset = uttr.find(entity_substr.lower())
218
                            end_offset = start_offset + len(entity_substr)
219
                            offsets = [start_offset, end_offset]
220
                        if entity_substr in [triplet[0], triplet[2]]:
221
                            entity_substr_dict[entity_substr] = {"offsets": offsets}
222

223
                    for entity_info in entity_info_list:
224
                        entity_substr = entity_info.get("entity_substr", "")
225
                        if (
226
                            entity_substr in [triplet[0], triplet[2]]
227
                            or stemmer.stem(entity_substr) in [triplet[0], triplet[2]]
228
                            and "entity_ids" in entity_info
229
                        ):
230
                            if entity_substr not in entity_substr_dict:
231
                                entity_substr_dict[entity_substr] = {}
232
                            entity_substr_dict[entity_substr]["entity_ids"] = entity_info["entity_ids"]
233
                            entity_substr_dict[entity_substr]["dbpedia_types"] = entity_info.get("dbpedia_types", [])
234
                            entity_substr_dict[entity_substr]["finegrained_types"] = entity_info.get(
235
                                "entity_id_tags", []
236
                            )
237
                    named_entities_list = [entity for elem in named_entities for entity in elem]
238
                    per_entities = [entity for entity in named_entities_list if entity.get("type", "") == "PER"]
239
                    if triplet[1] in {"have pet", "have family", "have sibling", "have chidren"} and per_entities:
240
                        per_triplet = {
241
                            "subject": triplet[2],
242
                            "property": "name",
243
                            "object": per_entities[0].get("text", ""),
244
                        }
245
                        per_triplets.append(per_triplet)
246

247
                    formatted_triplet = {
248
                        "subject": triplet[0],
249
                        rel_type_dict[triplet[1]]: triplet[1],
250
                        "object": triplet[2],
251
                    }
252
                    formatted_triplets.append(formatted_triplet)
253
        triplets_info_list = []
254
        if add_entity_info:
255
            triplets_info_list.append({"triplets": formatted_triplets, "entity_info": entity_substr_dict})
256
        else:
257
            triplets_info_list.append({"triplets": formatted_triplets})
258
        if per_triplets:
259
            per_entity_info = [{per_triplet["object"]: {"entity_id_tags": ["PER"]}} for per_triplet in per_triplets]
260
            if add_entity_info:
261
                triplets_info_list.append({"per_triplets": per_triplets, "entity_info": per_entity_info})
262
            else:
263
                triplets_info_list.append({"per_triplet": per_triplets})
264
        triplets_info_batch.append(triplets_info_list)
265
    total_time = time.time() - st_time
266
    logger.info(triplets_info_batch)
267
    logger.info(f"property extraction exec time: {total_time: .3f}s")
268
    return triplets_info_batch
269

270

271
@app.route("/respond", methods=["POST"])
272
def respond():
273
    result = get_result(request)
274
    return jsonify(result)
275

276

277
if __name__ == "__main__":
278
    app.run(debug=False, host="0.0.0.0", port=3000)
279

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.