llama-index

Форк
0
422 строки · 14.1 Кб
1
import json
2
from typing import Any, Callable, Dict, List, Optional, Sequence
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    ChatResponseAsyncGen,
10
    ChatResponseGen,
11
    CompletionResponse,
12
    CompletionResponseAsyncGen,
13
    CompletionResponseGen,
14
    LLMMetadata,
15
)
16
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
17
from llama_index.legacy.llms.generic_utils import (
18
    completion_response_to_chat_response,
19
    stream_completion_response_to_chat_response,
20
)
21
from llama_index.legacy.llms.generic_utils import (
22
    messages_to_prompt as generic_messages_to_prompt,
23
)
24
from llama_index.legacy.llms.llm import LLM
25
from llama_index.legacy.llms.vllm_utils import get_response, post_http_request
26
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
27

28

29
class Vllm(LLM):
30
    model: Optional[str] = Field(description="The HuggingFace Model to use.")
31

32
    temperature: float = Field(description="The temperature to use for sampling.")
33

34
    tensor_parallel_size: Optional[int] = Field(
35
        default=1,
36
        description="The number of GPUs to use for distributed execution with tensor parallelism.",
37
    )
38

39
    trust_remote_code: Optional[bool] = Field(
40
        default=True,
41
        description="Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer.",
42
    )
43

44
    n: int = Field(
45
        default=1,
46
        description="Number of output sequences to return for the given prompt.",
47
    )
48

49
    best_of: Optional[int] = Field(
50
        default=None,
51
        description="Number of output sequences that are generated from the prompt.",
52
    )
53

54
    presence_penalty: float = Field(
55
        default=0.0,
56
        description="Float that penalizes new tokens based on whether they appear in the generated text so far.",
57
    )
58

59
    frequency_penalty: float = Field(
60
        default=0.0,
61
        description="Float that penalizes new tokens based on their frequency in the generated text so far.",
62
    )
63

64
    top_p: float = Field(
65
        default=1.0,
66
        description="Float that controls the cumulative probability of the top tokens to consider.",
67
    )
68

69
    top_k: int = Field(
70
        default=-1,
71
        description="Integer that controls the number of top tokens to consider.",
72
    )
73

74
    use_beam_search: bool = Field(
75
        default=False, description="Whether to use beam search instead of sampling."
76
    )
77

78
    stop: Optional[List[str]] = Field(
79
        default=None,
80
        description="List of strings that stop the generation when they are generated.",
81
    )
82

83
    ignore_eos: bool = Field(
84
        default=False,
85
        description="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.",
86
    )
87

88
    max_new_tokens: int = Field(
89
        default=512,
90
        description="Maximum number of tokens to generate per output sequence.",
91
    )
92

93
    logprobs: Optional[int] = Field(
94
        default=None,
95
        description="Number of log probabilities to return per output token.",
96
    )
97

98
    dtype: str = Field(
99
        default="auto",
100
        description="The data type for the model weights and activations.",
101
    )
102

103
    download_dir: Optional[str] = Field(
104
        default=None,
105
        description="Directory to download and load the weights. (Default to the default cache dir of huggingface)",
106
    )
107

108
    vllm_kwargs: Dict[str, Any] = Field(
109
        default_factory=dict,
110
        description="Holds any model parameters valid for `vllm.LLM` call not explicitly specified.",
111
    )
112

113
    api_url: str = Field(description="The api url for vllm server")
114

115
    _client: Any = PrivateAttr()
116

117
    def __init__(
118
        self,
119
        model: str = "facebook/opt-125m",
120
        temperature: float = 1.0,
121
        tensor_parallel_size: int = 1,
122
        trust_remote_code: bool = True,
123
        n: int = 1,
124
        best_of: Optional[int] = None,
125
        presence_penalty: float = 0.0,
126
        frequency_penalty: float = 0.0,
127
        top_p: float = 1.0,
128
        top_k: int = -1,
129
        use_beam_search: bool = False,
130
        stop: Optional[List[str]] = None,
131
        ignore_eos: bool = False,
132
        max_new_tokens: int = 512,
133
        logprobs: Optional[int] = None,
134
        dtype: str = "auto",
135
        download_dir: Optional[str] = None,
136
        vllm_kwargs: Dict[str, Any] = {},
137
        api_url: Optional[str] = "",
138
        callback_manager: Optional[CallbackManager] = None,
139
        system_prompt: Optional[str] = None,
140
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
141
        completion_to_prompt: Optional[Callable[[str], str]] = None,
142
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
143
        output_parser: Optional[BaseOutputParser] = None,
144
    ) -> None:
145
        try:
146
            from vllm import LLM as VLLModel
147
        except ImportError:
148
            raise ImportError(
149
                "Could not import vllm python package. "
150
                "Please install it with `pip install vllm`."
151
            )
152
        if model != "":
153
            self._client = VLLModel(
154
                model=model,
155
                tensor_parallel_size=tensor_parallel_size,
156
                trust_remote_code=trust_remote_code,
157
                dtype=dtype,
158
                download_dir=download_dir,
159
                **vllm_kwargs
160
            )
161
        else:
162
            self._client = None
163
        callback_manager = callback_manager or CallbackManager([])
164
        super().__init__(
165
            model=model,
166
            temperature=temperature,
167
            n=n,
168
            best_of=best_of,
169
            presence_penalty=presence_penalty,
170
            frequency_penalty=frequency_penalty,
171
            top_p=top_p,
172
            top_k=top_k,
173
            use_beam_search=use_beam_search,
174
            stop=stop,
175
            ignore_eos=ignore_eos,
176
            max_new_tokens=max_new_tokens,
177
            logprobs=logprobs,
178
            dtype=dtype,
179
            download_dir=download_dir,
180
            vllm_kwargs=vllm_kwargs,
181
            api_url=api_url,
182
            system_prompt=system_prompt,
183
            messages_to_prompt=messages_to_prompt,
184
            completion_to_prompt=completion_to_prompt,
185
            pydantic_program_mode=pydantic_program_mode,
186
            output_parser=output_parser,
187
        )
188

189
    @classmethod
190
    def class_name(cls) -> str:
191
        return "Vllm"
192

193
    @property
194
    def metadata(self) -> LLMMetadata:
195
        return LLMMetadata(model_name=self.model)
196

197
    @property
198
    def _model_kwargs(self) -> Dict[str, Any]:
199
        base_kwargs = {
200
            "temperature": self.temperature,
201
            "max_tokens": self.max_new_tokens,
202
            "n": self.n,
203
            "frequency_penalty": self.frequency_penalty,
204
            "presence_penalty": self.presence_penalty,
205
            "use_beam_search": self.use_beam_search,
206
            "best_of": self.best_of,
207
            "ignore_eos": self.ignore_eos,
208
            "stop": self.stop,
209
            "logprobs": self.logprobs,
210
            "top_k": self.top_k,
211
            "top_p": self.top_p,
212
            "stop": self.stop,
213
        }
214
        return {**base_kwargs}
215

216
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
217
        return {
218
            **self._model_kwargs,
219
            **kwargs,
220
        }
221

222
    @llm_chat_callback()
223
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
224
        kwargs = kwargs if kwargs else {}
225
        prompt = self.messages_to_prompt(messages)
226
        completion_response = self.complete(prompt, **kwargs)
227
        return completion_response_to_chat_response(completion_response)
228

229
    @llm_completion_callback()
230
    def complete(
231
        self, prompt: str, formatted: bool = False, **kwargs: Any
232
    ) -> CompletionResponse:
233
        kwargs = kwargs if kwargs else {}
234
        params = {**self._model_kwargs, **kwargs}
235

236
        from vllm import SamplingParams
237

238
        # build sampling parameters
239
        sampling_params = SamplingParams(**params)
240
        outputs = self._client.generate([prompt], sampling_params)
241
        return CompletionResponse(text=outputs[0].outputs[0].text)
242

243
    @llm_chat_callback()
244
    def stream_chat(
245
        self, messages: Sequence[ChatMessage], **kwargs: Any
246
    ) -> ChatResponseGen:
247
        raise (ValueError("Not Implemented"))
248

249
    @llm_completion_callback()
250
    def stream_complete(
251
        self, prompt: str, formatted: bool = False, **kwargs: Any
252
    ) -> CompletionResponseGen:
253
        raise (ValueError("Not Implemented"))
254

255
    @llm_chat_callback()
256
    async def achat(
257
        self, messages: Sequence[ChatMessage], **kwargs: Any
258
    ) -> ChatResponse:
259
        kwargs = kwargs if kwargs else {}
260
        return self.chat(messages, **kwargs)
261

262
    @llm_completion_callback()
263
    async def acomplete(
264
        self, prompt: str, formatted: bool = False, **kwargs: Any
265
    ) -> CompletionResponse:
266
        raise (ValueError("Not Implemented"))
267

268
    @llm_chat_callback()
269
    async def astream_chat(
270
        self, messages: Sequence[ChatMessage], **kwargs: Any
271
    ) -> ChatResponseAsyncGen:
272
        raise (ValueError("Not Implemented"))
273

274
    @llm_completion_callback()
275
    async def astream_complete(
276
        self, prompt: str, formatted: bool = False, **kwargs: Any
277
    ) -> CompletionResponseAsyncGen:
278
        raise (ValueError("Not Implemented"))
279

280

281
class VllmServer(Vllm):
282
    def __init__(
283
        self,
284
        model: str = "facebook/opt-125m",
285
        api_url: str = "http://localhost:8000",
286
        temperature: float = 1.0,
287
        tensor_parallel_size: Optional[int] = 1,
288
        trust_remote_code: Optional[bool] = True,
289
        n: int = 1,
290
        best_of: Optional[int] = None,
291
        presence_penalty: float = 0.0,
292
        frequency_penalty: float = 0.0,
293
        top_p: float = 1.0,
294
        top_k: int = -1,
295
        use_beam_search: bool = False,
296
        stop: Optional[List[str]] = None,
297
        ignore_eos: bool = False,
298
        max_new_tokens: int = 512,
299
        logprobs: Optional[int] = None,
300
        dtype: str = "auto",
301
        download_dir: Optional[str] = None,
302
        messages_to_prompt: Optional[Callable] = None,
303
        completion_to_prompt: Optional[Callable] = None,
304
        vllm_kwargs: Dict[str, Any] = {},
305
        callback_manager: Optional[CallbackManager] = None,
306
        output_parser: Optional[BaseOutputParser] = None,
307
    ) -> None:
308
        self._client = None
309
        messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
310
        completion_to_prompt = completion_to_prompt or (lambda x: x)
311
        callback_manager = callback_manager or CallbackManager([])
312

313
        model = ""
314
        super().__init__(
315
            model=model,
316
            temperature=temperature,
317
            n=n,
318
            best_of=best_of,
319
            presence_penalty=presence_penalty,
320
            frequency_penalty=frequency_penalty,
321
            top_p=top_p,
322
            top_k=top_k,
323
            use_beam_search=use_beam_search,
324
            stop=stop,
325
            ignore_eos=ignore_eos,
326
            max_new_tokens=max_new_tokens,
327
            logprobs=logprobs,
328
            dtype=dtype,
329
            download_dir=download_dir,
330
            messages_to_prompt=messages_to_prompt,
331
            completion_to_prompt=completion_to_prompt,
332
            vllm_kwargs=vllm_kwargs,
333
            api_url=api_url,
334
            callback_manager=callback_manager,
335
            output_parser=output_parser,
336
        )
337

338
    @classmethod
339
    def class_name(cls) -> str:
340
        return "VllmServer"
341

342
    @llm_completion_callback()
343
    def complete(
344
        self, prompt: str, formatted: bool = False, **kwargs: Any
345
    ) -> List[CompletionResponse]:
346
        kwargs = kwargs if kwargs else {}
347
        params = {**self._model_kwargs, **kwargs}
348

349
        from vllm import SamplingParams
350

351
        # build sampling parameters
352
        sampling_params = SamplingParams(**params).__dict__
353
        sampling_params["prompt"] = prompt
354
        response = post_http_request(self.api_url, sampling_params, stream=False)
355
        output = get_response(response)
356

357
        return CompletionResponse(text=output[0])
358

359
    @llm_completion_callback()
360
    def stream_complete(
361
        self, prompt: str, formatted: bool = False, **kwargs: Any
362
    ) -> CompletionResponseGen:
363
        kwargs = kwargs if kwargs else {}
364
        params = {**self._model_kwargs, **kwargs}
365

366
        from vllm import SamplingParams
367

368
        # build sampling parameters
369
        sampling_params = SamplingParams(**params).__dict__
370
        sampling_params["prompt"] = prompt
371
        response = post_http_request(self.api_url, sampling_params, stream=True)
372

373
        def gen() -> CompletionResponseGen:
374
            for chunk in response.iter_lines(
375
                chunk_size=8192, decode_unicode=False, delimiter=b"\0"
376
            ):
377
                if chunk:
378
                    data = json.loads(chunk.decode("utf-8"))
379

380
                    yield CompletionResponse(text=data["text"][0])
381

382
        return gen()
383

384
    @llm_completion_callback()
385
    async def acomplete(
386
        self, prompt: str, formatted: bool = False, **kwargs: Any
387
    ) -> CompletionResponse:
388
        kwargs = kwargs if kwargs else {}
389
        return self.complete(prompt, **kwargs)
390

391
    @llm_completion_callback()
392
    async def astream_complete(
393
        self, prompt: str, formatted: bool = False, **kwargs: Any
394
    ) -> CompletionResponseAsyncGen:
395
        kwargs = kwargs if kwargs else {}
396
        params = {**self._model_kwargs, **kwargs}
397

398
        from vllm import SamplingParams
399

400
        # build sampling parameters
401
        sampling_params = SamplingParams(**params).__dict__
402
        sampling_params["prompt"] = prompt
403

404
        async def gen() -> CompletionResponseAsyncGen:
405
            for message in self.stream_complete(prompt, **kwargs):
406
                yield message
407

408
        return gen()
409

410
    @llm_chat_callback()
411
    def stream_chat(
412
        self, messages: Sequence[ChatMessage], **kwargs: Any
413
    ) -> ChatResponseGen:
414
        prompt = self.messages_to_prompt(messages)
415
        completion_response = self.stream_complete(prompt, **kwargs)
416
        return stream_completion_response_to_chat_response(completion_response)
417

418
    @llm_chat_callback()
419
    async def astream_chat(
420
        self, messages: Sequence[ChatMessage], **kwargs: Any
421
    ) -> ChatResponseAsyncGen:
422
        return self.stream_chat(messages, **kwargs)
423

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.