dream
278 строк · 11.6 Кб
1import logging
2import os
3import re
4import time
5import string
6import pickle
7import json
8from itertools import chain, product, zip_longest
9
10import nltk
11import sentry_sdk
12import spacy
13import numpy as np
14from flask import Flask, jsonify, request
15
16from deeppavlov import build_model
17from src.sentence_answer import sentence_answer
18
19sentry_sdk.init(os.getenv("SENTRY_DSN"))
20
21logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.DEBUG)
22logger = logging.getLogger(__name__)
23app = Flask(__name__)
24
25stemmer = nltk.PorterStemmer()
26nlp = spacy.load("en_core_web_sm")
27
28t5_config = os.getenv("CONFIG_T5")
29rel_ranker_config = os.getenv("CONFIG_REL_RANKER")
30add_entity_info = int(os.getenv("ADD_ENTITY_INFO", "0"))
31
32try:
33generative_ie = build_model(t5_config, download=True)
34rel_ranker = build_model(rel_ranker_config, download=True)
35logger.info("property extraction model is loaded.")
36except Exception as e:
37sentry_sdk.capture_exception(e)
38logger.exception(e)
39raise e
40
41rel_type_dict = {}
42relations_all = []
43with open("rel_list.txt", "r") as fl:
44lines = fl.readlines()
45for line in lines:
46rel, rel_type = line.strip().split()
47relations_all.append(rel.replace("_", " "))
48if rel_type == "r":
49rel_type = "relation"
50else:
51rel_type = "property"
52rel_type_dict[rel.replace("_", " ")] = rel_type
53
54config_metadata = json.load(open(rel_ranker_config))["metadata"]["variables"]
55root_path = config_metadata["ROOT_PATH"]
56model_path = config_metadata["MODEL_PATH"].replace("{ROOT_PATH}", root_path)
57rels_path = os.path.expanduser(f"{model_path}/rel_groups.pickle")
58with open(rels_path, "rb") as fl:
59rel_groups_list = pickle.load(fl)
60
61
62def sentrewrite(sentence, init_answer):
63answer = init_answer.strip(".")
64if any([sentence.startswith(elem) for elem in ["what's", "what is"]]):
65for old_tok, new_tok in [
66("what's your", f"{answer} is my"),
67("what is your", f"{answer} is my"),
68("what is", f"{answer} is"),
69("what's", f"{answer} is"),
70]:
71sentence = sentence.replace(old_tok, new_tok)
72elif any([sentence.startswith(elem) for elem in ["where", "when"]]):
73sentence = sentence_answer(sentence, answer)
74elif any([sentence.startswith(elem) for elem in ["is there"]]):
75for old_tok, new_tok in [("is there any", f"{answer} is"), ("is there", f"{answer} is")]:
76sentence = sentence.replace(old_tok, new_tok)
77else:
78sentence = f"{sentence} {init_answer}"
79return sentence
80
81
82def get_relations(uttr_batch, thres=0.5):
83relations_pred_batch = []
84input_batch = list(zip(*product(uttr_batch, relations_all)))
85rels_scores = rel_ranker(*input_batch)
86rels_scores = np.array(rels_scores).reshape((len(uttr_batch), len(relations_all), 2))
87for curr_scores in rels_scores:
88pred_rels = []
89rels_with_scores = [
90(curr_score[1], curr_rel)
91for curr_score, curr_rel in zip(curr_scores, relations_all)
92if curr_score[1] > thres
93]
94for rel_group in rel_groups_list:
95pred_rel_group = [
96(curr_score, curr_rel) for curr_score, curr_rel in rels_with_scores if curr_rel in rel_group
97]
98if len(pred_rel_group) == 1:
99pred_rel = pred_rel_group[0][1]
100pred_rels.append(pred_rel)
101elif len(pred_rel_group) >= 2:
102pred_rel = max(pred_rel_group)[1]
103pred_rels.append(pred_rel)
104relations_pred_batch.append(pred_rels or [""])
105logger.debug(f"rel clf raw output: {relations_pred_batch}")
106return relations_pred_batch
107
108
109def postprocess_triplets(triplets_init, scores_init, uttr):
110triplets, existing_obj = [], []
111scores_dict = {}
112for triplet_init, score in zip(triplets_init, scores_init):
113triplet = ""
114fnd = re.findall(r"<subj> (.*?)<rel> (.*?)<obj> (.*)", triplet_init)
115if fnd and fnd[0][1] in rel_type_dict:
116triplet = list(fnd[0])
117if triplet[0] in ["i", "my"]:
118triplet[0] = "user"
119obj = triplet[2]
120for punc in string.punctuation:
121obj = obj.replace(punc, "")
122if obj in existing_obj:
123prev_triplet, prev_score = scores_dict[obj]
124if score > prev_score:
125triplets.remove(prev_triplet)
126else:
127continue
128scores_dict[obj] = (triplet, score)
129existing_obj.append(obj)
130if obj.islower() and obj.capitalize() in uttr:
131triplet[2] = obj.capitalize()
132triplets.append(triplet)
133return triplets
134
135
136def generate_triplets(uttr_batch, relations_pred_batch):
137triplets_corr_batch = []
138t5_input_uttrs = []
139for uttr, preds in zip(uttr_batch, relations_pred_batch):
140uttrs_mult = [uttr for _ in preds]
141t5_input_uttrs.extend(uttrs_mult)
142relations_pred_flat = list(chain(*relations_pred_batch))
143t5_pred_triplets, t5_pred_scores = generative_ie(t5_input_uttrs, relations_pred_flat)
144logger.debug(f"t5 raw output: {t5_pred_triplets} scores: {t5_pred_scores}")
145
146offset_start = 0
147for uttr, pred_rels in zip(uttr_batch, relations_pred_batch):
148rels_len = len(pred_rels)
149triplets_init = t5_pred_triplets[offset_start : (offset_start + rels_len)]
150scores_init = t5_pred_scores[offset_start : (offset_start + rels_len)]
151offset_start += rels_len
152triplets = postprocess_triplets(triplets_init, scores_init, uttr)
153triplets_corr_batch.append(triplets)
154return triplets_corr_batch
155
156
157def get_result(request):
158st_time = time.time()
159init_uttrs = request.json.get("utterances", [])
160named_entities_batch = request.json.get("named_entities", [[] for _ in init_uttrs])
161entities_with_labels_batch = request.json.get("entities_with_labels", [[] for _ in init_uttrs])
162entity_info_batch = request.json.get("entity_info", [[] for _ in init_uttrs])
163logger.info(
164f"init_uttrs {init_uttrs} entities_with_labels: {entities_with_labels_batch} entity_info: {entity_info_batch}"
165)
166uttrs, indices = [], [0]
167for uttr_list in init_uttrs:
168if len(uttr_list) == 1:
169sents = nltk.sent_tokenize(uttr_list[0]) or [""]
170uttrs.extend(sents)
171else:
172utt_prev = uttr_list[-2]
173utt_prev_sentences = nltk.sent_tokenize(utt_prev)
174utt_prev = utt_prev_sentences[-1].lower()
175utt_cur = uttr_list[-1].lower()
176is_q = (
177any([utt_prev.startswith(q_word) for q_word in ["what ", "who ", "when ", "where "]]) or "?" in utt_prev
178)
179
180is_sentence = False
181parsed_sentence = nlp(utt_cur)
182if parsed_sentence:
183tokens = [elem.text for elem in parsed_sentence]
184tags = [elem.tag_ for elem in parsed_sentence]
185found_verbs = any([tag in tags for tag in ["VB", "VBZ", "VBP", "VBD"]])
186if found_verbs and len(tokens) > 2:
187is_sentence = True
188
189logger.info(f"is_q: {is_q} --- is_s: {is_sentence} --- utt_prev: {utt_prev} --- utt_cur: {utt_cur}")
190if is_q and not is_sentence:
191uttrs.append(sentrewrite(utt_prev, utt_cur))
192else:
193uttrs.append(utt_cur)
194indices.append(len(uttrs))
195
196logger.info(f"input utterances: {uttrs}")
197relations_pred = get_relations(uttrs)
198triplets_batch = generate_triplets(uttrs, relations_pred)
199
200logger.info(f"triplets_batch {triplets_batch}")
201triplets_info_batch = []
202triplets_batch = [list(chain(*triplets_batch[start:end])) for start, end in zip_longest(indices, indices[1:])]
203uttrs = [" ".join(uttrs[start:end]) for start, end in zip_longest(indices, indices[1:])]
204for triplets, uttr, named_entities, entities_with_labels, entity_info_list in zip(
205triplets_batch, uttrs, named_entities_batch, entities_with_labels_batch, entity_info_batch
206):
207uttr = uttr.lower()
208entity_substr_dict = {}
209formatted_triplets, per_triplets = [], []
210if len(uttr.split()) > 2:
211for triplet in triplets:
212if triplet:
213for entity in entities_with_labels:
214entity_substr = entity.get("text", "")
215offsets = entity.get("offsets", [])
216if not offsets:
217start_offset = uttr.find(entity_substr.lower())
218end_offset = start_offset + len(entity_substr)
219offsets = [start_offset, end_offset]
220if entity_substr in [triplet[0], triplet[2]]:
221entity_substr_dict[entity_substr] = {"offsets": offsets}
222
223for entity_info in entity_info_list:
224entity_substr = entity_info.get("entity_substr", "")
225if (
226entity_substr in [triplet[0], triplet[2]]
227or stemmer.stem(entity_substr) in [triplet[0], triplet[2]]
228and "entity_ids" in entity_info
229):
230if entity_substr not in entity_substr_dict:
231entity_substr_dict[entity_substr] = {}
232entity_substr_dict[entity_substr]["entity_ids"] = entity_info["entity_ids"]
233entity_substr_dict[entity_substr]["dbpedia_types"] = entity_info.get("dbpedia_types", [])
234entity_substr_dict[entity_substr]["finegrained_types"] = entity_info.get(
235"entity_id_tags", []
236)
237named_entities_list = [entity for elem in named_entities for entity in elem]
238per_entities = [entity for entity in named_entities_list if entity.get("type", "") == "PER"]
239if triplet[1] in {"have pet", "have family", "have sibling", "have chidren"} and per_entities:
240per_triplet = {
241"subject": triplet[2],
242"property": "name",
243"object": per_entities[0].get("text", ""),
244}
245per_triplets.append(per_triplet)
246
247formatted_triplet = {
248"subject": triplet[0],
249rel_type_dict[triplet[1]]: triplet[1],
250"object": triplet[2],
251}
252formatted_triplets.append(formatted_triplet)
253triplets_info_list = []
254if add_entity_info:
255triplets_info_list.append({"triplets": formatted_triplets, "entity_info": entity_substr_dict})
256else:
257triplets_info_list.append({"triplets": formatted_triplets})
258if per_triplets:
259per_entity_info = [{per_triplet["object"]: {"entity_id_tags": ["PER"]}} for per_triplet in per_triplets]
260if add_entity_info:
261triplets_info_list.append({"per_triplets": per_triplets, "entity_info": per_entity_info})
262else:
263triplets_info_list.append({"per_triplet": per_triplets})
264triplets_info_batch.append(triplets_info_list)
265total_time = time.time() - st_time
266logger.info(triplets_info_batch)
267logger.info(f"property extraction exec time: {total_time: .3f}s")
268return triplets_info_batch
269
270
271@app.route("/respond", methods=["POST"])
272def respond():
273result = get_result(request)
274return jsonify(result)
275
276
277if __name__ == "__main__":
278app.run(debug=False, host="0.0.0.0", port=3000)
279