rag-chatbot-2
132 строки · 4.8 Кб
1import os
2from abc import ABC
3from pathlib import Path
4from typing import Dict
5
6import requests
7from exp_lama_cpp.prompts import generate_prompt, generate_summarization_prompt
8from llama_cpp import Llama
9from tqdm import tqdm
10
11
12class ModelSettings(ABC):
13url: str
14file_name: str
15system_template: str
16prompt_template: str
17summarization_template: str
18config: Dict
19
20
21class StableLMZephyrSettings(ModelSettings):
22url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q5_K_M.gguf"
23file_name = "stablelm-zephyr-3b.Q5_K_M.gguf"
24config = {
25"n_ctx": 4096, # The max sequence length to use - note that longer sequence lengths require much more resources
26"n_threads": 8, # The number of CPU threads to use, tailor to your system and the resulting performance
27"n_gpu_layers": 35, # The number of layers to offload to GPU, if you have GPU acceleration available
28}
29
30system_template = "You are a helpful, respectful and honest assistant. "
31prompt_template = """<|user|>Answer the question below:
32{question}<|endoftext|>
33<|assistant|>
34"""
35summarization_template = """<|user|>Create a concise and comprehensive summary of the provided text,
36ensuring that key information, concepts, and code snippets are retained.
37Do not omit or shorten the code snippets, as they are crucial for a comprehensive understanding of the content.
38"{text}"
39CONCISE SUMMARY:<|endoftext|>
40<|assistant|>
41"""
42
43
44SUPPORTED_MODELS = {"stablelm-zephyr": StableLMZephyrSettings}
45
46
47def get_models():
48return list(SUPPORTED_MODELS.keys())
49
50
51def get_model_setting(model_name: str):
52model_settings = SUPPORTED_MODELS.get(model_name)
53
54# validate input
55if model_settings is None:
56raise KeyError(model_name + " is a not supported model")
57
58return model_settings
59
60
61class Model:
62"""
63This Model class encapsulates the initialization of the language model, as well as the generation of
64prompts and outputs.
65You can create an instance of this class and use its methods to handle the specific tasks you need.
66"""
67
68def __init__(self, model_folder: Path, model_settings: ModelSettings):
69self.model_settings = model_settings
70self.model_path = model_folder / self.model_settings.file_name
71self.prompt_template = self.model_settings.prompt_template
72self.summarization_template = self.model_settings.summarization_template
73self.system_template = self.model_settings.system_template
74
75self._auto_download()
76
77self.llm = Llama(model_path=str(self.model_path), **self.model_settings.config)
78
79def _auto_download(self) -> None:
80"""
81Downloads a model file based on the provided name and saves it to the specified path.
82
83Returns:
84None
85
86Raises:
87Any exceptions raised during the download process will be caught and printed, but not re-raised.
88
89This function fetches model settings using the provided name, including the model's URL, and then downloads
90the model file from the URL. The download is done in chunks, and a progress bar is displayed to visualize
91the download process.
92
93"""
94file_name = self.model_settings.file_name
95url = self.model_settings.url
96
97if not os.path.exists(self.model_path):
98# send a GET request to the URL to download the file.
99# Stream it while downloading, since the file is large
100
101try:
102response = requests.get(url, stream=True)
103# open the file in binary mode and write the contents of the response
104# in chunks.
105with open(self.model_path, "wb") as f:
106for chunk in tqdm(response.iter_content(chunk_size=8912)):
107if chunk:
108f.write(chunk)
109
110except Exception as e:
111print(f"=> Download Failed. Error: {e}")
112return
113
114print(f"=> Model: {file_name} downloaded successfully 🥳")
115
116def generate_prompt(self, question):
117return generate_prompt(
118template=self.prompt_template,
119system=self.system_template,
120question=question,
121)
122
123def generate_summarization_prompt(self, text):
124return generate_summarization_prompt(template=self.summarization_template, text=text)
125
126def generate_answer(self, prompt: str, max_new_tokens: int = 1024) -> str:
127output = self.llm(prompt, max_tokens=max_new_tokens, echo=True)
128return output["choices"][0]["text"].split("<|assistant|>")[-1]
129
130def start_answer_iterator_streamer(self, prompt: str, max_new_tokens: int = 1024):
131stream = self.llm.create_completion(prompt, max_tokens=max_new_tokens, temperature=0.8, stream=True)
132return stream
133