LLM-FineTuning-Large-Language-Models
/
CodeLLaMA_34B_Conversation_with_Streamlit.py
219 строк · 8.5 Кб
1# !pip install streamlit transformers
2# Run the whole app with below kind of command
3# `streamlit run app.py`
4
5import re
6from threading import Thread
7import streamlit as st
8from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
9
10# Constants
11class Config:
12BASE_MODEL = "TheBloke/Phind-CodeLlama-34B-v2-GPTQ"
13Config.MODEL_MAX_LEN = 16384
14SYSTEM_PROMPT = "You are an AI coding assistant."
15GEN_LENGTH = 2048
16DEFAULT_PROMPT_LEN = None
17
18st.set_page_config(page_title="Code Generation conversation", page_icon="🤗")
19
20def load_models():
21"""
22Loads the language model and tokenizer.
23Returns:
24tuple: A tuple containing the model and tokenizer.
25"""
26try:
27model = AutoModelForCausalLM.from_pretrained(
28Config.BASE_MODEL,
29device_map="auto",
30trust_remote_code=False,
31revision="gptq-4bit-32g-actorder_True"
32)
33tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
34except Exception as e:
35print(f"Error loading models: {e}")
36return None, None
37
38return model, tokenizer
39
40model, tokenizer = load_models()
41
42def get_token_length(text):
43"""
44Calculates the length of a given text in tokens.
45Args:
46text (str): Text to be tokenized.
47Returns:
48int: Length of the tokenized text.
49"""
50return len(tokenizer(text)[0])
51
52Config.DEFAULT_PROMPT_LEN = get_token_length(f"""### System Prompt:
53{Config.SYSTEM_PROMPT}
54
55### User Message:
56
57### Assistant:""")
58
59def create_conversation_pairs():
60"""
61Creates conversation pairs from session messages.
62Returns:
63list: List of conversation pairs with token count.
64"""
65conversation_history = []
66temp_dict = {}
67turn_token_len = 0
68for i in st.session_state.messages[1:]:
69role = i['role']
70content = i['content']
71tokenized_content = f"""### {role.capitalize()}:{content}</s>"""
72turn_token_len += get_token_length(tokenized_content)
73
74if role == "assistant":
75temp_dict["token_count"] = turn_token_len
76temp_dict['content'] += tokenized_content
77conversation_history.append(temp_dict)
78temp_dict = {}
79turn_token_len = 0
80else:
81temp_dict['content'] = tokenized_content
82
83return conversation_history
84
85def get_prompt_with_history(instruction, max_tokens=Config.MODEL_MAX_LEN, generation_length=Config.GEN_LENGTH):
86"""
87Creates a prompt for the model.
88Args:
89instruction (str): User instruction to be included in the prompt.
90max_tokens (int): Maximum token length for the model.
91generation_length (int): Length of the generation.
92Returns:
93str: The created prompt.
94"""
95current_instruction_len = get_token_length(instruction)
96max_usable_tokens = max_tokens - generation_length - Config.DEFAULT_PROMPT_LEN - current_instruction_len
97conversation_history = create_conversation_pairs()
98conversation_history.reverse()
99
100usable_history = []
101history_len = 0
102for pair in conversation_history:
103history_len += pair['token_count']
104if history_len > max_usable_tokens:
105break
106usable_history.append(pair['content'])
107
108usable_history = "".join(reversed(usable_history))
109prompt = f"""### System Prompt:
110{Config.SYSTEM_PROMPT}
111
112{usable_history}
113
114### User Message: {instruction}
115
116### Assistant:"""
117return prompt
118
119def generate_response(instruction, max_new_tokens=Config.GEN_LENGTH):
120"""
121Generates a response from the model.
122Args:
123instruction (str): Instruction for generating the response.
124max_new_tokens (int): Maximum new tokens for the generation.
125Returns:
126str: Generated text.
127"""
128prompt = get_prompt_with_history(instruction, max_new_tokens)
129inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
130streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
131generation_kwargs = dict(inputs=inputs, streamer=streamer, max_new_tokens=max_new_tokens)
132thread = Thread(target=model.generate, kwargs=generation_kwargs)
133thread.start()
134generated_text = ""
135with st.empty():
136for idx, new_text in enumerate(streamer):
137generated_text += new_text
138generated_text = re.sub(r"</s>", "", generated_text)
139st.write(generated_text)
140return generated_text
141
142def main():
143"""
144Main function to handle the chat interface and response generation.
145"""
146# Initialization
147if "messages" not in st.session_state.keys():
148st.session_state.messages = [{"role": "assistant", "content": "Hello, how can I help?"}]
149
150# Chat Interface
151# Displaying each message in the chat
152for message in st.session_state.messages:
153with st.container():
154with st.chat_message(message["role"]):
155st.write(message["content"])
156
157# User input outside any container or layout widget
158user_input = st.chat_input()
159if user_input:
160# Append user message to the chat
161st.session_state.messages.append({"role": "user", "content": user_input})
162# Generate and append the assistant's response
163generate_and_append_response(user_input)
164
165def generate_and_append_response(user_input):
166"""
167Generates a response for the given user input and appends it to the chat.
168Args:
169user_input (str): User's input text.
170"""
171with st.chat_message("assistant"):
172with st.spinner("Typing..."):
173response = generate_response(user_input)
174# remove any end-of-string tokens (`</s>`).
175# These tokens are used by language models to signify the end of a text
176# sequence, but they are not needed in the final output shown to the user.
177response = re.sub("</s>", "", response)
178
179st.session_state.messages.append({"role": "assistant", "content": response})
180
181
182# Run the application
183if __name__ == "__main__":
184main()
185
186
187"""
188#####################################################
189## `st.session_state`
190#####################################################
191
192📌 `st.session_state` in Streamlit is a powerful feature that allows you to maintain state across user interactions. When a Streamlit app is running, each user interaction, like clicking a button or entering text, typically causes the whole script to re-run. This can lead to a loss of state - for example, all variables are reset. `st.session_state` solves this problem by providing a way to store and persist values across reruns.
193
194📌 Each `st.session_state` is unique to a user session. It behaves like a Python dictionary and can store any kind of Python object. You can set key-value pairs in this state, read them, and update them. This enables the app to remember information like user inputs, the state of interactions, or any other data that should persist across reruns.
195
196📌 Here in this code, `st.session_state` is used to store the conversation history in a chat application. Every time a user enters a message, it's appended to the `messages` list in `st.session_state`. This list then persists across reruns, allowing the app to maintain the context of the conversation and display the entire chat history.
197
198--------------
199
200##############################################
201## `thread = Thread(target=model.generate, kwargs=generation_kwargs)`
202##############################################
203
204The `target` parameter specifies the callable object to be invoked by the `run()` method of the thread.
205
206In Python, multithreading allows multiple parts of a program to run concurrently. Each thread runs independently and can execute different parts of the code simultaneously.
207
208When you create a new thread, you need to specify the function it will execute. This is where the `target` parameter is used.
209
210The `target` parameter in the `Thread` class specifies the function that the thread will execute. So, `target=model.generate` means that the `generate` method of the `model` object will be run in a separate thread.
211
212When the thread is started using `thread.start()`, the `run()` method of the thread is called. This method, in turn, calls the function specified in the `target` parameter, in this case, `model.generate`.
213
214The arguments required by `model.generate` are passed through the `kwargs` parameter, ensuring that the method has all the necessary information to execute properly.
215
216Without multithreading, calling `model.generate` directly in the main thread would block the entire execution of the program until the text generation is complete. This would make the Streamlit app unresponsive.
217
218
219"""