llama-index

Форк
0
216 строк · 7.5 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.core.llms.types import (
6
    ChatMessage,
7
    ChatResponse,
8
    ChatResponseAsyncGen,
9
    ChatResponseGen,
10
    CompletionResponse,
11
    CompletionResponseAsyncGen,
12
    CompletionResponseGen,
13
    LLMMetadata,
14
)
15
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
16
from llama_index.legacy.llms.generic_utils import (
17
    completion_to_chat_decorator,
18
    stream_completion_to_chat_decorator,
19
)
20
from llama_index.legacy.llms.llm import LLM
21
from llama_index.legacy.llms.watsonx_utils import (
22
    WATSONX_MODELS,
23
    get_from_param_or_env_without_error,
24
    watsonx_model_to_context_size,
25
)
26
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
27

28

29
class WatsonX(LLM):
30
    """IBM WatsonX LLM."""
31

32
    model_id: str = Field(description="The Model to use.")
33
    max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
34
    temperature: float = Field(description="The temperature to use for sampling.")
35
    additional_kwargs: Dict[str, Any] = Field(
36
        default_factory=dict, description="Additional Kwargs for the WatsonX model"
37
    )
38
    model_info: Dict[str, Any] = Field(
39
        default_factory=dict, description="Details about the selected model"
40
    )
41

42
    _model = PrivateAttr()
43

44
    def __init__(
45
        self,
46
        credentials: Dict[str, Any],
47
        model_id: Optional[str] = "ibm/mpt-7b-instruct2",
48
        project_id: Optional[str] = None,
49
        space_id: Optional[str] = None,
50
        max_new_tokens: Optional[int] = 512,
51
        temperature: Optional[float] = 0.1,
52
        additional_kwargs: Optional[Dict[str, Any]] = None,
53
        callback_manager: Optional[CallbackManager] = None,
54
        system_prompt: Optional[str] = None,
55
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
56
        completion_to_prompt: Optional[Callable[[str], str]] = None,
57
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
58
        output_parser: Optional[BaseOutputParser] = None,
59
    ) -> None:
60
        """Initialize params."""
61
        if model_id not in WATSONX_MODELS:
62
            raise ValueError(
63
                f"Model name {model_id} not found in {WATSONX_MODELS.keys()}"
64
            )
65

66
        try:
67
            from ibm_watson_machine_learning.foundation_models.model import Model
68
        except ImportError as e:
69
            raise ImportError(
70
                "You must install the `ibm_watson_machine_learning` package to use WatsonX"
71
                "please `pip install ibm_watson_machine_learning`"
72
            ) from e
73

74
        additional_kwargs = additional_kwargs or {}
75
        callback_manager = callback_manager or CallbackManager([])
76

77
        project_id = get_from_param_or_env_without_error(
78
            project_id, "IBM_WATSONX_PROJECT_ID"
79
        )
80
        space_id = get_from_param_or_env_without_error(space_id, "IBM_WATSONX_SPACE_ID")
81

82
        if project_id is not None or space_id is not None:
83
            self._model = Model(
84
                model_id=model_id,
85
                credentials=credentials,
86
                project_id=project_id,
87
                space_id=space_id,
88
            )
89
        else:
90
            raise ValueError(
91
                f"Did not find `project_id` or `space_id`, Please pass them as named parameters"
92
                f" or as environment variables, `IBM_WATSONX_PROJECT_ID` or `IBM_WATSONX_SPACE_ID`."
93
            )
94

95
        super().__init__(
96
            model_id=model_id,
97
            temperature=temperature,
98
            max_new_tokens=max_new_tokens,
99
            additional_kwargs=additional_kwargs,
100
            model_info=self._model.get_details(),
101
            callback_manager=callback_manager,
102
            system_prompt=system_prompt,
103
            messages_to_prompt=messages_to_prompt,
104
            completion_to_prompt=completion_to_prompt,
105
            pydantic_program_mode=pydantic_program_mode,
106
            output_parser=output_parser,
107
        )
108

109
    @classmethod
110
    def class_name(self) -> str:
111
        """Get Class Name."""
112
        return "WatsonX_LLM"
113

114
    @property
115
    def metadata(self) -> LLMMetadata:
116
        return LLMMetadata(
117
            context_window=watsonx_model_to_context_size(self.model_id),
118
            num_output=self.max_new_tokens,
119
            model_name=self.model_id,
120
        )
121

122
    @property
123
    def sample_model_kwargs(self) -> Dict[str, Any]:
124
        """Get a sample of Model kwargs that a user can pass to the model."""
125
        try:
126
            from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames
127
        except ImportError as e:
128
            raise ImportError(
129
                "You must install the `ibm_watson_machine_learning` package to use WatsonX"
130
                "please `pip install ibm_watson_machine_learning`"
131
            ) from e
132

133
        params = GenTextParamsMetaNames().get_example_values()
134

135
        params.pop("return_options")
136

137
        return params
138

139
    @property
140
    def _model_kwargs(self) -> Dict[str, Any]:
141
        base_kwargs = {
142
            "max_new_tokens": self.max_new_tokens,
143
            "temperature": self.temperature,
144
        }
145

146
        return {**base_kwargs, **self.additional_kwargs}
147

148
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
149
        return {**self._model_kwargs, **kwargs}
150

151
    @llm_completion_callback()
152
    def complete(
153
        self, prompt: str, formatted: bool = False, **kwargs: Any
154
    ) -> CompletionResponse:
155
        all_kwargs = self._get_all_kwargs(**kwargs)
156

157
        response = self._model.generate_text(prompt=prompt, params=all_kwargs)
158

159
        return CompletionResponse(text=response)
160

161
    @llm_completion_callback()
162
    def stream_complete(
163
        self, prompt: str, formatted: bool = False, **kwargs: Any
164
    ) -> CompletionResponseGen:
165
        all_kwargs = self._get_all_kwargs(**kwargs)
166

167
        stream_response = self._model.generate_text_stream(
168
            prompt=prompt, params=all_kwargs
169
        )
170

171
        def gen() -> CompletionResponseGen:
172
            content = ""
173
            for stream_delta in stream_response:
174
                content += stream_delta
175
                yield CompletionResponse(text=content, delta=stream_delta)
176

177
        return gen()
178

179
    @llm_chat_callback()
180
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
181
        all_kwargs = self._get_all_kwargs(**kwargs)
182
        chat_fn = completion_to_chat_decorator(self.complete)
183

184
        return chat_fn(messages, **all_kwargs)
185

186
    @llm_chat_callback()
187
    def stream_chat(
188
        self, messages: Sequence[ChatMessage], **kwargs: Any
189
    ) -> ChatResponseGen:
190
        all_kwargs = self._get_all_kwargs(**kwargs)
191
        chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete)
192

193
        return chat_stream_fn(messages, **all_kwargs)
194

195
    # Async Functions
196
    # IBM Watson Machine Learning Package currently does not have Support for Async calls
197

198
    async def acomplete(
199
        self, prompt: str, formatted: bool = False, **kwargs: Any
200
    ) -> CompletionResponse:
201
        raise NotImplementedError
202

203
    async def astream_chat(
204
        self, messages: Sequence[ChatMessage], **kwargs: Any
205
    ) -> ChatResponseAsyncGen:
206
        raise NotImplementedError
207

208
    async def achat(
209
        self, messages: Sequence[ChatMessage], **kwargs: Any
210
    ) -> ChatResponse:
211
        raise NotImplementedError
212

213
    async def astream_complete(
214
        self, prompt: str, formatted: bool = False, **kwargs: Any
215
    ) -> CompletionResponseAsyncGen:
216
        raise NotImplementedError
217

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.