dream
204 строки · 7.0 Кб
1import logging
2import os
3import time
4from flask import Flask, request, jsonify
5import sentry_sdk
6from deeppavlov import build_model
7
8logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
9logger = logging.getLogger(__name__)
10sentry_sdk.init(os.getenv("SENTRY_DSN"))
11
12app = Flask(__name__)
13
14config_name = os.getenv("CONFIG")
15
16with open("abstract_rels.txt", "r") as inp:
17abstract_rels = [line.strip() for line in inp.readlines()]
18
19try:
20el = build_model(config_name, download=True)
21logger.info("model loaded")
22except Exception as e:
23sentry_sdk.capture_exception(e)
24logger.exception(e)
25raise e
26
27
28def preprocess_context(context_batch):
29"""Preprocesses the context batch by combining previous and current utterances.
30
31Args:
32context_batch (list): List of conversation contexts.
33
34Returns:
35list: Preprocessed context batch.
36"""
37optimized_context_batch = []
38for hist_uttr in context_batch:
39if len(hist_uttr) == 1:
40optimized_context_batch.append(hist_uttr[0])
41else:
42prev_uttr = hist_uttr[-2]
43cur_uttr = hist_uttr[-1]
44is_q = (
45any([prev_uttr.startswith(q_word) for q_word in ["what ", "who ", "when ", "where "]])
46or "?" in prev_uttr
47)
48if is_q and len(cur_uttr.split()) < 3:
49optimized_context_batch.append(f"{prev_uttr} {cur_uttr}")
50else:
51optimized_context_batch.append(cur_uttr)
52
53return optimized_context_batch
54
55
56def process_entity_info(
57entity_substr_batch, entity_ids_batch, conf_batch, entity_id_tags_batch, prex_info_batch, optimized_context_batch
58):
59"""Processes entity information based on various conditions.
60
61Args:
62entity_substr_batch (list): List of entity substrings (entity names).
63entity_ids_batch (list): List of entity IDs.
64conf_batch (list): List of confidences.
65entity_id_tags_batch (list): List of entity ID tags (entity kinds).
66prex_info_batch (list): List of property extraction information.
67optimized_context_batch (list): List of preprocessed conversation contexts.
68
69Returns:
70list: Processed entity information batch.
71"""
72entity_info_batch = []
73for (
74entity_substr_list,
75entity_ids_list,
76conf_list,
77entity_id_tags_list,
78prex_info,
79context,
80) in zip(
81entity_substr_batch,
82entity_ids_batch,
83conf_batch,
84entity_id_tags_batch,
85prex_info_batch,
86optimized_context_batch,
87):
88entity_info_list = []
89triplets = {}
90
91# Extract triplets from property extraction information
92if isinstance(prex_info, list) and prex_info:
93prex_info = prex_info[0]
94if prex_info:
95triplets = prex_info.get("triplets", {})
96
97obj2rel_dict = {}
98for triplet in triplets:
99obj = triplet["object"].lower()
100
101# Determine the relationship type (relation or property)
102if "relation" in triplet:
103rel = triplet["relation"]
104elif "property" in triplet:
105rel = triplet["property"]
106else:
107rel = ""
108obj2rel_dict[obj] = rel
109
110# Process entity information for each entity substring
111for entity_substr, entity_ids, confs, entity_id_tags in zip(
112entity_substr_list,
113entity_ids_list,
114conf_list,
115entity_id_tags_list,
116):
117entity_info = {}
118entity_substr = entity_substr.lower()
119context = context.lower()
120curr_rel = obj2rel_dict.get(entity_substr, "")
121is_abstract = curr_rel.lower().replace("_", " ") in abstract_rels and not any(
122[f" {word} {entity_substr}" in context for word in ["the", "my", "his", "her"]]
123)
124
125filtered_entity_ids, filtered_confs, filtered_entity_id_tags = [], [], []
126
127# Filter entity information based on condition:
128# - Exclude entities marked as "Abstract" in db if they are not considered
129# abstract according to is_abstract.
130for entity_id, conf, entity_id_tag in zip(entity_ids, confs, entity_id_tags):
131if entity_id_tag == "Abstract" and not is_abstract:
132logger.info(f"Contradiction between the entity_kind 'Abstract' and relationship '{curr_rel}'")
133else:
134filtered_entity_ids.append(entity_id)
135filtered_confs.append(conf)
136filtered_entity_id_tags.append(entity_id_tag)
137
138if filtered_entity_ids and entity_substr in context:
139# Construct the entity information dictionary
140entity_info["entity_substr"] = entity_substr
141entity_info["entity_ids"] = filtered_entity_ids
142entity_info["confidences"] = [float(elem[2]) for elem in filtered_confs]
143entity_info["tokens_match_conf"] = [float(elem[0]) for elem in filtered_confs]
144entity_info["entity_id_tags"] = filtered_entity_id_tags
145entity_info_list.append(entity_info)
146# Add the processed entity information to the batch
147entity_info_batch.append(entity_info_list)
148return entity_info_batch
149
150
151@app.route("/model", methods=["POST"])
152def respond():
153"""Main function for responding to a request.
154
155Returns:
156flask.Response: Response containing the processed entity information.
157"""
158st_time = time.time()
159user_ids = request.json.get("user_id", [""])
160entity_substr_batch = request.json.get("entity_substr", [[""]])
161entity_tags_batch = request.json.get(
162"entity_tags",
163[["" for _ in entity_substr_list] for entity_substr_list in entity_substr_batch],
164)
165context_batch = request.json.get("contexts", [[""]])
166prex_info_batch = request.json.get("property_extraction", [{} for _ in entity_substr_batch])
167
168# Preprocess the conversation context
169optimized_context_batch = preprocess_context(context_batch)
170
171entity_info_batch = []
172try:
173(
174entity_substr_batch,
175entity_ids_batch,
176conf_batch,
177entity_id_tags_batch,
178) = el(user_ids, entity_substr_batch, entity_tags_batch)
179
180# Process entity information
181entity_info_batch = process_entity_info(
182entity_substr_batch,
183entity_ids_batch,
184conf_batch,
185entity_id_tags_batch,
186prex_info_batch,
187optimized_context_batch,
188)
189
190except Exception as e:
191sentry_sdk.capture_exception(e)
192logger.exception(e)
193entity_info_batch = [[]] * len(entity_substr_batch)
194
195total_time = time.time() - st_time
196logger.info(f"entity_info_batch: {entity_info_batch}")
197logger.info(f"custom entity linking exec time = {total_time:.3f}s")
198
199# Return the processed entity information
200return jsonify(entity_info_batch)
201
202
203if __name__ == "__main__":
204app.run(debug=False, host="0.0.0.0", port=3000)
205