dream

Форк
0
109 строк · 4.2 Кб
1
import json
2
import logging
3
import os
4
from itertools import chain
5
from typing import List
6

7
import numpy as np
8
import sentry_sdk
9
from deeppavlov import build_model
10
from deeppavlov.core.commands.utils import parse_config, expand_path
11
from flask import Flask, jsonify, request
12
from sentry_sdk.integrations.flask import FlaskIntegration
13
from utils import get_regexp, unite_responses
14

15
# logging here because it conflicts with tf
16
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
17
logger = logging.getLogger(__name__)
18

19
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
20
app = Flask(__name__)
21

22

23
INTENT_PHRASES_PATH = os.environ.get("INTENT_PHRASES_PATH", "intent_phrases.json")
24
CONFIG_NAME = os.environ.get("CONFIG_NAME", None)
25
if CONFIG_NAME is None:
26
    raise NotImplementedError("No config file name is given.")
27

28
try:
29
    intents_model = build_model(CONFIG_NAME, download=True)
30
    logger.info("Model loaded")
31
    regexp = get_regexp(INTENT_PHRASES_PATH)
32
    logger.info("Regexp model loaded")
33
except Exception as e:
34
    sentry_sdk.capture_exception(e)
35
    logger.exception(e)
36
    raise e
37

38
parsed = parse_config(CONFIG_NAME)
39
with open(expand_path(parsed["metadata"]["variables"]["MODEL_PATH"]).joinpath("classes.dict"), "r") as f:
40
    intents = f.read().strip().splitlines()
41
CLS_INTENTS = [el.strip().split("\t")[0] for el in intents]
42
ALL_INTENTS = list(json.load(open(INTENT_PHRASES_PATH))["intent_phrases"].keys())
43
logger.info(f"Considered intents for classifier: {CLS_INTENTS}")
44
logger.info(f"Considered intents from json file: {ALL_INTENTS}")
45

46

47
def get_classifier_predictions(batch_texts: List[List[str]], intents_model, thresholds):
48
    global CLS_INTENTS
49
    if thresholds is None:
50
        # if we do not given thresholds, use 0.5 as default
51
        thresholds = [0.5] * len(CLS_INTENTS)
52
    thresholds = np.array(thresholds)
53
    # make a 1d-list of texts for classifier
54
    sentences = list(chain.from_iterable(batch_texts))
55
    sentences_text_ids = []
56
    for text_id, text in enumerate(batch_texts):
57
        sentences_text_ids += [text_id] * len(text)
58
    sentences_text_ids = np.array(sentences_text_ids)
59

60
    result = []
61
    # classify with intent catcher classifier
62
    if len(sentences) > 0:
63
        _, pred_probas = intents_model(sentences)
64
        for text_id, text in enumerate(batch_texts):
65
            maximized_probas = np.max(pred_probas[sentences_text_ids == text_id], axis=0)
66
            resp = {
67
                intent: {"detected": int(float(proba) > thresh), "confidence": round(float(proba), 3)}
68
                for intent, thresh, proba in zip(CLS_INTENTS, thresholds, maximized_probas)
69
            }
70
            result += [resp]
71
    return result
72

73

74
def predict_intents(batch_texts: List[List[str]], regexp, intents_model, thresholds=None):
75
    global ALL_INTENTS
76
    responds = []
77
    not_detected_utterances = []
78
    for text_id, text in enumerate(batch_texts):
79
        resp = {intent: {"detected": 0, "confidence": 0.0} for intent in ALL_INTENTS}
80
        not_detected_utterance = text.copy()
81
        for intent, reg in regexp.items():
82
            for i, utt in enumerate(text):
83
                if reg.fullmatch(utt):
84
                    logger.info(f"Full match of `{utt}` with `{reg}`.")
85
                    resp[intent]["detected"] = 1
86
                    resp[intent]["confidence"] = 1.0
87
                    not_detected_utterance[i] = None
88
        not_detected_utterance = [utt for utt in not_detected_utterance if utt]
89
        not_detected_utterances.append(not_detected_utterance)
90
        responds.append(resp)
91

92
    if len(not_detected_utterances) > 0 and len(not_detected_utterances[0]) > 0:
93
        classifier_result = get_classifier_predictions(not_detected_utterances, intents_model, thresholds)
94
        return unite_responses(classifier_result, responds, ALL_INTENTS)
95
    else:
96
        return responds
97

98

99
@app.route("/detect", methods=["POST"])
100
def detect():
101
    utterances = request.json["sentences"]
102
    logger.info(f"Input: `{utterances}`.")
103
    results = predict_intents(utterances, regexp, intents_model)
104
    logger.info(f"Output: `{results}`.")
105
    return jsonify(results)
106

107

108
if __name__ == "__main__":
109
    app.run(debug=False, host="0.0.0.0", port=8014)
110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.