llama-index

Форк
0
254 строки · 9.4 Кб
1
import os
2
from typing import Any, Callable, Dict, Optional, Sequence
3

4
import requests
5
from tqdm import tqdm
6

7
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
8
from llama_index.legacy.callbacks import CallbackManager
9
from llama_index.legacy.constants import (
10
    DEFAULT_CONTEXT_WINDOW,
11
    DEFAULT_NUM_OUTPUTS,
12
    DEFAULT_TEMPERATURE,
13
)
14
from llama_index.legacy.core.llms.types import (
15
    ChatMessage,
16
    ChatResponse,
17
    ChatResponseGen,
18
    CompletionResponse,
19
    CompletionResponseGen,
20
    LLMMetadata,
21
)
22
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
23
from llama_index.legacy.llms.custom import CustomLLM
24
from llama_index.legacy.llms.generic_utils import (
25
    completion_response_to_chat_response,
26
    stream_completion_response_to_chat_response,
27
)
28
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
29
from llama_index.legacy.utils import get_cache_dir
30

31
DEFAULT_LLAMA_CPP_GGML_MODEL = (
32
    "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML/resolve"
33
    "/main/llama-2-13b-chat.ggmlv3.q4_0.bin"
34
)
35
DEFAULT_LLAMA_CPP_GGUF_MODEL = (
36
    "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve"
37
    "/main/llama-2-13b-chat.Q4_0.gguf"
38
)
39
DEFAULT_LLAMA_CPP_MODEL_VERBOSITY = True
40

41

42
class LlamaCPP(CustomLLM):
43
    model_url: Optional[str] = Field(
44
        description="The URL llama-cpp model to download and use."
45
    )
46
    model_path: Optional[str] = Field(
47
        description="The path to the llama-cpp model to use."
48
    )
49
    temperature: float = Field(
50
        default=DEFAULT_TEMPERATURE,
51
        description="The temperature to use for sampling.",
52
        gte=0.0,
53
        lte=1.0,
54
    )
55
    max_new_tokens: int = Field(
56
        default=DEFAULT_NUM_OUTPUTS,
57
        description="The maximum number of tokens to generate.",
58
        gt=0,
59
    )
60
    context_window: int = Field(
61
        default=DEFAULT_CONTEXT_WINDOW,
62
        description="The maximum number of context tokens for the model.",
63
        gt=0,
64
    )
65
    generate_kwargs: Dict[str, Any] = Field(
66
        default_factory=dict, description="Kwargs used for generation."
67
    )
68
    model_kwargs: Dict[str, Any] = Field(
69
        default_factory=dict, description="Kwargs used for model initialization."
70
    )
71
    verbose: bool = Field(
72
        default=DEFAULT_LLAMA_CPP_MODEL_VERBOSITY,
73
        description="Whether to print verbose output.",
74
    )
75

76
    _model: Any = PrivateAttr()
77

78
    def __init__(
79
        self,
80
        model_url: Optional[str] = None,
81
        model_path: Optional[str] = None,
82
        temperature: float = DEFAULT_TEMPERATURE,
83
        max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
84
        context_window: int = DEFAULT_CONTEXT_WINDOW,
85
        callback_manager: Optional[CallbackManager] = None,
86
        generate_kwargs: Optional[Dict[str, Any]] = None,
87
        model_kwargs: Optional[Dict[str, Any]] = None,
88
        verbose: bool = DEFAULT_LLAMA_CPP_MODEL_VERBOSITY,
89
        system_prompt: Optional[str] = None,
90
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
91
        completion_to_prompt: Optional[Callable[[str], str]] = None,
92
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
93
        output_parser: Optional[BaseOutputParser] = None,
94
    ) -> None:
95
        try:
96
            from llama_cpp import Llama
97
        except ImportError:
98
            raise ImportError(
99
                "Could not import llama_cpp library."
100
                "Please install llama_cpp with `pip install llama-cpp-python`."
101
                "See the full installation guide for GPU support at "
102
                "`https://github.com/abetlen/llama-cpp-python`"
103
            )
104

105
        model_kwargs = {
106
            **{"n_ctx": context_window, "verbose": verbose},
107
            **(model_kwargs or {}),  # Override defaults via model_kwargs
108
        }
109

110
        # check if model is cached
111
        if model_path is not None:
112
            if not os.path.exists(model_path):
113
                raise ValueError(
114
                    "Provided model path does not exist. "
115
                    "Please check the path or provide a model_url to download."
116
                )
117
            else:
118
                self._model = Llama(model_path=model_path, **model_kwargs)
119
        else:
120
            cache_dir = get_cache_dir()
121
            model_url = model_url or self._get_model_path_for_version()
122
            model_name = os.path.basename(model_url)
123
            model_path = os.path.join(cache_dir, "models", model_name)
124
            if not os.path.exists(model_path):
125
                os.makedirs(os.path.dirname(model_path), exist_ok=True)
126
                self._download_url(model_url, model_path)
127
                assert os.path.exists(model_path)
128

129
            self._model = Llama(model_path=model_path, **model_kwargs)
130

131
        model_path = model_path
132
        generate_kwargs = generate_kwargs or {}
133
        generate_kwargs.update(
134
            {"temperature": temperature, "max_tokens": max_new_tokens}
135
        )
136

137
        super().__init__(
138
            model_path=model_path,
139
            model_url=model_url,
140
            temperature=temperature,
141
            context_window=context_window,
142
            max_new_tokens=max_new_tokens,
143
            callback_manager=callback_manager,
144
            generate_kwargs=generate_kwargs,
145
            model_kwargs=model_kwargs,
146
            verbose=verbose,
147
            system_prompt=system_prompt,
148
            messages_to_prompt=messages_to_prompt,
149
            completion_to_prompt=completion_to_prompt,
150
            pydantic_program_mode=pydantic_program_mode,
151
            output_parser=output_parser,
152
        )
153

154
    @classmethod
155
    def class_name(cls) -> str:
156
        return "LlamaCPP_llm"
157

158
    @property
159
    def metadata(self) -> LLMMetadata:
160
        """LLM metadata."""
161
        return LLMMetadata(
162
            context_window=self._model.context_params.n_ctx,
163
            num_output=self.max_new_tokens,
164
            model_name=self.model_path,
165
        )
166

167
    def _get_model_path_for_version(self) -> str:
168
        """Get model path for the current llama-cpp version."""
169
        import pkg_resources
170

171
        version = pkg_resources.get_distribution("llama-cpp-python").version
172
        major, minor, patch = version.split(".")
173

174
        # NOTE: llama-cpp-python<=0.1.78 supports GGML, newer support GGUF
175
        if int(major) <= 0 and int(minor) <= 1 and int(patch) <= 78:
176
            return DEFAULT_LLAMA_CPP_GGML_MODEL
177
        else:
178
            return DEFAULT_LLAMA_CPP_GGUF_MODEL
179

180
    def _download_url(self, model_url: str, model_path: str) -> None:
181
        completed = False
182
        try:
183
            print("Downloading url", model_url, "to path", model_path)
184
            with requests.get(model_url, stream=True) as r:
185
                with open(model_path, "wb") as file:
186
                    total_size = int(r.headers.get("Content-Length") or "0")
187
                    if total_size < 1000 * 1000:
188
                        raise ValueError(
189
                            "Content should be at least 1 MB, but is only",
190
                            r.headers.get("Content-Length"),
191
                            "bytes",
192
                        )
193
                    print("total size (MB):", round(total_size / 1000 / 1000, 2))
194
                    chunk_size = 1024 * 1024  # 1 MB
195
                    for chunk in tqdm(
196
                        r.iter_content(chunk_size=chunk_size),
197
                        total=int(total_size / chunk_size),
198
                    ):
199
                        file.write(chunk)
200
            completed = True
201
        except Exception as e:
202
            print("Error downloading model:", e)
203
        finally:
204
            if not completed:
205
                print("Download incomplete.", "Removing partially downloaded file.")
206
                os.remove(model_path)
207
                raise ValueError("Download incomplete.")
208

209
    @llm_chat_callback()
210
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
211
        prompt = self.messages_to_prompt(messages)
212
        completion_response = self.complete(prompt, formatted=True, **kwargs)
213
        return completion_response_to_chat_response(completion_response)
214

215
    @llm_chat_callback()
216
    def stream_chat(
217
        self, messages: Sequence[ChatMessage], **kwargs: Any
218
    ) -> ChatResponseGen:
219
        prompt = self.messages_to_prompt(messages)
220
        completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
221
        return stream_completion_response_to_chat_response(completion_response)
222

223
    @llm_completion_callback()
224
    def complete(
225
        self, prompt: str, formatted: bool = False, **kwargs: Any
226
    ) -> CompletionResponse:
227
        self.generate_kwargs.update({"stream": False})
228

229
        if not formatted:
230
            prompt = self.completion_to_prompt(prompt)
231

232
        response = self._model(prompt=prompt, **self.generate_kwargs)
233

234
        return CompletionResponse(text=response["choices"][0]["text"], raw=response)
235

236
    @llm_completion_callback()
237
    def stream_complete(
238
        self, prompt: str, formatted: bool = False, **kwargs: Any
239
    ) -> CompletionResponseGen:
240
        self.generate_kwargs.update({"stream": True})
241

242
        if not formatted:
243
            prompt = self.completion_to_prompt(prompt)
244

245
        response_iter = self._model(prompt=prompt, **self.generate_kwargs)
246

247
        def gen() -> CompletionResponseGen:
248
            text = ""
249
            for response in response_iter:
250
                delta = response["choices"][0]["text"]
251
                text += delta
252
                yield CompletionResponse(delta=delta, text=text, raw=response)
253

254
        return gen()
255

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.