llama-index
216 строк · 7.5 Кб
1from typing import Any, Callable, Dict, Optional, Sequence2
3from llama_index.legacy.bridge.pydantic import Field, PrivateAttr4from llama_index.legacy.callbacks import CallbackManager5from llama_index.legacy.core.llms.types import (6ChatMessage,7ChatResponse,8ChatResponseAsyncGen,9ChatResponseGen,10CompletionResponse,11CompletionResponseAsyncGen,12CompletionResponseGen,13LLMMetadata,14)
15from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback16from llama_index.legacy.llms.generic_utils import (17completion_to_chat_decorator,18stream_completion_to_chat_decorator,19)
20from llama_index.legacy.llms.llm import LLM21from llama_index.legacy.llms.watsonx_utils import (22WATSONX_MODELS,23get_from_param_or_env_without_error,24watsonx_model_to_context_size,25)
26from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode27
28
29class WatsonX(LLM):30"""IBM WatsonX LLM."""31
32model_id: str = Field(description="The Model to use.")33max_new_tokens: int = Field(description="The maximum number of tokens to generate.")34temperature: float = Field(description="The temperature to use for sampling.")35additional_kwargs: Dict[str, Any] = Field(36default_factory=dict, description="Additional Kwargs for the WatsonX model"37)38model_info: Dict[str, Any] = Field(39default_factory=dict, description="Details about the selected model"40)41
42_model = PrivateAttr()43
44def __init__(45self,46credentials: Dict[str, Any],47model_id: Optional[str] = "ibm/mpt-7b-instruct2",48project_id: Optional[str] = None,49space_id: Optional[str] = None,50max_new_tokens: Optional[int] = 512,51temperature: Optional[float] = 0.1,52additional_kwargs: Optional[Dict[str, Any]] = None,53callback_manager: Optional[CallbackManager] = None,54system_prompt: Optional[str] = None,55messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,56completion_to_prompt: Optional[Callable[[str], str]] = None,57pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,58output_parser: Optional[BaseOutputParser] = None,59) -> None:60"""Initialize params."""61if model_id not in WATSONX_MODELS:62raise ValueError(63f"Model name {model_id} not found in {WATSONX_MODELS.keys()}"64)65
66try:67from ibm_watson_machine_learning.foundation_models.model import Model68except ImportError as e:69raise ImportError(70"You must install the `ibm_watson_machine_learning` package to use WatsonX"71"please `pip install ibm_watson_machine_learning`"72) from e73
74additional_kwargs = additional_kwargs or {}75callback_manager = callback_manager or CallbackManager([])76
77project_id = get_from_param_or_env_without_error(78project_id, "IBM_WATSONX_PROJECT_ID"79)80space_id = get_from_param_or_env_without_error(space_id, "IBM_WATSONX_SPACE_ID")81
82if project_id is not None or space_id is not None:83self._model = Model(84model_id=model_id,85credentials=credentials,86project_id=project_id,87space_id=space_id,88)89else:90raise ValueError(91f"Did not find `project_id` or `space_id`, Please pass them as named parameters"92f" or as environment variables, `IBM_WATSONX_PROJECT_ID` or `IBM_WATSONX_SPACE_ID`."93)94
95super().__init__(96model_id=model_id,97temperature=temperature,98max_new_tokens=max_new_tokens,99additional_kwargs=additional_kwargs,100model_info=self._model.get_details(),101callback_manager=callback_manager,102system_prompt=system_prompt,103messages_to_prompt=messages_to_prompt,104completion_to_prompt=completion_to_prompt,105pydantic_program_mode=pydantic_program_mode,106output_parser=output_parser,107)108
109@classmethod110def class_name(self) -> str:111"""Get Class Name."""112return "WatsonX_LLM"113
114@property115def metadata(self) -> LLMMetadata:116return LLMMetadata(117context_window=watsonx_model_to_context_size(self.model_id),118num_output=self.max_new_tokens,119model_name=self.model_id,120)121
122@property123def sample_model_kwargs(self) -> Dict[str, Any]:124"""Get a sample of Model kwargs that a user can pass to the model."""125try:126from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames127except ImportError as e:128raise ImportError(129"You must install the `ibm_watson_machine_learning` package to use WatsonX"130"please `pip install ibm_watson_machine_learning`"131) from e132
133params = GenTextParamsMetaNames().get_example_values()134
135params.pop("return_options")136
137return params138
139@property140def _model_kwargs(self) -> Dict[str, Any]:141base_kwargs = {142"max_new_tokens": self.max_new_tokens,143"temperature": self.temperature,144}145
146return {**base_kwargs, **self.additional_kwargs}147
148def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:149return {**self._model_kwargs, **kwargs}150
151@llm_completion_callback()152def complete(153self, prompt: str, formatted: bool = False, **kwargs: Any154) -> CompletionResponse:155all_kwargs = self._get_all_kwargs(**kwargs)156
157response = self._model.generate_text(prompt=prompt, params=all_kwargs)158
159return CompletionResponse(text=response)160
161@llm_completion_callback()162def stream_complete(163self, prompt: str, formatted: bool = False, **kwargs: Any164) -> CompletionResponseGen:165all_kwargs = self._get_all_kwargs(**kwargs)166
167stream_response = self._model.generate_text_stream(168prompt=prompt, params=all_kwargs169)170
171def gen() -> CompletionResponseGen:172content = ""173for stream_delta in stream_response:174content += stream_delta175yield CompletionResponse(text=content, delta=stream_delta)176
177return gen()178
179@llm_chat_callback()180def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:181all_kwargs = self._get_all_kwargs(**kwargs)182chat_fn = completion_to_chat_decorator(self.complete)183
184return chat_fn(messages, **all_kwargs)185
186@llm_chat_callback()187def stream_chat(188self, messages: Sequence[ChatMessage], **kwargs: Any189) -> ChatResponseGen:190all_kwargs = self._get_all_kwargs(**kwargs)191chat_stream_fn = stream_completion_to_chat_decorator(self.stream_complete)192
193return chat_stream_fn(messages, **all_kwargs)194
195# Async Functions196# IBM Watson Machine Learning Package currently does not have Support for Async calls197
198async def acomplete(199self, prompt: str, formatted: bool = False, **kwargs: Any200) -> CompletionResponse:201raise NotImplementedError202
203async def astream_chat(204self, messages: Sequence[ChatMessage], **kwargs: Any205) -> ChatResponseAsyncGen:206raise NotImplementedError207
208async def achat(209self, messages: Sequence[ChatMessage], **kwargs: Any210) -> ChatResponse:211raise NotImplementedError212
213async def astream_complete(214self, prompt: str, formatted: bool = False, **kwargs: Any215) -> CompletionResponseAsyncGen:216raise NotImplementedError217