dream

Форк
0
250 строк · 9.3 Кб
1
import json
2
import logging
3
import os
4
import re
5
import time
6
from copy import deepcopy
7

8
import sentry_sdk
9
import torch
10
from flask import Flask, request, jsonify
11
from sentry_sdk.integrations.flask import FlaskIntegration
12
from transformers import AutoModelForCausalLM, AutoTokenizer
13
from transformers import StoppingCriteria, StoppingCriteriaList
14

15
from common.prompts import META_GOALS_PROMPT
16
from common.universal_templates import GENERATIVE_ROBOT_TEMPLATE
17

18

19
sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
20

21
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
22
logger = logging.getLogger(__name__)
23

24
PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
25
HALF_PRECISION = os.environ.get("HALF_PRECISION", 0)
26
HALF_PRECISION = 0 if HALF_PRECISION is None else bool(int(HALF_PRECISION))
27
USE_FLASH_ATTENTION_2 = os.environ.get("USE_FLASH_ATTENTION_2", 0)
28
USE_FLASH_ATTENTION_2 = 0 if USE_FLASH_ATTENTION_2 is None else bool(int(USE_FLASH_ATTENTION_2))
29

30
logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
31
LANGUAGE = os.getenv("LANGUAGE", "EN")
32
HF_ACCESS_TOKEN = os.environ.get("HF_ACCESS_TOKEN", None)
33
NAMING = {
34
    "EN": ["AI", "Human"],
35
    "RU": ["Assistant", "Human"],
36
}
37
ADDITIONAL_EOS_TOKENS = os.environ.get("ADDITIONAL_EOS_TOKENS", None)  # for RuXGLM: "<|endoftext|>,Human:"
38
if ADDITIONAL_EOS_TOKENS:
39
    ADDITIONAL_EOS_TOKENS = ADDITIONAL_EOS_TOKENS.split(",")
40

41
app = Flask(__name__)
42
logging.getLogger("werkzeug").setLevel("WARNING")
43

44
DEFAULT_CONFIGS = {
45
    "EleutherAI/gpt-j-6B": json.load(open("common/generative_configs/default_generative_config.json", "r")),
46
    "OpenAssistant/pythia-12b-sft-v8-7k-steps": json.load(
47
        open("common/generative_configs/default_generative_config.json", "r")
48
    ),
49
    "togethercomputer/GPT-JT-6B-v1": json.load(open("common/generative_configs/default_generative_config.json", "r")),
50
    "lmsys/vicuna-13b-v1.3": json.load(open("common/generative_configs/default_generative_config.json", "r")),
51
    "dim/xglm-4.5B_ru_v10_epoch_6_step_41141": json.load(open("common/generative_configs/ruxglm_config.json", "r")),
52
    "ai-forever/ruGPT-3.5-13B": json.load(open("common/generative_configs/rugpt35_config.json", "r")),
53
    "NousResearch/Yarn-Mistral-7b-128k": json.load(open("common/generative_configs/transformers_mistral.json", "r")),
54
}
55

56

57
def add_replacement_tokens(text, replacement):
58
    for pair in replacement:
59
        text = re.sub(pair[0], f"{pair[1]} ", text)
60
    return text
61

62

63
def remove_replacement_tokens(text, replacement):
64
    for pair in replacement:
65
        text = re.sub(pair[1], pair[0], text)
66

67
    text = text.replace("\n ", "\n")
68
    return text
69

70

71
def cut_predictions_by_additional_eos(text):
72
    if ADDITIONAL_EOS_TOKENS:
73
        for token in ADDITIONAL_EOS_TOKENS:
74
            text = text.split(token)[0]
75
    return text
76

77

78
class StoppingCriteriaSub(StoppingCriteria):
79
    def __init__(self, stops, tokenizer, prompt, replacement):
80
        super().__init__()
81
        self.stops = stops
82
        self.tokenizer = tokenizer
83
        self.prompt = add_replacement_tokens(prompt, replacement)
84
        self.prompt = tokenizer.decode(tokenizer.encode(self.prompt))
85

86
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
87
        for stop in self.stops:
88
            generated_temp_ids = input_ids.tolist()[0]
89
            if stop in tokenizer.decode(generated_temp_ids)[len(self.prompt) :]:
90
                return True
91

92
        return False
93

94

95
def generate_responses(context, model, tokenizer, prompt, generation_params, continue_last_uttr=False):
96
    outputs = []
97
    dialog_context = ""
98
    if prompt:
99
        dialog_context += prompt + "\n"
100
    s = len(context) % 2
101
    context = [f"{NAMING[LANGUAGE][(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]
102
    if continue_last_uttr:
103
        dialog_context += "\n".join(context)
104
    else:
105
        dialog_context += "\n".join(context) + f"\n{NAMING[LANGUAGE][0]}:"
106

107
    replacement = generation_params.pop("replacement", [])
108
    logger.info(f"replacement: {replacement}")
109
    logger.info(f"generation_params: {generation_params}")
110
    dialog_context = add_replacement_tokens(dialog_context, replacement)
111
    logger.info(f"context inside generate_responses seen as: {dialog_context}")
112
    bot_input_ids = tokenizer([dialog_context], return_tensors="pt").input_ids
113
    stopping_criteria = StoppingCriteriaList(
114
        [
115
            StoppingCriteriaSub(
116
                stops=ADDITIONAL_EOS_TOKENS,
117
                tokenizer=tokenizer,
118
                prompt=dialog_context,
119
                replacement=replacement,
120
            )
121
        ]
122
    )
123
    with torch.no_grad():
124
        if torch.cuda.is_available():
125
            bot_input_ids = bot_input_ids.to("cuda")
126
        chat_history_ids = model.generate(
127
            bot_input_ids,
128
            pad_token_id=tokenizer.eos_token_id,
129
            stopping_criteria=stopping_criteria,
130
            **generation_params,
131
        )
132
    if torch.cuda.is_available():
133
        chat_history_ids = chat_history_ids.cpu()
134
    for result in chat_history_ids:
135
        skip_special_tokens = False if replacement else True
136
        output = tokenizer.decode(result, skip_special_tokens=skip_special_tokens)
137
        # preprocess dialog context to correctly remove it from output
138
        dialog_context = re.sub(r"  +", " ", dialog_context)
139
        dialog_context = dialog_context.replace("\n ", "\n")
140
        output = re.sub(r"  +", " ", output)
141
        output = output.replace("\n ", "\n")
142

143
        result_cut = output.replace(dialog_context + " ", "")
144
        result_cut = cut_predictions_by_additional_eos(result_cut)
145
        result_cut = remove_replacement_tokens(result_cut, replacement)
146
        result_cut = [x.strip() for x in GENERATIVE_ROBOT_TEMPLATE.split(result_cut) if x.strip()][0]
147
        logger.info(f"hypothesis: {result_cut}")
148
        outputs.append(result_cut)
149

150
    return outputs
151

152

153
try:
154
    additional_kwargs = {}
155
    if HF_ACCESS_TOKEN:
156
        additional_kwargs["use_auth_token"] = HF_ACCESS_TOKEN
157

158
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH, **additional_kwargs)
159

160
    if HALF_PRECISION:
161
        additional_kwargs["torch_dtype"] = torch.float16
162
    if USE_FLASH_ATTENTION_2:
163
        additional_kwargs["use_flash_attention_2"] = True
164
        additional_kwargs["trust_remote_code"] = True
165

166
    model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH, **additional_kwargs)
167
    if torch.cuda.is_available():
168
        model.to("cuda")
169
        logger.info("transformers_lm is set to run on cuda")
170

171
    example_response = generate_responses(
172
        ["What is the goal of SpaceX?"],
173
        model,
174
        tokenizer,
175
        "You are a SpaceX Assistant.",
176
        deepcopy(DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH]),
177
    )
178
    logger.info(f"example response: {example_response}")
179
    logger.info("transformers_lm is ready")
180
except Exception as e:
181
    sentry_sdk.capture_exception(e)
182
    logger.exception(e)
183
    raise e
184

185

186
@app.route("/ping", methods=["POST"])
187
def ping():
188
    return "pong"
189

190

191
@app.route("/respond", methods=["POST"])
192
def respond():
193
    st_time = time.time()
194
    contexts = request.json.get("dialog_contexts", [])
195
    prompts = request.json.get("prompts", [])
196
    configs = request.json.get("configs", None)
197
    configs = [None] * len(prompts) if configs is None else configs
198
    configs = [deepcopy(DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH]) if el is None else el for el in configs]
199
    if len(contexts) > 0 and len(prompts) == 0:
200
        prompts = [""] * len(contexts)
201

202
    try:
203
        responses = []
204
        for context, prompt, config in zip(contexts, prompts, configs):
205
            curr_responses = []
206
            outputs = generate_responses(context, model, tokenizer, prompt, config)
207
            for response in outputs:
208
                if len(response) >= 2:
209
                    curr_responses += [response]
210
                else:
211
                    curr_responses += [""]
212
            responses += [curr_responses]
213

214
    except Exception as exc:
215
        logger.exception(exc)
216
        sentry_sdk.capture_exception(exc)
217
        responses = [[""]] * len(contexts)
218

219
    logger.info(f"transformers_lm output: {responses}")
220
    total_time = time.time() - st_time
221
    logger.info(f"transformers_lm exec time: {total_time:.3f}s")
222
    return jsonify(responses)
223

224

225
@app.route("/generate_goals", methods=["POST"])
226
def generate_goals():
227
    st_time = time.time()
228

229
    prompts = request.json.get("prompts", None)
230
    prompts = [] if prompts is None else prompts
231
    configs = request.json.get("configs", None)
232
    configs = [None] * len(prompts) if configs is None else configs
233
    configs = [deepcopy(DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH]) if el is None else el for el in configs]
234

235
    try:
236
        responses = []
237
        for prompt, config in zip(prompts, configs):
238
            context = ["hi", META_GOALS_PROMPT + f"\nPrompt: '''{prompt}'''\nResult:"]
239
            goals_for_prompt = generate_responses(context, model, tokenizer, "", config)[0]
240
            logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")
241
            responses += [goals_for_prompt]
242

243
    except Exception as exc:
244
        logger.info(exc)
245
        sentry_sdk.capture_exception(exc)
246
        responses = [""] * len(prompts)
247

248
    total_time = time.time() - st_time
249
    logger.info(f"openai-api generate_goals exec time: {total_time:.3f}s")
250
    return jsonify(responses)
251

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.