rag-chatbot-2

Форк
0
132 строки · 4.8 Кб
1
import os
2
from abc import ABC
3
from pathlib import Path
4
from typing import Dict
5

6
import requests
7
from exp_lama_cpp.prompts import generate_prompt, generate_summarization_prompt
8
from llama_cpp import Llama
9
from tqdm import tqdm
10

11

12
class ModelSettings(ABC):
13
    url: str
14
    file_name: str
15
    system_template: str
16
    prompt_template: str
17
    summarization_template: str
18
    config: Dict
19

20

21
class StableLMZephyrSettings(ModelSettings):
22
    url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q5_K_M.gguf"
23
    file_name = "stablelm-zephyr-3b.Q5_K_M.gguf"
24
    config = {
25
        "n_ctx": 4096,  # The max sequence length to use - note that longer sequence lengths require much more resources
26
        "n_threads": 8,  # The number of CPU threads to use, tailor to your system and the resulting performance
27
        "n_gpu_layers": 35,  # The number of layers to offload to GPU, if you have GPU acceleration available
28
    }
29

30
    system_template = "You are a helpful, respectful and honest assistant. "
31
    prompt_template = """<|user|>Answer the question below:
32
{question}<|endoftext|>
33
<|assistant|>
34
"""
35
    summarization_template = """<|user|>Create a concise and comprehensive summary of the provided text,
36
ensuring that key information, concepts, and code snippets are retained.
37
Do not omit or shorten the code snippets, as they are crucial for a comprehensive understanding of the content.
38
"{text}"
39
CONCISE SUMMARY:<|endoftext|>
40
<|assistant|>
41
"""
42

43

44
SUPPORTED_MODELS = {"stablelm-zephyr": StableLMZephyrSettings}
45

46

47
def get_models():
48
    return list(SUPPORTED_MODELS.keys())
49

50

51
def get_model_setting(model_name: str):
52
    model_settings = SUPPORTED_MODELS.get(model_name)
53

54
    # validate input
55
    if model_settings is None:
56
        raise KeyError(model_name + " is a not supported model")
57

58
    return model_settings
59

60

61
class Model:
62
    """
63
    This Model class encapsulates the initialization of the language model, as well as the generation of
64
    prompts and outputs.
65
    You can create an instance of this class and use its methods to handle the specific tasks you need.
66
    """
67

68
    def __init__(self, model_folder: Path, model_settings: ModelSettings):
69
        self.model_settings = model_settings
70
        self.model_path = model_folder / self.model_settings.file_name
71
        self.prompt_template = self.model_settings.prompt_template
72
        self.summarization_template = self.model_settings.summarization_template
73
        self.system_template = self.model_settings.system_template
74

75
        self._auto_download()
76

77
        self.llm = Llama(model_path=str(self.model_path), **self.model_settings.config)
78

79
    def _auto_download(self) -> None:
80
        """
81
        Downloads a model file based on the provided name and saves it to the specified path.
82

83
        Returns:
84
            None
85

86
        Raises:
87
            Any exceptions raised during the download process will be caught and printed, but not re-raised.
88

89
        This function fetches model settings using the provided name, including the model's URL, and then downloads
90
        the model file from the URL. The download is done in chunks, and a progress bar is displayed to visualize
91
        the download process.
92

93
        """
94
        file_name = self.model_settings.file_name
95
        url = self.model_settings.url
96

97
        if not os.path.exists(self.model_path):
98
            # send a GET request to the URL to download the file.
99
            # Stream it while downloading, since the file is large
100

101
            try:
102
                response = requests.get(url, stream=True)
103
                # open the file in binary mode and write the contents of the response
104
                # in chunks.
105
                with open(self.model_path, "wb") as f:
106
                    for chunk in tqdm(response.iter_content(chunk_size=8912)):
107
                        if chunk:
108
                            f.write(chunk)
109

110
            except Exception as e:
111
                print(f"=> Download Failed. Error: {e}")
112
                return
113

114
            print(f"=> Model: {file_name} downloaded successfully 🥳")
115

116
    def generate_prompt(self, question):
117
        return generate_prompt(
118
            template=self.prompt_template,
119
            system=self.system_template,
120
            question=question,
121
        )
122

123
    def generate_summarization_prompt(self, text):
124
        return generate_summarization_prompt(template=self.summarization_template, text=text)
125

126
    def generate_answer(self, prompt: str, max_new_tokens: int = 1024) -> str:
127
        output = self.llm(prompt, max_tokens=max_new_tokens, echo=True)
128
        return output["choices"][0]["text"].split("<|assistant|>")[-1]
129

130
    def start_answer_iterator_streamer(self, prompt: str, max_new_tokens: int = 1024):
131
        stream = self.llm.create_completion(prompt, max_tokens=max_new_tokens, temperature=0.8, stream=True)
132
        return stream
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.