dream
109 строк · 4.2 Кб
1import json2import logging3import os4from itertools import chain5from typing import List6
7import numpy as np8import sentry_sdk9from deeppavlov import build_model10from deeppavlov.core.commands.utils import parse_config, expand_path11from flask import Flask, jsonify, request12from sentry_sdk.integrations.flask import FlaskIntegration13from utils import get_regexp, unite_responses14
15# logging here because it conflicts with tf
16logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)17logger = logging.getLogger(__name__)18
19sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])20app = Flask(__name__)21
22
23INTENT_PHRASES_PATH = os.environ.get("INTENT_PHRASES_PATH", "intent_phrases.json")24CONFIG_NAME = os.environ.get("CONFIG_NAME", None)25if CONFIG_NAME is None:26raise NotImplementedError("No config file name is given.")27
28try:29intents_model = build_model(CONFIG_NAME, download=True)30logger.info("Model loaded")31regexp = get_regexp(INTENT_PHRASES_PATH)32logger.info("Regexp model loaded")33except Exception as e:34sentry_sdk.capture_exception(e)35logger.exception(e)36raise e37
38parsed = parse_config(CONFIG_NAME)39with open(expand_path(parsed["metadata"]["variables"]["MODEL_PATH"]).joinpath("classes.dict"), "r") as f:40intents = f.read().strip().splitlines()41CLS_INTENTS = [el.strip().split("\t")[0] for el in intents]42ALL_INTENTS = list(json.load(open(INTENT_PHRASES_PATH))["intent_phrases"].keys())43logger.info(f"Considered intents for classifier: {CLS_INTENTS}")44logger.info(f"Considered intents from json file: {ALL_INTENTS}")45
46
47def get_classifier_predictions(batch_texts: List[List[str]], intents_model, thresholds):48global CLS_INTENTS49if thresholds is None:50# if we do not given thresholds, use 0.5 as default51thresholds = [0.5] * len(CLS_INTENTS)52thresholds = np.array(thresholds)53# make a 1d-list of texts for classifier54sentences = list(chain.from_iterable(batch_texts))55sentences_text_ids = []56for text_id, text in enumerate(batch_texts):57sentences_text_ids += [text_id] * len(text)58sentences_text_ids = np.array(sentences_text_ids)59
60result = []61# classify with intent catcher classifier62if len(sentences) > 0:63_, pred_probas = intents_model(sentences)64for text_id, text in enumerate(batch_texts):65maximized_probas = np.max(pred_probas[sentences_text_ids == text_id], axis=0)66resp = {67intent: {"detected": int(float(proba) > thresh), "confidence": round(float(proba), 3)}68for intent, thresh, proba in zip(CLS_INTENTS, thresholds, maximized_probas)69}70result += [resp]71return result72
73
74def predict_intents(batch_texts: List[List[str]], regexp, intents_model, thresholds=None):75global ALL_INTENTS76responds = []77not_detected_utterances = []78for text_id, text in enumerate(batch_texts):79resp = {intent: {"detected": 0, "confidence": 0.0} for intent in ALL_INTENTS}80not_detected_utterance = text.copy()81for intent, reg in regexp.items():82for i, utt in enumerate(text):83if reg.fullmatch(utt):84logger.info(f"Full match of `{utt}` with `{reg}`.")85resp[intent]["detected"] = 186resp[intent]["confidence"] = 1.087not_detected_utterance[i] = None88not_detected_utterance = [utt for utt in not_detected_utterance if utt]89not_detected_utterances.append(not_detected_utterance)90responds.append(resp)91
92if len(not_detected_utterances) > 0 and len(not_detected_utterances[0]) > 0:93classifier_result = get_classifier_predictions(not_detected_utterances, intents_model, thresholds)94return unite_responses(classifier_result, responds, ALL_INTENTS)95else:96return responds97
98
99@app.route("/detect", methods=["POST"])100def detect():101utterances = request.json["sentences"]102logger.info(f"Input: `{utterances}`.")103results = predict_intents(utterances, regexp, intents_model)104logger.info(f"Output: `{results}`.")105return jsonify(results)106
107
108if __name__ == "__main__":109app.run(debug=False, host="0.0.0.0", port=8014)110