openai-cookbook
159 строк · 5.3 Кб
1"""
2TODO: This example is deprecated.
3Note: To answer questions based on text documents, we recommend the procedure in
4[Question Answering using Embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).
5Some of the code below may rely on [deprecated API endpoints](https://github.com/openai/openai-cookbook/tree/main/transition_guides_for_deprecated_API_endpoints).
6"""
7
8import argparse9
10from openai import OpenAI11import os12
13client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))14
15
16def create_context(17question, search_file_id, max_len=1800, search_model="ada", max_rerank=1018):19"""20Create a context for a question by finding the most similar context from the search file.
21:param question: The question
22:param search_file_id: The file id of the search file
23:param max_len: The maximum length of the returned context (in tokens)
24:param search_model: The search model to use
25:param max_rerank: The maximum number of reranking
26:return: The context
27"""
28# TODO: openai.Engine(search_model) is deprecated29results = client.Engine(search_model).search(30search_model=search_model,31query=question,32max_rerank=max_rerank,33file=search_file_id,34return_metadata=True,35)36returns = []37cur_len = 038for result in results["data"]:39cur_len += int(result["metadata"]) + 440if cur_len > max_len:41break42returns.append(result["text"])43return "\n\n###\n\n".join(returns)44
45
46def answer_question(47search_file_id="<SEARCH_FILE_ID>",48fine_tuned_qa_model="<FT_QA_MODEL_ID>",49question="Which country won the European Football championship in 2021?",50max_len=1800,51search_model="ada",52max_rerank=10,53debug=False,54stop_sequence=["\n", "."],55max_tokens=100,56):57"""58Answer a question based on the most similar context from the search file, using your fine-tuned model.
59:param question: The question
60:param fine_tuned_qa_model: The fine tuned QA model
61:param search_file_id: The file id of the search file
62:param max_len: The maximum length of the returned context (in tokens)
63:param search_model: The search model to use
64:param max_rerank: The maximum number of reranking
65:param debug: Whether to output debug information
66:param stop_sequence: The stop sequence for Q&A model
67:param max_tokens: The maximum number of tokens to return
68:return: The answer
69"""
70context = create_context(71question,72search_file_id,73max_len=max_len,74search_model=search_model,75max_rerank=max_rerank,76)77if debug:78print("Context:\n" + context)79print("\n\n")80try:81# fine-tuned models requires model parameter, whereas other models require engine parameter82model_param = (83{"model": fine_tuned_qa_model}84if ":" in fine_tuned_qa_model85and fine_tuned_qa_model.split(":")[1].startswith("ft")86else {"engine": fine_tuned_qa_model}87)88response = client.chat.completions.create(prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:",89temperature=0,90max_tokens=max_tokens,91top_p=1,92frequency_penalty=0,93presence_penalty=0,94stop=stop_sequence,95**model_param)96return response["choices"][0]["text"]97except Exception as e:98print(e)99return ""100
101
102if __name__ == "__main__":103parser = argparse.ArgumentParser(104description="Rudimentary functionality of the answers endpoint with a fine-tuned Q&A model.",105formatter_class=argparse.ArgumentDefaultsHelpFormatter,106)107parser.add_argument(108"--search_file_id", help="Search file id", required=True, type=str109)110parser.add_argument(111"--fine_tuned_qa_model", help="Fine-tuned QA model id", required=True, type=str112)113parser.add_argument(114"--question", help="Question to answer", required=True, type=str115)116parser.add_argument(117"--max_len",118help="Maximum length of the returned context (in tokens)",119default=1800,120type=int,121)122parser.add_argument(123"--search_model", help="Search model to use", default="ada", type=str124)125parser.add_argument(126"--max_rerank",127help="Maximum number of reranking for the search",128default=10,129type=int,130)131parser.add_argument(132"--debug", help="Print debug information (context used)", action="store_true"133)134parser.add_argument(135"--stop_sequence",136help="Stop sequences for the Q&A model",137default=["\n", "."],138nargs="+",139type=str,140)141parser.add_argument(142"--max_tokens",143help="Maximum number of tokens to return",144default=100,145type=int,146)147args = parser.parse_args()148response = answer_question(149search_file_id=args.search_file_id,150fine_tuned_qa_model=args.fine_tuned_qa_model,151question=args.question,152max_len=args.max_len,153search_model=args.search_model,154max_rerank=args.max_rerank,155debug=args.debug,156stop_sequence=args.stop_sequence,157max_tokens=args.max_tokens,158)159print(f"Answer:{response}")160