llama-index

Форк
0
141 строка · 4.7 Кб
1
from typing import Any, Callable, Dict, Optional, Sequence
2

3
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
4
from llama_index.legacy.callbacks import CallbackManager
5
from llama_index.legacy.core.llms.types import (
6
    ChatMessage,
7
    ChatResponse,
8
    ChatResponseGen,
9
    CompletionResponse,
10
    CompletionResponseGen,
11
    LLMMetadata,
12
)
13
from llama_index.legacy.llms.ai21_utils import ai21_model_to_context_size
14
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
15
from llama_index.legacy.llms.custom import CustomLLM
16
from llama_index.legacy.llms.generic_utils import (
17
    completion_to_chat_decorator,
18
    get_from_param_or_env,
19
)
20
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
21

22

23
class AI21(CustomLLM):
24
    """AI21 Labs LLM."""
25

26
    model: str = Field(description="The AI21 model to use.")
27
    maxTokens: int = Field(description="The maximum number of tokens to generate.")
28
    temperature: float = Field(description="The temperature to use for sampling.")
29

30
    additional_kwargs: Dict[str, Any] = Field(
31
        default_factory=dict, description="Additional kwargs for the anthropic API."
32
    )
33

34
    _api_key = PrivateAttr()
35

36
    def __init__(
37
        self,
38
        api_key: Optional[str] = None,
39
        model: Optional[str] = "j2-mid",
40
        maxTokens: Optional[int] = 512,
41
        temperature: Optional[float] = 0.1,
42
        additional_kwargs: Optional[Dict[str, Any]] = None,
43
        callback_manager: Optional[CallbackManager] = None,
44
        system_prompt: Optional[str] = None,
45
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
46
        completion_to_prompt: Optional[Callable[[str], str]] = None,
47
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
48
        output_parser: Optional[BaseOutputParser] = None,
49
    ) -> None:
50
        """Initialize params."""
51
        try:
52
            import ai21 as _  # noqa
53
        except ImportError as e:
54
            raise ImportError(
55
                "You must install the `ai21` package to use AI21."
56
                "Please `pip install ai21`"
57
            ) from e
58

59
        additional_kwargs = additional_kwargs or {}
60
        callback_manager = callback_manager or CallbackManager([])
61

62
        api_key = get_from_param_or_env("api_key", api_key, "AI21_API_KEY")
63
        self._api_key = api_key
64

65
        super().__init__(
66
            model=model,
67
            maxTokens=maxTokens,
68
            temperature=temperature,
69
            additional_kwargs=additional_kwargs,
70
            callback_manager=callback_manager,
71
            system_prompt=system_prompt,
72
            messages_to_prompt=messages_to_prompt,
73
            completion_to_prompt=completion_to_prompt,
74
            pydantic_program_mode=pydantic_program_mode,
75
            output_parser=output_parser,
76
        )
77

78
    @classmethod
79
    def class_name(self) -> str:
80
        """Get Class Name."""
81
        return "AI21_LLM"
82

83
    @property
84
    def metadata(self) -> LLMMetadata:
85
        return LLMMetadata(
86
            context_window=ai21_model_to_context_size(self.model),
87
            num_output=self.maxTokens,
88
            model_name=self.model,
89
        )
90

91
    @property
92
    def _model_kwargs(self) -> Dict[str, Any]:
93
        base_kwargs = {
94
            "model": self.model,
95
            "maxTokens": self.maxTokens,
96
            "temperature": self.temperature,
97
        }
98
        return {**base_kwargs, **self.additional_kwargs}
99

100
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
101
        return {
102
            **self._model_kwargs,
103
            **kwargs,
104
        }
105

106
    @llm_completion_callback()
107
    def complete(
108
        self, prompt: str, formatted: bool = False, **kwargs: Any
109
    ) -> CompletionResponse:
110
        all_kwargs = self._get_all_kwargs(**kwargs)
111

112
        import ai21
113

114
        ai21.api_key = self._api_key
115

116
        response = ai21.Completion.execute(**all_kwargs, prompt=prompt)
117

118
        return CompletionResponse(
119
            text=response["completions"][0]["data"]["text"], raw=response.__dict__
120
        )
121

122
    @llm_completion_callback()
123
    def stream_complete(
124
        self, prompt: str, formatted: bool = False, **kwargs: Any
125
    ) -> CompletionResponseGen:
126
        raise NotImplementedError(
127
            "AI21 does not currently support streaming completion."
128
        )
129

130
    @llm_chat_callback()
131
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
132
        all_kwargs = self._get_all_kwargs(**kwargs)
133
        chat_fn = completion_to_chat_decorator(self.complete)
134

135
        return chat_fn(messages, **all_kwargs)
136

137
    @llm_chat_callback()
138
    def stream_chat(
139
        self, messages: Sequence[ChatMessage], **kwargs: Any
140
    ) -> ChatResponseGen:
141
        raise NotImplementedError("AI21 does not Currently Support Streaming Chat.")
142

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.