dream
77 строк · 2.3 Кб
1import logging
2import time
3import os
4
5import sentry_sdk
6from catboost import CatBoostClassifier
7from flask import Flask, request, jsonify
8from score import get_features
9
10sentry_sdk.init(os.getenv("SENTRY_DSN"))
11
12logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
13logger = logging.getLogger(__name__)
14logging.getLogger("werkzeug").setLevel("WARNING")
15
16app = Flask(__name__)
17
18
19def get_probas(contexts, hypotheses):
20features = get_features(contexts, hypotheses)
21pred = cb.predict_proba(features)[:, 1]
22return pred
23
24
25try:
26cb = CatBoostClassifier()
27cb.load_model("model-confidence-convert-old_midas.cbm")
28contexts = [
29[
30"i'm good how are you",
31"Spectacular, by all reports! Do you want to know what I can do?",
32"absolutely",
33"I'm a socialbot, and I'm all about chatting with people like you. "
34"I can answer questions, share fun facts, discuss movies, books and news. What do you want to talk about?",
35"let's talk about movies",
36]
37]
38hypotheses = [
39{
40"is_best": True,
41"text": "Kong: Skull Island is a good action movie. What do you think about it?",
42"confidence": 1.0,
43"convers_evaluator_annotator": {
44"isResponseOnTopic": 0.505,
45"isResponseErroneous": 0.938,
46"responseEngagesUser": 0.344,
47"isResponseInteresting": 0.084,
48"isResponseComprehensible": 0.454,
49},
50}
51]
52get_probas(contexts, hypotheses)
53except Exception as e:
54logger.exception("Scorer not loaded")
55sentry_sdk.capture_exception(e)
56raise e
57
58
59@app.route("/batch_model", methods=["POST"])
60def batch_respond():
61st_time = time.time()
62contexts = request.json["contexts"]
63hypotheses = request.json["hypotheses"]
64
65try:
66responses = get_probas(contexts, hypotheses).tolist()
67except Exception as e:
68responses = [0] * len(hypotheses)
69sentry_sdk.capture_exception(e)
70logger.exception(e)
71
72logging.warning(f"hypothesis_scorer exec time {time.time() - st_time}")
73return jsonify([{"batch": responses}])
74
75
76if __name__ == "__main__":
77app.run(debug=False, host="0.0.0.0", port=3000)
78