LLM-FineTuning-Large-Language-Models

Форк
0
/
CodeLLaMA_34B_Conversation_with_Streamlit.py 
219 строк · 8.5 Кб
1
# !pip install streamlit transformers
2
# Run the whole app with below kind of command
3
# `streamlit run app.py`
4

5
import re
6
from threading import Thread
7
import streamlit as st
8
from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
9

10
# Constants
11
class Config:
12
    BASE_MODEL = "TheBloke/Phind-CodeLlama-34B-v2-GPTQ"
13
    Config.MODEL_MAX_LEN = 16384
14
    SYSTEM_PROMPT = "You are an AI coding assistant."
15
    GEN_LENGTH = 2048
16
    DEFAULT_PROMPT_LEN = None
17

18
st.set_page_config(page_title="Code Generation conversation", page_icon="🤗")
19

20
def load_models():
21
    """
22
    Loads the language model and tokenizer.
23
    Returns:
24
        tuple: A tuple containing the model and tokenizer.
25
    """
26
    try:
27
        model = AutoModelForCausalLM.from_pretrained(
28
            Config.BASE_MODEL,
29
            device_map="auto",
30
            trust_remote_code=False,
31
            revision="gptq-4bit-32g-actorder_True"
32
        )
33
        tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
34
    except Exception as e:
35
        print(f"Error loading models: {e}")
36
        return None, None
37

38
    return model, tokenizer
39

40
model, tokenizer = load_models()
41

42
def get_token_length(text):
43
    """
44
    Calculates the length of a given text in tokens.
45
    Args:
46
        text (str): Text to be tokenized.
47
    Returns:
48
        int: Length of the tokenized text.
49
    """
50
    return len(tokenizer(text)[0])
51

52
Config.DEFAULT_PROMPT_LEN = get_token_length(f"""### System Prompt:
53
{Config.SYSTEM_PROMPT}
54

55
### User Message:
56

57
### Assistant:""")
58

59
def create_conversation_pairs():
60
    """
61
    Creates conversation pairs from session messages.
62
    Returns:
63
        list: List of conversation pairs with token count.
64
    """
65
    conversation_history = []
66
    temp_dict = {}
67
    turn_token_len = 0
68
    for i in st.session_state.messages[1:]:
69
        role = i['role']
70
        content = i['content']
71
        tokenized_content = f"""### {role.capitalize()}:{content}</s>"""
72
        turn_token_len += get_token_length(tokenized_content)
73

74
        if role == "assistant":
75
            temp_dict["token_count"] = turn_token_len
76
            temp_dict['content'] += tokenized_content
77
            conversation_history.append(temp_dict)
78
            temp_dict = {}
79
            turn_token_len = 0
80
        else:
81
            temp_dict['content'] = tokenized_content
82

83
    return conversation_history
84

85
def get_prompt_with_history(instruction, max_tokens=Config.MODEL_MAX_LEN, generation_length=Config.GEN_LENGTH):
86
    """
87
    Creates a prompt for the model.
88
    Args:
89
        instruction (str): User instruction to be included in the prompt.
90
        max_tokens (int): Maximum token length for the model.
91
        generation_length (int): Length of the generation.
92
    Returns:
93
        str: The created prompt.
94
    """
95
    current_instruction_len = get_token_length(instruction)
96
    max_usable_tokens = max_tokens - generation_length - Config.DEFAULT_PROMPT_LEN - current_instruction_len
97
    conversation_history = create_conversation_pairs()
98
    conversation_history.reverse()
99

100
    usable_history = []
101
    history_len = 0
102
    for pair in conversation_history:
103
        history_len += pair['token_count']
104
        if history_len > max_usable_tokens:
105
            break
106
        usable_history.append(pair['content'])
107

108
    usable_history = "".join(reversed(usable_history))
109
    prompt = f"""### System Prompt:
110
{Config.SYSTEM_PROMPT}
111

112
{usable_history}
113

114
### User Message: {instruction}
115

116
### Assistant:"""
117
    return prompt
118

119
def generate_response(instruction, max_new_tokens=Config.GEN_LENGTH):
120
    """
121
    Generates a response from the model.
122
    Args:
123
        instruction (str): Instruction for generating the response.
124
        max_new_tokens (int): Maximum new tokens for the generation.
125
    Returns:
126
        str: Generated text.
127
    """
128
    prompt = get_prompt_with_history(instruction, max_new_tokens)
129
    inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
130
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
131
    generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens)
132
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
133
    thread.start()
134
    generated_text = ""
135
    with st.empty():
136
        for idx, new_text in enumerate(streamer):
137
            generated_text += new_text
138
            generated_text = re.sub(r"</s>", "", generated_text)
139
            st.write(generated_text)
140
    return generated_text
141

142
def main():
143
    """
144
    Main function to handle the chat interface and response generation.
145
    """
146
    # Initialization
147
    if "messages" not in st.session_state.keys():
148
        st.session_state.messages = [{"role": "assistant", "content": "Hello, how can I help?"}]
149

150
    # Chat Interface
151
    # Displaying each message in the chat
152
    for message in st.session_state.messages:
153
        with st.container():
154
            with st.chat_message(message["role"]):
155
                st.write(message["content"])
156

157
    # User input outside any container or layout widget
158
    user_input = st.chat_input()
159
    if user_input:
160
        # Append user message to the chat
161
        st.session_state.messages.append({"role": "user", "content": user_input})
162
        # Generate and append the assistant's response
163
        generate_and_append_response(user_input)
164

165
def generate_and_append_response(user_input):
166
    """
167
    Generates a response for the given user input and appends it to the chat.
168
    Args:
169
        user_input (str): User's input text.
170
    """
171
    with st.chat_message("assistant"):
172
        with st.spinner("Typing..."):
173
            response = generate_response(user_input)
174
            # remove any end-of-string tokens (`</s>`).
175
            # These tokens are used by language models to signify the end of a text
176
            # sequence, but they are not needed in the final output shown to the user.
177
            response = re.sub("</s>", "", response)
178

179
    st.session_state.messages.append({"role": "assistant", "content": response})
180

181

182
# Run the application
183
if __name__ == "__main__":
184
    main()
185

186

187
"""
188
#####################################################
189
## `st.session_state`
190
#####################################################
191

192
📌 `st.session_state` in Streamlit is a powerful feature that allows you to maintain state across user interactions. When a Streamlit app is running, each user interaction, like clicking a button or entering text, typically causes the whole script to re-run. This can lead to a loss of state - for example, all variables are reset. `st.session_state` solves this problem by providing a way to store and persist values across reruns.
193

194
📌 Each `st.session_state` is unique to a user session. It behaves like a Python dictionary and can store any kind of Python object. You can set key-value pairs in this state, read them, and update them. This enables the app to remember information like user inputs, the state of interactions, or any other data that should persist across reruns.
195

196
📌 Here in this code, `st.session_state` is used to store the conversation history in a chat application. Every time a user enters a message, it's appended to the `messages` list in `st.session_state`. This list then persists across reruns, allowing the app to maintain the context of the conversation and display the entire chat history.
197

198
--------------
199

200
##############################################
201
## `thread = Thread(target=model.generate, kwargs=generation_kwargs)`
202
##############################################
203

204
The `target` parameter specifies the callable object to be invoked by the `run()` method of the thread.
205

206
In Python, multithreading allows multiple parts of a program to run concurrently. Each thread runs independently and can execute different parts of the code simultaneously.
207

208
When you create a new thread, you need to specify the function it will execute. This is where the `target` parameter is used.
209

210
The `target` parameter in the `Thread` class specifies the function that the thread will execute. So, `target=model.generate` means that the `generate` method of the `model` object will be run in a separate thread.
211

212
When the thread is started using `thread.start()`, the `run()` method of the thread is called. This method, in turn, calls the function specified in the `target` parameter, in this case, `model.generate`.
213

214
The arguments required by `model.generate` are passed through the `kwargs` parameter, ensuring that the method has all the necessary information to execute properly.
215

216
Without multithreading, calling `model.generate` directly in the main thread would block the entire execution of the program until the text generation is complete. This would make the Streamlit app unresponsive.
217

218

219
"""

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.