dream
250 строк · 9.3 Кб
1import json
2import logging
3import os
4import re
5import time
6from copy import deepcopy
7
8import sentry_sdk
9import torch
10from flask import Flask, request, jsonify
11from sentry_sdk.integrations.flask import FlaskIntegration
12from transformers import AutoModelForCausalLM, AutoTokenizer
13from transformers import StoppingCriteria, StoppingCriteriaList
14
15from common.prompts import META_GOALS_PROMPT
16from common.universal_templates import GENERATIVE_ROBOT_TEMPLATE
17
18
19sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])
20
21logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
22logger = logging.getLogger(__name__)
23
24PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")
25HALF_PRECISION = os.environ.get("HALF_PRECISION", 0)
26HALF_PRECISION = 0 if HALF_PRECISION is None else bool(int(HALF_PRECISION))
27USE_FLASH_ATTENTION_2 = os.environ.get("USE_FLASH_ATTENTION_2", 0)
28USE_FLASH_ATTENTION_2 = 0 if USE_FLASH_ATTENTION_2 is None else bool(int(USE_FLASH_ATTENTION_2))
29
30logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")
31LANGUAGE = os.getenv("LANGUAGE", "EN")
32HF_ACCESS_TOKEN = os.environ.get("HF_ACCESS_TOKEN", None)
33NAMING = {
34"EN": ["AI", "Human"],
35"RU": ["Assistant", "Human"],
36}
37ADDITIONAL_EOS_TOKENS = os.environ.get("ADDITIONAL_EOS_TOKENS", None) # for RuXGLM: "<|endoftext|>,Human:"
38if ADDITIONAL_EOS_TOKENS:
39ADDITIONAL_EOS_TOKENS = ADDITIONAL_EOS_TOKENS.split(",")
40
41app = Flask(__name__)
42logging.getLogger("werkzeug").setLevel("WARNING")
43
44DEFAULT_CONFIGS = {
45"EleutherAI/gpt-j-6B": json.load(open("common/generative_configs/default_generative_config.json", "r")),
46"OpenAssistant/pythia-12b-sft-v8-7k-steps": json.load(
47open("common/generative_configs/default_generative_config.json", "r")
48),
49"togethercomputer/GPT-JT-6B-v1": json.load(open("common/generative_configs/default_generative_config.json", "r")),
50"lmsys/vicuna-13b-v1.3": json.load(open("common/generative_configs/default_generative_config.json", "r")),
51"dim/xglm-4.5B_ru_v10_epoch_6_step_41141": json.load(open("common/generative_configs/ruxglm_config.json", "r")),
52"ai-forever/ruGPT-3.5-13B": json.load(open("common/generative_configs/rugpt35_config.json", "r")),
53"NousResearch/Yarn-Mistral-7b-128k": json.load(open("common/generative_configs/transformers_mistral.json", "r")),
54}
55
56
57def add_replacement_tokens(text, replacement):
58for pair in replacement:
59text = re.sub(pair[0], f"{pair[1]} ", text)
60return text
61
62
63def remove_replacement_tokens(text, replacement):
64for pair in replacement:
65text = re.sub(pair[1], pair[0], text)
66
67text = text.replace("\n ", "\n")
68return text
69
70
71def cut_predictions_by_additional_eos(text):
72if ADDITIONAL_EOS_TOKENS:
73for token in ADDITIONAL_EOS_TOKENS:
74text = text.split(token)[0]
75return text
76
77
78class StoppingCriteriaSub(StoppingCriteria):
79def __init__(self, stops, tokenizer, prompt, replacement):
80super().__init__()
81self.stops = stops
82self.tokenizer = tokenizer
83self.prompt = add_replacement_tokens(prompt, replacement)
84self.prompt = tokenizer.decode(tokenizer.encode(self.prompt))
85
86def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
87for stop in self.stops:
88generated_temp_ids = input_ids.tolist()[0]
89if stop in tokenizer.decode(generated_temp_ids)[len(self.prompt) :]:
90return True
91
92return False
93
94
95def generate_responses(context, model, tokenizer, prompt, generation_params, continue_last_uttr=False):
96outputs = []
97dialog_context = ""
98if prompt:
99dialog_context += prompt + "\n"
100s = len(context) % 2
101context = [f"{NAMING[LANGUAGE][(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]
102if continue_last_uttr:
103dialog_context += "\n".join(context)
104else:
105dialog_context += "\n".join(context) + f"\n{NAMING[LANGUAGE][0]}:"
106
107replacement = generation_params.pop("replacement", [])
108logger.info(f"replacement: {replacement}")
109logger.info(f"generation_params: {generation_params}")
110dialog_context = add_replacement_tokens(dialog_context, replacement)
111logger.info(f"context inside generate_responses seen as: {dialog_context}")
112bot_input_ids = tokenizer([dialog_context], return_tensors="pt").input_ids
113stopping_criteria = StoppingCriteriaList(
114[
115StoppingCriteriaSub(
116stops=ADDITIONAL_EOS_TOKENS,
117tokenizer=tokenizer,
118prompt=dialog_context,
119replacement=replacement,
120)
121]
122)
123with torch.no_grad():
124if torch.cuda.is_available():
125bot_input_ids = bot_input_ids.to("cuda")
126chat_history_ids = model.generate(
127bot_input_ids,
128pad_token_id=tokenizer.eos_token_id,
129stopping_criteria=stopping_criteria,
130**generation_params,
131)
132if torch.cuda.is_available():
133chat_history_ids = chat_history_ids.cpu()
134for result in chat_history_ids:
135skip_special_tokens = False if replacement else True
136output = tokenizer.decode(result, skip_special_tokens=skip_special_tokens)
137# preprocess dialog context to correctly remove it from output
138dialog_context = re.sub(r" +", " ", dialog_context)
139dialog_context = dialog_context.replace("\n ", "\n")
140output = re.sub(r" +", " ", output)
141output = output.replace("\n ", "\n")
142
143result_cut = output.replace(dialog_context + " ", "")
144result_cut = cut_predictions_by_additional_eos(result_cut)
145result_cut = remove_replacement_tokens(result_cut, replacement)
146result_cut = [x.strip() for x in GENERATIVE_ROBOT_TEMPLATE.split(result_cut) if x.strip()][0]
147logger.info(f"hypothesis: {result_cut}")
148outputs.append(result_cut)
149
150return outputs
151
152
153try:
154additional_kwargs = {}
155if HF_ACCESS_TOKEN:
156additional_kwargs["use_auth_token"] = HF_ACCESS_TOKEN
157
158tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH, **additional_kwargs)
159
160if HALF_PRECISION:
161additional_kwargs["torch_dtype"] = torch.float16
162if USE_FLASH_ATTENTION_2:
163additional_kwargs["use_flash_attention_2"] = True
164additional_kwargs["trust_remote_code"] = True
165
166model = AutoModelForCausalLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH, **additional_kwargs)
167if torch.cuda.is_available():
168model.to("cuda")
169logger.info("transformers_lm is set to run on cuda")
170
171example_response = generate_responses(
172["What is the goal of SpaceX?"],
173model,
174tokenizer,
175"You are a SpaceX Assistant.",
176deepcopy(DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH]),
177)
178logger.info(f"example response: {example_response}")
179logger.info("transformers_lm is ready")
180except Exception as e:
181sentry_sdk.capture_exception(e)
182logger.exception(e)
183raise e
184
185
186@app.route("/ping", methods=["POST"])
187def ping():
188return "pong"
189
190
191@app.route("/respond", methods=["POST"])
192def respond():
193st_time = time.time()
194contexts = request.json.get("dialog_contexts", [])
195prompts = request.json.get("prompts", [])
196configs = request.json.get("configs", None)
197configs = [None] * len(prompts) if configs is None else configs
198configs = [deepcopy(DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH]) if el is None else el for el in configs]
199if len(contexts) > 0 and len(prompts) == 0:
200prompts = [""] * len(contexts)
201
202try:
203responses = []
204for context, prompt, config in zip(contexts, prompts, configs):
205curr_responses = []
206outputs = generate_responses(context, model, tokenizer, prompt, config)
207for response in outputs:
208if len(response) >= 2:
209curr_responses += [response]
210else:
211curr_responses += [""]
212responses += [curr_responses]
213
214except Exception as exc:
215logger.exception(exc)
216sentry_sdk.capture_exception(exc)
217responses = [[""]] * len(contexts)
218
219logger.info(f"transformers_lm output: {responses}")
220total_time = time.time() - st_time
221logger.info(f"transformers_lm exec time: {total_time:.3f}s")
222return jsonify(responses)
223
224
225@app.route("/generate_goals", methods=["POST"])
226def generate_goals():
227st_time = time.time()
228
229prompts = request.json.get("prompts", None)
230prompts = [] if prompts is None else prompts
231configs = request.json.get("configs", None)
232configs = [None] * len(prompts) if configs is None else configs
233configs = [deepcopy(DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH]) if el is None else el for el in configs]
234
235try:
236responses = []
237for prompt, config in zip(prompts, configs):
238context = ["hi", META_GOALS_PROMPT + f"\nPrompt: '''{prompt}'''\nResult:"]
239goals_for_prompt = generate_responses(context, model, tokenizer, "", config)[0]
240logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")
241responses += [goals_for_prompt]
242
243except Exception as exc:
244logger.info(exc)
245sentry_sdk.capture_exception(exc)
246responses = [""] * len(prompts)
247
248total_time = time.time() - st_time
249logger.info(f"openai-api generate_goals exec time: {total_time:.3f}s")
250return jsonify(responses)
251