dream
159 строк · 6.4 Кб
1import json2import logging3import os4import time5import sentry_sdk6
7from openai import OpenAI8from common.prompts import META_GOALS_PROMPT9from common.universal_templates import GENERATIVE_ROBOT_TEMPLATE10from flask import Flask, request, jsonify11from sentry_sdk.integrations.flask import FlaskIntegration12
13
14sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()])15
16
17logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)18logger = logging.getLogger(__name__)19
20PRETRAINED_MODEL_NAME_OR_PATH = os.environ.get("PRETRAINED_MODEL_NAME_OR_PATH")21logger.info(f"PRETRAINED_MODEL_NAME_OR_PATH = {PRETRAINED_MODEL_NAME_OR_PATH}")22NAMING = ["AI", "Human"]23CHATGPT_ROLES = ["assistant", "user"]24
25app = Flask(__name__)26logging.getLogger("werkzeug").setLevel("WARNING")27DEFAULT_CONFIGS = {28"text-davinci-003": json.load(open("common/generative_configs/openai-text-davinci-003.json", "r")),29"gpt-3.5-turbo": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),30"gpt-3.5-turbo-16k": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),31"gpt-4": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),32"gpt-4-32k": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),33"gpt-4-1106-preview": json.load(open("common/generative_configs/openai-chatgpt.json", "r")),34}
35CHAT_COMPLETION_MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k", "gpt-4-1106-preview"]36
37
38def generate_responses(context, openai_api_key, openai_org, prompt, generation_params, continue_last_uttr=False):39outputs = []40
41assert openai_api_key, logger.error("Error: OpenAI API key is not specified in env")42client = OpenAI(api_key=openai_api_key, organization=openai_org if openai_org else None)43
44if PRETRAINED_MODEL_NAME_OR_PATH in CHAT_COMPLETION_MODELS:45logger.info("Use special chat completion endpoint")46s = len(context) % 247messages = [48{"role": "system", "content": prompt},49]50messages += [51{52"role": f"{CHATGPT_ROLES[(s + uttr_id) % 2]}",53"content": uttr,54}55for uttr_id, uttr in enumerate(context)56]57logger.info(f"context inside generate_responses seen as: {messages}")58response = client.chat.completions.create(59model=PRETRAINED_MODEL_NAME_OR_PATH, messages=messages, **generation_params60)61else:62dialog_context = ""63if prompt:64dialog_context += prompt + "\n"65s = len(context) % 266context = [f"{NAMING[(s + uttr_id) % 2]}: {uttr}" for uttr_id, uttr in enumerate(context)]67if continue_last_uttr:68dialog_context += "\n".join(context)69else:70dialog_context += "\n".join(context) + f"\n{NAMING[0]}:"71logger.info(f"context inside generate_responses seen as: {dialog_context}")72response = client.completions.create(73model=PRETRAINED_MODEL_NAME_OR_PATH, prompt=dialog_context, **generation_params74)75
76response = response.model_dump()77outputs = [78resp["message"]["content"].strip() if "message" in resp else resp.get("text", "").strip()79for resp in response["choices"]80]81
82if PRETRAINED_MODEL_NAME_OR_PATH not in CHAT_COMPLETION_MODELS:83# post-processing of the responses by all models except of ChatGPT84outputs = [GENERATIVE_ROBOT_TEMPLATE.sub("\n", resp).strip() for resp in outputs]85return outputs86
87
88@app.route("/ping", methods=["POST"])89def ping():90return "pong"91
92
93@app.route("/respond", methods=["POST"])94def respond():95st_time = time.time()96contexts = request.json.get("dialog_contexts", [])97prompts = request.json.get("prompts", [])98configs = request.json.get("configs", None)99configs = [None] * len(prompts) if configs is None else configs100configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs]101if len(contexts) > 0 and len(prompts) == 0:102prompts = [""] * len(contexts)103openai_api_keys = request.json.get("openai_api_keys", [])104openai_orgs = request.json.get("openai_api_organizations", None)105openai_orgs = [None] * len(contexts) if openai_orgs is None else openai_orgs106
107try:108responses = []109for context, openai_api_key, openai_org, prompt, config in zip(110contexts, openai_api_keys, openai_orgs, prompts, configs111):112curr_responses = []113outputs = generate_responses(context, openai_api_key, openai_org, prompt, config)114for response in outputs:115if len(response) >= 2:116curr_responses += [response]117else:118curr_responses += [""]119responses += [curr_responses]120
121except Exception as exc:122logger.exception(exc)123sentry_sdk.capture_exception(exc)124responses = [[""]] * len(contexts)125
126logger.info(f"openai-api result: {responses}")127total_time = time.time() - st_time128logger.info(f"openai-api exec time: {total_time:.3f}s")129return jsonify(responses)130
131
132@app.route("/generate_goals", methods=["POST"])133def generate_goals():134st_time = time.time()135
136prompts = request.json.get("prompts", None)137prompts = [] if prompts is None else prompts138configs = request.json.get("configs", None)139configs = [None] * len(prompts) if configs is None else configs140configs = [DEFAULT_CONFIGS[PRETRAINED_MODEL_NAME_OR_PATH] if el is None else el for el in configs]141openai_api_keys = request.json.get("openai_api_keys", [])142openai_orgs = request.json.get("openai_api_organizations", None)143openai_orgs = [None] * len(prompts) if openai_orgs is None else openai_orgs144try:145responses = []146for openai_api_key, openai_org, prompt, config in zip(openai_api_keys, openai_orgs, prompts, configs):147context = ["hi", META_GOALS_PROMPT + f"\nPrompt: '''{prompt}'''\nResult:"]148goals_for_prompt = generate_responses(context, openai_api_key, openai_org, "", config)[0]149logger.info(f"Generated goals: `{goals_for_prompt}` for prompt: `{prompt}`")150responses += [goals_for_prompt]151
152except Exception as exc:153logger.info(exc)154sentry_sdk.capture_exception(exc)155responses = [""] * len(prompts)156
157total_time = time.time() - st_time158logger.info(f"openai-api generate_goals exec time: {total_time:.3f}s")159return jsonify(responses)160