dream
190 строк · 6.8 Кб
1#!/usr/bin/env python
2
3import logging
4import time
5from os import getenv
6import random
7import pathlib
8import datetime
9import copy
10import difflib
11
12from flask import Flask, request, jsonify
13from healthcheck import HealthCheck
14import sentry_sdk
15from sentry_sdk.integrations.logging import ignore_logger
16
17from common.constants import CAN_NOT_CONTINUE, CAN_CONTINUE_SCENARIO, MUST_CONTINUE
18from common.universal_templates import is_switch_topic, if_chat_about_particular_topic
19
20from common.utils import get_skill_outputs_from_dialog, is_yes
21from common.game_cooperative_skill import game_skill_was_proposed, GAMES_COMPILED_PATTERN, FALLBACK_ACKN_TEXT
22from common.gaming import find_games_in_text
23from common.dialogflow_framework.programy.text_preprocessing import clean_text
24
25from router import run_skills as skill
26
27
28ignore_logger("root")
29
30sentry_sdk.init(getenv("SENTRY_DSN"))
31DB_FILE = pathlib.Path(getenv("DB_FILE", "/data/game-cooperative-skill/game_db.json"))
32MEMORY_LENGTH = 3
33
34logging.basicConfig(format="%(asctime)s - %(pathname)s - %(lineno)d - %(levelname)s - %(message)s", level=logging.INFO)
35logger = logging.getLogger(__name__)
36
37app = Flask(__name__)
38health = HealthCheck(app, "/healthcheck")
39logging.getLogger("werkzeug").setLevel("WARNING")
40
41
42# add your own check function to the healthcheck
43def db_is_updated():
44curr_date = datetime.datetime.now()
45min_update_time = datetime.timedelta(hours=25)
46if DB_FILE.exists():
47file_modification_time = datetime.datetime.fromtimestamp(DB_FILE.lstat().st_mtime)
48data_is_expired = curr_date - min_update_time > file_modification_time
49msg = "db is expired" if data_is_expired else "db is updated"
50msg += f", last modified date of db is {file_modification_time.strftime('%m/%d/%Y, %H:%M:%S')}"
51if data_is_expired:
52sentry_sdk.capture_message(msg)
53return True, msg
54else:
55msg = "db file is not created"
56logger.error(msg)
57sentry_sdk.capture_message(msg)
58return False, msg
59
60
61health.add_check(db_is_updated)
62
63
64def get_agent_intents(last_utter):
65annotations = last_utter.get("annotations", {})
66agent_intents = {}
67for intent_name, intent_detector in annotations.get("intent_catcher", {}).items():
68if intent_detector.get("detected", 0) == 1:
69agent_intents[intent_name] = True
70
71if not agent_intents.get("topic_switching") and (
72is_switch_topic(last_utter)
73or agent_intents.get("exit")
74or agent_intents.get("stupid")
75or agent_intents.get("cant_do")
76or agent_intents.get("tell_me_a_story")
77or agent_intents.get("weather_forecast_intent")
78or agent_intents.get("what_can_you_do")
79or agent_intents.get("what_is_your_job")
80or agent_intents.get("what_is_your_name")
81or agent_intents.get("what_time")
82):
83agent_intents["topic_switching"] = True
84return agent_intents
85
86
87@app.route("/respond", methods=["POST"])
88def respond():
89dialogs_batch = [None]
90st_time = time.time()
91dialogs_batch = request.json["dialogs"]
92rand_seed = request.json.get("rand_seed")
93
94responses = []
95for dialog in dialogs_batch:
96prev_skill_outputs = get_skill_outputs_from_dialog(
97dialog["utterances"][-MEMORY_LENGTH:], "game_cooperative_skill", activated=True
98)
99is_active_last_answer = bool(prev_skill_outputs)
100human_attr = dialog["human"]["attributes"]
101prev_state = human_attr.get("game_cooperative_skill", {}).get("state", {})
102try:
103state = copy.deepcopy(prev_state)
104if state and not is_active_last_answer:
105state["messages"] = []
106# pre_len = len(state.get("messages", []))
107
108last_utter = dialog["human_utterances"][-1]
109
110last_utter_text = last_utter["text"].lower()
111agent_intents = get_agent_intents(last_utter)
112
113# for tests
114attr = {}
115if rand_seed:
116random.seed(int(rand_seed))
117response, state = skill([last_utter_text], state, agent_intents)
118
119# logger.info(f"state = {state}")
120# logger.info(f"last_utter_text = {last_utter_text}")
121# logger.info(f"response = {response}")
122bot_utterance = dialog["bot_utterances"][-1] if dialog["bot_utterances"] else {}
123text = response.get("text", "Sorry")
124if not response.get("confidence"):
125confidence = 0
126elif (
127not is_active_last_answer
128and if_chat_about_particular_topic(
129dialog["human_utterances"][-1],
130bot_utterance,
131compiled_pattern=GAMES_COMPILED_PATTERN,
132)
133and find_games_in_text(last_utter_text)
134):
135confidence = 0
136elif not is_active_last_answer and if_chat_about_particular_topic(
137dialog["human_utterances"][-1],
138bot_utterance,
139compiled_pattern=GAMES_COMPILED_PATTERN,
140):
141confidence = 1
142elif is_active_last_answer:
143confidence = 1
144elif is_yes(dialog["human_utterances"][-1]) and game_skill_was_proposed(bot_utterance):
145confidence = 1
146elif not is_yes(dialog["human_utterances"][-1]) and game_skill_was_proposed(bot_utterance):
147confidence = 0.95
148text = FALLBACK_ACKN_TEXT
149state = prev_state
150elif GAMES_COMPILED_PATTERN.search(last_utter_text) and not is_active_last_answer:
151confidence = 0.98
152else:
153confidence = 0
154
155curr_text = clean_text(text.lower())
156last_text = clean_text(bot_utterance.get("text", "").lower())
157ratio = difflib.SequenceMatcher(None, curr_text.split(), last_text.split()).ratio()
158
159if ratio > 0.95:
160confidence = 0
161
162if confidence == 1:
163can_continue = MUST_CONTINUE
164elif confidence > 0.95:
165can_continue = CAN_CONTINUE_SCENARIO
166else:
167can_continue = CAN_NOT_CONTINUE
168
169human_attr["game_cooperative_skill"] = {"state": state}
170attr["can_continue"] = can_continue
171
172except Exception as exc:
173sentry_sdk.capture_exception(exc)
174logger.exception(exc)
175text = ""
176confidence = 0.0
177human_attr["game_cooperative_skill"] = {"state": prev_state}
178attr = {}
179
180bot_attr = {}
181responses.append((text, confidence, human_attr, bot_attr, attr))
182
183total_time = time.time() - st_time
184logger.info(f"game_cooperative_skill exec time = {total_time:.3f}s")
185
186return jsonify(responses)
187
188
189if __name__ == "__main__":
190app.run(debug=False, host="0.0.0.0", port=3000)
191