dream
73 строки · 2.5 Кб
1#!/usr/bin/env python
2
3import logging
4import numpy as np
5import time
6import requests
7
8from flask import Flask, request, jsonify
9from os import getenv
10import sentry_sdk
11
12
13sentry_sdk.init(getenv("SENTRY_DSN"))
14
15logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
16logger = logging.getLogger(__name__)
17
18app = Flask(__name__)
19
20BADLIST_URL = getenv("BADLIST_ANNOTATOR_URL", "http://badlisted-words:8018/badlisted_words_batch")
21FILTER_BADLISTED_WORDS = getenv("FILTER_BADLISTED_WORDS", 0)
22
23
24@app.route("/respond", methods=["POST"])
25def respond():
26st_time = time.time()
27dialogs = request.json["dialogs"]
28response_candidates = [dialog["utterances"][-1]["hypotheses"] for dialog in dialogs]
29
30selected_skill_names = []
31selected_responses = []
32selected_confidences = []
33
34for i, dialog in enumerate(dialogs):
35confidences = []
36responses = []
37skill_names = []
38
39for skill_data in response_candidates[i]:
40if skill_data["text"] and skill_data["confidence"]:
41logger.info(f"Skill {skill_data['skill_name']} returned non-empty hypothesis with non-zero confidence.")
42
43if FILTER_BADLISTED_WORDS:
44try:
45badlist_result = requests.post(
46BADLIST_URL, json={"sentences": [skill_data["text"]]}, timeout=1.5
47).json()[0]["batch"][0]
48except Exception as exc:
49logger.exception(exc)
50sentry_sdk.capture_exception(exc)
51badlist_result = {"bad_words": False}
52if not badlist_result["bad_words"]:
53confidences += [skill_data["confidence"]]
54responses += [skill_data["text"]]
55skill_names += [skill_data["skill_name"]]
56else:
57confidences += [skill_data["confidence"]]
58responses += [skill_data["text"]]
59skill_names += [skill_data["skill_name"]]
60
61best_id = np.argmax(confidences)
62
63selected_skill_names.append(skill_names[best_id])
64selected_responses.append(responses[best_id])
65selected_confidences.append(confidences[best_id])
66
67total_time = time.time() - st_time
68logger.info(f"confidence_based_response_selector exec time = {total_time:.3f}s")
69return jsonify(list(zip(selected_skill_names, selected_responses, selected_confidences)))
70
71
72if __name__ == "__main__":
73app.run(debug=False, host="0.0.0.0", port=3000)
74