llama-index

Форк
0
78 строк · 2.5 Кб
1
from typing import Any, Callable, Optional, Sequence
2

3
from llama_index.legacy.callbacks import CallbackManager
4
from llama_index.legacy.core.llms.types import (
5
    ChatMessage,
6
    CompletionResponse,
7
    CompletionResponseGen,
8
    LLMMetadata,
9
)
10
from llama_index.legacy.llms.base import llm_completion_callback
11
from llama_index.legacy.llms.custom import CustomLLM
12
from llama_index.legacy.types import PydanticProgramMode
13

14

15
class MockLLM(CustomLLM):
16
    max_tokens: Optional[int]
17

18
    def __init__(
19
        self,
20
        max_tokens: Optional[int] = None,
21
        callback_manager: Optional[CallbackManager] = None,
22
        system_prompt: Optional[str] = None,
23
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
24
        completion_to_prompt: Optional[Callable[[str], str]] = None,
25
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
26
    ) -> None:
27
        super().__init__(
28
            max_tokens=max_tokens,
29
            callback_manager=callback_manager,
30
            system_prompt=system_prompt,
31
            messages_to_prompt=messages_to_prompt,
32
            completion_to_prompt=completion_to_prompt,
33
            pydantic_program_mode=pydantic_program_mode,
34
        )
35

36
    @classmethod
37
    def class_name(cls) -> str:
38
        return "MockLLM"
39

40
    @property
41
    def metadata(self) -> LLMMetadata:
42
        return LLMMetadata(num_output=self.max_tokens or -1)
43

44
    def _generate_text(self, length: int) -> str:
45
        return " ".join(["text" for _ in range(length)])
46

47
    @llm_completion_callback()
48
    def complete(
49
        self, prompt: str, formatted: bool = False, **kwargs: Any
50
    ) -> CompletionResponse:
51
        response_text = (
52
            self._generate_text(self.max_tokens) if self.max_tokens else prompt
53
        )
54

55
        return CompletionResponse(
56
            text=response_text,
57
        )
58

59
    @llm_completion_callback()
60
    def stream_complete(
61
        self, prompt: str, formatted: bool = False, **kwargs: Any
62
    ) -> CompletionResponseGen:
63
        def gen_prompt() -> CompletionResponseGen:
64
            for ch in prompt:
65
                yield CompletionResponse(
66
                    text=prompt,
67
                    delta=ch,
68
                )
69

70
        def gen_response(max_tokens: int) -> CompletionResponseGen:
71
            for i in range(max_tokens):
72
                response_text = self._generate_text(i)
73
                yield CompletionResponse(
74
                    text=response_text,
75
                    delta="text ",
76
                )
77

78
        return gen_response(self.max_tokens) if self.max_tokens else gen_prompt()
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.