LLM-FineTuning-Large-Language-Models

Форк
0
83 строки · 3.2 Кб
1
from transformers import AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
2
from threading import Thread
3
import gradio as gr
4
import transformers
5
import torch
6

7
# Run the entire app with `python run_mixtral.py`
8

9
""" The messages list should be of the following format:
10

11
messages =
12

13
[
14
    {"role": "user", "content": "User's first message"},
15
    {"role": "assistant", "content": "Assistant's first response"},
16
    {"role": "user", "content": "User's second message"},
17
    {"role": "assistant", "content": "Assistant's second response"},
18
    {"role": "user", "content": "User's third message"}
19
]
20

21
"""
22
""" The `format_chat_history` function below is designed to format the dialogue history into a prompt that can be fed into the Mixtral model. This will help understand the context of the conversation and generate appropriate responses by the Model.
23
The function takes a history of dialogues as input, which is a list of lists where each sublist represents a pair of user and assistant messages.
24
"""
25

26
def format_chat_history(history) -> str:
27
    messages = [{"role": ("user" if i % 2 == 0 else "assistant"), "content": dialog[i % 2]}
28
                for i, dialog in enumerate(history) for _ in (0, 1) if dialog[i % 2]]
29
    # The conditional `(if dialog[i % 2])` ensures that messages
30
    # that are None (like the latest assistant response in an ongoing
31
    # conversation) are not included.
32
    return pipeline.tokenizer.apply_chat_template(
33
        messages, tokenize=False,
34
        add_generation_prompt=True)
35

36
def model_loading_pipeline():
37
    model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
38
    tokenizer = AutoTokenizer.from_pretrained(model_id)
39
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
40

41
    pipeline = transformers.pipeline(
42
        "text-generation",
43
        model=model_id,
44
        model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True,
45
                      "quantization_config": BitsAndBytesConfig(
46
                                                                load_in_4bit=True,
47
                                                                bnb_4bit_compute_dtype=torch.float16)},
48
        streamer=streamer
49
    )
50
    return pipeline, streamer
51

52
def launch_gradio_app(pipeline, streamer):
53
    with gr.Blocks() as demo:
54
        chatbot = gr.Chatbot()
55
        msg = gr.Textbox()
56
        clear = gr.Button("Clear")
57

58
        def user(user_message, history):
59
            return "", history + [[user_message, None]]
60

61
        def bot(history):
62
            prompt = format_chat_history(history)
63

64
            history[-1][1] = ""
65
            kwargs = dict(text_inputs=prompt, max_new_tokens=2048, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
66
            thread = Thread(target=pipeline, kwargs=kwargs)
67
            thread.start()
68

69
            for token in streamer:
70
                history[-1][1] += token
71
                yield history
72

73
        msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)
74
        clear.click(lambda: None, None, chatbot, queue=False)
75

76
    demo.queue()
77
    demo.launch(share=True, debug=True)
78

79
if __name__ == '__main__':
80
    pipeline, streamer = model_loading_pipeline()
81
    launch_gradio_app(pipeline, streamer)
82

83
# Run the entire app with `python run_mixtral.py`
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.