dream
81 строка · 3.5 Кб
1import logging2import os3import time4from flask import Flask, request, jsonify5import sentry_sdk6from deeppavlov import build_model7
8logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)9logger = logging.getLogger(__name__)10sentry_sdk.init(os.getenv("SENTRY_DSN"))11
12app = Flask(__name__)13
14config_name = os.getenv("CONFIG")15top_n = int(os.getenv("TOP_N"))16
17try:18fact_retrieval = build_model(config_name, download=True)19logger.info("model loaded")20except Exception as e:21sentry_sdk.capture_exception(e)22logger.exception(e)23raise e24
25
26@app.route("/model", methods=["POST"])27def respond():28st_time = time.time()29inp = request.json30dialog_history_batch = inp.get("dialog_history", [])31entity_substr_batch = inp.get("entity_substr", [[] for _ in dialog_history_batch])32entity_tags_batch = inp.get("entity_tags", [[] for _ in dialog_history_batch])33entity_pages_batch = inp.get("entity_pages", [[] for _ in dialog_history_batch])34sentences_batch = []35for dialog_history in dialog_history_batch:36if (len(dialog_history[-1].split()) > 2 and "?" in dialog_history[-1]) or len(dialog_history) == 1:37sentence = dialog_history[-1]38else:39sentence = " ".join(dialog_history)40sentences_batch.append(sentence)41
42contexts_with_scores_batch = [[] for _ in sentences_batch]43try:44contexts_with_scores_batch = []45contexts_batch, scores_batch, from_linked_page_batch, numbers_batch = fact_retrieval(46sentences_batch, entity_substr_batch, entity_tags_batch, entity_pages_batch47)48for contexts, scores, from_linked_page_list, numbers in zip(49contexts_batch, scores_batch, from_linked_page_batch, numbers_batch50):51contexts_with_scores_linked, contexts_with_scores_not_linked, contexts_with_scores_first = [], [], []52for context, score, from_linked_page, number in zip(contexts, scores, from_linked_page_list, numbers):53if from_linked_page and number > 0:54contexts_with_scores_linked.append((context, score, number))55elif from_linked_page and number == 0:56contexts_with_scores_first.append((context, score, number))57else:58contexts_with_scores_not_linked.append((context, score, number))59contexts_with_scores_linked = sorted(contexts_with_scores_linked, key=lambda x: (x[1], x[2]), reverse=True)60contexts_with_scores_not_linked = sorted(61contexts_with_scores_not_linked, key=lambda x: (x[1], x[2]), reverse=True62)63contexts_with_scores = []64contexts_with_scores += [(context, score, True) for context, score, _ in contexts_with_scores_first]65contexts_with_scores += [66(context, score, True) for context, score, _ in contexts_with_scores_linked[: top_n // 2]67]68contexts_with_scores += [69(context, score, False) for context, score, _ in contexts_with_scores_not_linked[: top_n // 2]70]71contexts_with_scores_batch.append(contexts_with_scores)72except Exception as e:73sentry_sdk.capture_exception(e)74logger.exception(e)75total_time = time.time() - st_time76logger.info(f"fact retrieval exec time = {total_time:.3f}s")77return jsonify(contexts_with_scores_batch)78
79
80if __name__ == "__main__":81app.run(debug=False, host="0.0.0.0", port=3000)82