LLM-FineTuning-Large-Language-Models
83 строки · 3.2 Кб
1from transformers import AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig2from threading import Thread3import gradio as gr4import transformers5import torch6
7# Run the entire app with `python run_mixtral.py`
8
9""" The messages list should be of the following format:
10
11messages =
12
13[
14{"role": "user", "content": "User's first message"},
15{"role": "assistant", "content": "Assistant's first response"},
16{"role": "user", "content": "User's second message"},
17{"role": "assistant", "content": "Assistant's second response"},
18{"role": "user", "content": "User's third message"}
19]
20
21"""
22""" The `format_chat_history` function below is designed to format the dialogue history into a prompt that can be fed into the Mixtral model. This will help understand the context of the conversation and generate appropriate responses by the Model.
23The function takes a history of dialogues as input, which is a list of lists where each sublist represents a pair of user and assistant messages.
24"""
25
26def format_chat_history(history) -> str:27messages = [{"role": ("user" if i % 2 == 0 else "assistant"), "content": dialog[i % 2]}28for i, dialog in enumerate(history) for _ in (0, 1) if dialog[i % 2]]29# The conditional `(if dialog[i % 2])` ensures that messages30# that are None (like the latest assistant response in an ongoing31# conversation) are not included.32return pipeline.tokenizer.apply_chat_template(33messages, tokenize=False,34add_generation_prompt=True)35
36def model_loading_pipeline():37model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"38tokenizer = AutoTokenizer.from_pretrained(model_id)39streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)40
41pipeline = transformers.pipeline(42"text-generation",43model=model_id,44model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True,45"quantization_config": BitsAndBytesConfig(46load_in_4bit=True,47bnb_4bit_compute_dtype=torch.float16)},48streamer=streamer49)50return pipeline, streamer51
52def launch_gradio_app(pipeline, streamer):53with gr.Blocks() as demo:54chatbot = gr.Chatbot()55msg = gr.Textbox()56clear = gr.Button("Clear")57
58def user(user_message, history):59return "", history + [[user_message, None]]60
61def bot(history):62prompt = format_chat_history(history)63
64history[-1][1] = ""65kwargs = dict(text_inputs=prompt, max_new_tokens=2048, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)66thread = Thread(target=pipeline, kwargs=kwargs)67thread.start()68
69for token in streamer:70history[-1][1] += token71yield history72
73msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)74clear.click(lambda: None, None, chatbot, queue=False)75
76demo.queue()77demo.launch(share=True, debug=True)78
79if __name__ == '__main__':80pipeline, streamer = model_loading_pipeline()81launch_gradio_app(pipeline, streamer)82
83# Run the entire app with `python run_mixtral.py`
84