llama-index

Форк
0
134 строки · 4.5 Кб
1
from typing import Any, Dict, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field
4
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
5
from llama_index.legacy.core.llms.types import (
6
    ChatMessage,
7
    ChatResponse,
8
    ChatResponseGen,
9
    CompletionResponse,
10
    CompletionResponseGen,
11
    LLMMetadata,
12
)
13
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
14
from llama_index.legacy.llms.custom import CustomLLM
15
from llama_index.legacy.llms.generic_utils import (
16
    completion_response_to_chat_response,
17
    stream_completion_response_to_chat_response,
18
)
19

20
DEFAULT_REPLICATE_TEMP = 0.75
21

22

23
class Replicate(CustomLLM):
24
    model: str = Field(description="The Replicate model to use.")
25
    temperature: float = Field(
26
        default=DEFAULT_REPLICATE_TEMP,
27
        description="The temperature to use for sampling.",
28
        gte=0.01,
29
        lte=1.0,
30
    )
31
    image: str = Field(
32
        default="", description="The image file for multimodal model to use. (optional)"
33
    )
34
    context_window: int = Field(
35
        default=DEFAULT_CONTEXT_WINDOW,
36
        description="The maximum number of context tokens for the model.",
37
        gt=0,
38
    )
39
    prompt_key: str = Field(
40
        default="prompt", description="The key to use for the prompt in API calls."
41
    )
42
    additional_kwargs: Dict[str, Any] = Field(
43
        default_factory=dict, description="Additional kwargs for the Replicate API."
44
    )
45
    is_chat_model: bool = Field(
46
        default=False, description="Whether the model is a chat model."
47
    )
48

49
    @classmethod
50
    def class_name(cls) -> str:
51
        return "Replicate_llm"
52

53
    @property
54
    def metadata(self) -> LLMMetadata:
55
        """LLM metadata."""
56
        return LLMMetadata(
57
            context_window=self.context_window,
58
            num_output=DEFAULT_NUM_OUTPUTS,
59
            model_name=self.model,
60
            is_chat_model=self.is_chat_model,
61
        )
62

63
    @property
64
    def _model_kwargs(self) -> Dict[str, Any]:
65
        base_kwargs: Dict[str, Any] = {
66
            "temperature": self.temperature,
67
            "max_length": self.context_window,
68
        }
69
        if self.image != "":
70
            try:
71
                base_kwargs["image"] = open(self.image, "rb")
72
            except FileNotFoundError:
73
                raise FileNotFoundError(
74
                    "Could not load image file. Please check whether the file exists"
75
                )
76
        return {
77
            **base_kwargs,
78
            **self.additional_kwargs,
79
        }
80

81
    def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
82
        return {self.prompt_key: prompt, **self._model_kwargs, **kwargs}
83

84
    @llm_chat_callback()
85
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
86
        prompt = self.messages_to_prompt(messages)
87
        completion_response = self.complete(prompt, formatted=True, **kwargs)
88
        return completion_response_to_chat_response(completion_response)
89

90
    @llm_chat_callback()
91
    def stream_chat(
92
        self, messages: Sequence[ChatMessage], **kwargs: Any
93
    ) -> ChatResponseGen:
94
        prompt = self.messages_to_prompt(messages)
95
        completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
96
        return stream_completion_response_to_chat_response(completion_response)
97

98
    @llm_completion_callback()
99
    def complete(
100
        self, prompt: str, formatted: bool = False, **kwargs: Any
101
    ) -> CompletionResponse:
102
        response_gen = self.stream_complete(prompt, formatted=formatted, **kwargs)
103
        response_list = list(response_gen)
104
        final_response = response_list[-1]
105
        final_response.delta = None
106
        return final_response
107

108
    @llm_completion_callback()
109
    def stream_complete(
110
        self, prompt: str, formatted: bool = False, **kwargs: Any
111
    ) -> CompletionResponseGen:
112
        try:
113
            import replicate
114
        except ImportError:
115
            raise ImportError(
116
                "Could not import replicate library."
117
                "Please install replicate with `pip install replicate`"
118
            )
119

120
        if not formatted:
121
            prompt = self.completion_to_prompt(prompt)
122
        input_dict = self._get_input_dict(prompt, **kwargs)
123
        response_iter = replicate.run(self.model, input=input_dict)
124

125
        def gen() -> CompletionResponseGen:
126
            text = ""
127
            for delta in response_iter:
128
                text += delta
129
                yield CompletionResponse(
130
                    delta=delta,
131
                    text=text,
132
                )
133

134
        return gen()
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.