dream

Форк
0
158 строк · 5.4 Кб
1
import logging
2
import os
3
import time
4

5
import sentry_sdk
6
import torch
7
from flask import Flask, request, jsonify
8
from sentry_sdk.integrations.flask import FlaskIntegration
9
from transformers import AutoModelForCausalLM, AutoTokenizer
10

11
from common.utils import get_intents
12

13

14
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
15
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
16
logger = logging.getLogger(__name__)
17

18
logging.getLogger("werkzeug").setLevel("INFO")
19
app = Flask(__name__)
20

21
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
22
MAX_PERSONA_SENTENCES = int(os.environ.get("MAX_PERSONA_SENTENCES"))
23

24
SUPER_CONFIDENCE = 1.0
25
DEFAULT_CONFIDENCE = 0.9
26

27
SPECIAL_TOKENS = {
28
    "<sp_1>": "<sp_1>",
29
    "</sp_1>": "</sp_1>",
30
    "<sp_2>": "<sp_2>",
31
    "</sp_2>": "</sp_2>",
32
    "<persona>": "<persona>",
33
    "</persona>": "</persona>",
34
}
35

36
try:
37
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
39
    model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
40

41
    if torch.cuda.is_available():
42
        model.to("cuda")
43
        logger.info("dialogpt_persona_based is set to run on cuda")
44

45
    logger.info("dialogpt_persona_based is ready")
46
except Exception as e:
47
    sentry_sdk.capture_exception(e)
48
    logger.exception(e)
49
    raise e
50

51

52
def generate_response(persona: dict = None, model=None, tokenizer=None, utterances_histories=None):
53
    """generates the next replica of the bot based on a short persona consisting of several sentences.
54

55
    Args:
56
        persona (List[List[str], float]): Top sentences similar to the last replica. Defaults to None.
57
        model (AutoModelForCausalLM): gpt model. Defaults to None.
58
        tokenizer (AutoTokenizer): gpt tokenizer. Defaults to None.
59
        utterances_histories (List[List[str]]): dialog history. Defaults to None.
60

61
    Returns:
62
        str: next utterance
63
    """
64
    vocab_tokens = tokenizer.get_added_vocab()
65

66
    max_likelihood_sentences = persona["persona"]
67
    max_likelihood_sentences = max_likelihood_sentences[:MAX_PERSONA_SENTENCES]
68
    max_likelihood_sentences = " ".join(max_likelihood_sentences)
69
    max_likelihood_sentences = f"{SPECIAL_TOKENS['<persona>']}{max_likelihood_sentences}{SPECIAL_TOKENS['</persona>']}"
70
    persona_ids = tokenizer.encode(max_likelihood_sentences, return_tensors="pt")
71
    persona_ids = persona_ids.to(device)
72

73
    utterances_histories = utterances_histories[0]
74
    history_chat = "".join(
75
        list(
76
            reversed(
77
                [f"<sp_{(i)%2+1}>{item}</sp_{(i)%2+1}>" for i, item in enumerate(reversed(utterances_histories[-1:]))]
78
            )
79
        )
80
    )
81
    history_chat += "<sp_2>"
82

83
    history_ids = tokenizer.encode(history_chat, return_tensors="pt")
84
    history_ids = history_ids.to(device)
85

86
    bot_input_ids = torch.cat([persona_ids, history_ids], dim=-1)
87

88
    model_response = model.generate(
89
        bot_input_ids,
90
        max_length=150,
91
        pad_token_id=tokenizer.eos_token_id,
92
        do_sample=True,
93
        temperature=0.95,
94
        top_k=50,
95
        top_p=0.95,
96
    )
97

98
    model_response = model_response.to(device)
99
    model_response_list = list(model_response[0])
100
    bot_response_decode = None
101
    if vocab_tokens["</sp_2>"] in model_response_list:
102
        end_speaker_index = model_response_list.index(vocab_tokens["</sp_2>"])
103
        model_response = model_response[:, : end_speaker_index + 1]
104

105
        chat_history_ids = model_response
106
        bot_response_decode = tokenizer.decode(
107
            chat_history_ids[0][len(bot_input_ids[0]) - 1 :], skip_special_tokens=True
108
        )
109
    else:
110
        bot_response_decode = ""
111

112
    return bot_response_decode
113

114

115
@app.route("/respond", methods=["POST"])
116
def respond():
117
    start_time = time.time()
118
    responses = []
119
    confidences = []
120

121
    last_annotated_utterances_batch = request.json["last_annotated_utterances"]
122
    utterances_histories = request.json["utterances_histories"]
123
    try:
124
        for utt_pos in range(len(last_annotated_utterances_batch)):
125
            persona = (
126
                last_annotated_utterances_batch[utt_pos].get("annotations", {}).get("relative_persona_extractor", [])
127
            )
128
            response = ""
129
            try:
130
                response = generate_response(
131
                    model=model,
132
                    tokenizer=tokenizer,
133
                    persona=persona,
134
                    utterances_histories=utterances_histories,
135
                )
136
            except Exception as e:
137
                logger.exception(e)
138
                response = ""
139

140
            if "open_question_personal" in get_intents(last_annotated_utterances_batch[utt_pos]):
141
                logger.info("open_question_personal")
142
                responses.append([response])
143
                confidences.append([SUPER_CONFIDENCE])
144
            else:
145
                logger.info("NOT open_question_personal")
146
                responses.append([response])
147
                confidences.append([DEFAULT_CONFIDENCE])
148

149
    except Exception as exc:
150
        logger.exception(exc)
151
        sentry_sdk.capture_exception(exc)
152
        responses = [""] * len(last_annotated_utterances_batch)
153
        confidences = [0.0] * len(last_annotated_utterances_batch)
154

155
    total_time = time.time() - start_time
156
    logger.info(f"dialogpt_persona_based exec time: {total_time:.3f}s")
157

158
    return jsonify(list(zip(responses, confidences)))
159

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.