dream
161 строка · 5.5 Кб
1import json2import logging3import os4import time5
6import sentry_sdk7import torch8from flask import Flask, request, jsonify9from peft import PeftModel, PeftConfig10from sentry_sdk.integrations.flask import FlaskIntegration11from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig12
13from common.prompts import META_GOALS_PROMPT14from common.universal_templates import GENERATIVE_ROBOT_TEMPLATE15
16
17sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])18
19logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)20logger = logging.getLogger(__name__)21
22PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")23logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")24LANGUAGE = os.getenv("LANGUAGE", "EN")25NAMING = {26"EN": ["AI", "Human"],27"RU": ["Чат-бот", "Человек"],28}
29
30app = Flask(__name__)31logging.getLogger("werkzeug").setLevel("WARNING")32
33DEFAULT_CONFIGS = {34"transformers-lm-bloomz7b": json.load(open("common/generative_configs/default_generative_config.json", "r")),35"transformers-lm-gptj": json.load(open("common/generative_configs/default_generative_config.json", "r")),36"transformers-lm-oasst12b": json.load(open("common/generative_configs/default_generative_config.json", "r")),37}
38
39
40def generate_responses(context, model, tokenizer, prompt, continue_last_uttr=False):41outputs = []42dialog_context = ""43if prompt:44dialog_context += prompt + "\n"45s = len(context) % 246context = [f"{NAMING[LANGUAGE][(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]47if continue_last_uttr:48dialog_context += "\n".join(context)49else:50dialog_context += "\n".join(context) + f"\n{NAMING[LANGUAGE][0]}:"51
52logger.info(f"context inside generate_responses seen as: {dialog_context}")53data = tokenizer([dialog_context], return_tensors="pt")54data = {k: v.to(model.device) for k, v in data.items() if k in ("input_ids", "attention_mask")}55
56with torch.no_grad():57chat_history_ids = model.generate(58**data,59generation_config=default_config,60)61if torch.cuda.is_available():62chat_history_ids = chat_history_ids.cpu()63for result in chat_history_ids:64output = tokenizer.decode(result, skip_special_tokens=True)65result_cut = output.replace(dialog_context + " ", "")66result_cut = [x.strip() for x in GENERATIVE_ROBOT_TEMPLATE.split(result_cut) if x.strip()][0]67logger.info(f"hypothesis: {result_cut}")68outputs.append(result_cut)69
70return outputs71
72
73try:74tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)75
76default_config = GenerationConfig.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)77
78config = PeftConfig.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)79model = AutoModelForCausalLM.from_pretrained(80config.base_model_name_or_path,81torch_dtype=torch.float16,82# load_in_8bit=True,83# device_map="auto"84)85model = PeftModel.from_pretrained(model, PRETRAINED_MODEL_NAME_OR_PATH)86model.eval()87
88if torch.cuda.is_available():89model.to("cuda")90logger.info("transformers_peft_lm is set to run on cuda")91
92example_response = generate_responses(93["What is the goal of SpaceX?"], model, tokenizer, "You are a SpaceX Assistant.", default_config94)95logger.info(f"example response: {example_response}")96logger.info("transformers_peft_lm is ready")97except Exception as e:98sentry_sdk.capture_exception(e)99logger.exception(e)100raise e101
102
103@app.route("/ping", methods=["POST"])104def ping():105return "pong"106
107
108@app.route("/respond", methods=["POST"])109def respond():110st_time = time.time()111contexts = request.json.get("dialog_contexts", [])112prompts = request.json.get("prompts", [])113if len(contexts) > 0 and len(prompts) == 0:114prompts = [""] * len(contexts)115
116try:117responses = []118for context, prompt in zip(contexts, prompts):119curr_responses = []120outputs = generate_responses(context, model, tokenizer, prompt)121for response in outputs:122if len(response) >= 2:123curr_responses += [response]124else:125curr_responses += [""]126responses += [curr_responses]127
128except Exception as exc:129logger.exception(exc)130sentry_sdk.capture_exception(exc)131responses = [[""]] * len(contexts)132
133logger.info(f"transformers_peft_lm output: {responses}")134total_time = time.time() - st_time135logger.info(f"transformers_peft_lm exec time: {total_time:.3f}s")136return jsonify(responses)137
138
139@app.route("/generate_goals", methods=["POST"])140def generate_goals():141st_time = time.time()142
143prompts = request.json.get("prompts", None)144prompts = [] if prompts is None else prompts145
146try:147responses = []148for prompt in prompts:149context = ["hi", META_GOALS_PROMPT + f"\nPrompt: '''{prompt}'''\nResult:"]150goals_for_prompt = generate_responses(context, model, tokenizer, "")[0]151logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")152responses += [goals_for_prompt]153
154except Exception as exc:155logger.info(exc)156sentry_sdk.capture_exception(exc)157responses = [""] * len(prompts)158
159total_time = time.time() - st_time160logger.info(f"openai-api generate_goals exec time: {total_time:.3f}s")161return jsonify(responses)162