cohere-python

Форк
0
308 строк · 13.7 Кб
1
import asyncio
2
import os
3
import typing
4
from concurrent.futures import ThreadPoolExecutor
5

6
import httpx
7

8
from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate
9
from .base_client import BaseCohere, AsyncBaseCohere, OMIT
10
from .config import embed_batch_size
11
from .core import RequestOptions
12
from .environment import ClientEnvironment
13
from .overrides import run_overrides
14
from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils
15

16
run_overrides()
17

18
# Use NoReturn as Never type for compatibility
19
Never = typing.NoReturn
20

21

22
def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None:
23
    method = getattr(obj, method_name)
24

25
    def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
26
        check_fn(*args, **kwargs)
27
        return method(*args, **kwargs)
28

29
    setattr(obj, method_name, wrapped)
30

31

32
def throw_if_stream_is_true(*args, **kwargs) -> None:
33
    if kwargs.get("stream") is True:
34
        raise ValueError(
35
            "Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"
36
        )
37

38

39
def moved_function(fn_name: str, new_fn_name: str) -> typing.Any:
40
    """
41
    This method is moved. Please update usage.
42
    """
43

44
    def fn(*args, **kwargs):
45
        raise ValueError(
46
            f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). "
47
            f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."
48
        )
49

50
    return fn
51

52

53
def deprecated_function(fn_name: str) -> typing.Any:
54
    """
55
    This method is deprecated. Please update usage.
56
    """
57

58
    def fn(*args, **kwargs):
59
        raise ValueError(
60
            f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. "
61
            f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."
62
        )
63

64
    return fn
65

66

67
class Client(BaseCohere):
68
    def __init__(
69
            self,
70
            api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
71
            *,
72
            base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
73
            environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
74
            client_name: typing.Optional[str] = None,
75
            timeout: typing.Optional[float] = 60,
76
            httpx_client: typing.Optional[httpx.Client] = None,
77
    ):
78
        if api_key is None:
79
            api_key = os.getenv("CO_API_KEY")
80

81
        BaseCohere.__init__(
82
            self,
83
            base_url=base_url,
84
            environment=environment,
85
            client_name=client_name,
86
            token=api_key,
87
            timeout=timeout,
88
            httpx_client=httpx_client,
89
        )
90

91
        validate_args(self, "chat", throw_if_stream_is_true)
92

93
    utils = SyncSdkUtils()
94

95
    # support context manager until Fern upstreams
96
    # https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
97
    def __enter__(self):
98
        return self
99

100
    def __exit__(self, exc_type, exc_value, traceback):
101
        self._client_wrapper.httpx_client.httpx_client.close()
102

103
    wait = wait
104

105
    _executor = ThreadPoolExecutor(64)
106

107
    def embed(
108
            self,
109
            *,
110
            texts: typing.Sequence[str],
111
            model: typing.Optional[str] = OMIT,
112
            input_type: typing.Optional[EmbedInputType] = OMIT,
113
            embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
114
            truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
115
            request_options: typing.Optional[RequestOptions] = None,
116
            batching: typing.Optional[bool] = True,
117
    ) -> EmbedResponse:
118
        if batching is False:
119
            return BaseCohere.embed(
120
                self,
121
                texts=texts,
122
                model=model,
123
                input_type=input_type,
124
                embedding_types=embedding_types,
125
                truncate=truncate,
126
                request_options=request_options,
127
            )
128

129
        texts_batches = [texts[i: i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)]
130

131
        responses = [response for response in self._executor.map(lambda text_batch: BaseCohere.embed(
132
                self,
133
                texts=text_batch,
134
                model=model,
135
                input_type=input_type,
136
                embedding_types=embedding_types,
137
                truncate=truncate,
138
                request_options=request_options,
139
        ), texts_batches)]
140

141
        return merge_embed_responses(responses)
142

143
    """
144
    The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
145
    Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
146
    """
147
    check_api_key: Never = deprecated_function("check_api_key")
148
    loglikelihood: Never = deprecated_function("loglikelihood")
149
    batch_generate: Never = deprecated_function("batch_generate")
150
    codebook: Never = deprecated_function("codebook")
151
    batch_tokenize: Never = deprecated_function("batch_tokenize")
152
    batch_detokenize: Never = deprecated_function("batch_detokenize")
153
    detect_language: Never = deprecated_function("detect_language")
154
    generate_feedback: Never = deprecated_function("generate_feedback")
155
    generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")
156
    create_dataset: Never = moved_function("create_dataset", ".datasets.create")
157
    get_dataset: Never = moved_function("get_dataset", ".datasets.get")
158
    list_datasets: Never = moved_function("list_datasets", ".datasets.list")
159
    delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")
160
    get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")
161
    wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")
162
    _check_response: Never = deprecated_function("_check_response")
163
    _request: Never = deprecated_function("_request")
164
    create_cluster_job: Never = deprecated_function("create_cluster_job")
165
    get_cluster_job: Never = deprecated_function("get_cluster_job")
166
    list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")
167
    wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")
168
    create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")
169
    list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")
170
    get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")
171
    cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")
172
    wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")
173
    create_custom_model: Never = deprecated_function("create_custom_model")
174
    wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")
175
    _upload_dataset: Never = deprecated_function("_upload_dataset")
176
    _create_signed_url: Never = deprecated_function("_create_signed_url")
177
    get_custom_model: Never = deprecated_function("get_custom_model")
178
    get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")
179
    get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")
180
    list_custom_models: Never = deprecated_function("list_custom_models")
181
    create_connector: Never = moved_function("create_connector", ".connectors.create")
182
    update_connector: Never = moved_function("update_connector", ".connectors.update")
183
    get_connector: Never = moved_function("get_connector", ".connectors.get")
184
    list_connectors: Never = moved_function("list_connectors", ".connectors.list")
185
    delete_connector: Never = moved_function("delete_connector", ".connectors.delete")
186
    oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")
187

188

189
class AsyncClient(AsyncBaseCohere):
190
    def __init__(
191
            self,
192
            api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,
193
            *,
194
            base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
195
            environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
196
            client_name: typing.Optional[str] = None,
197
            timeout: typing.Optional[float] = 60,
198
            httpx_client: typing.Optional[httpx.AsyncClient] = None,
199
    ):
200
        if api_key is None:
201
            api_key = os.getenv("CO_API_KEY")
202

203
        AsyncBaseCohere.__init__(
204
            self,
205
            base_url=base_url,
206
            environment=environment,
207
            client_name=client_name,
208
            token=api_key,
209
            timeout=timeout,
210
            httpx_client=httpx_client,
211
        )
212

213
        validate_args(self, "chat", throw_if_stream_is_true)
214

215
    utils = AsyncSdkUtils()
216

217
    # support context manager until Fern upstreams
218
    # https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily
219
    async def __aenter__(self):
220
        return self
221

222
    async def __aexit__(self, exc_type, exc_value, traceback):
223
        await self._client_wrapper.httpx_client.httpx_client.aclose()
224

225
    wait = async_wait
226

227
    _executor = ThreadPoolExecutor(64)
228

229
    async def embed(
230
            self,
231
            *,
232
            texts: typing.Sequence[str],
233
            model: typing.Optional[str] = OMIT,
234
            input_type: typing.Optional[EmbedInputType] = OMIT,
235
            embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
236
            truncate: typing.Optional[EmbedRequestTruncate] = OMIT,
237
            request_options: typing.Optional[RequestOptions] = None,
238
            batching: typing.Optional[bool] = True,
239
    ) -> EmbedResponse:
240
        if batching is False:
241
            return await AsyncBaseCohere.embed(
242
                self,
243
                texts=texts,
244
                model=model,
245
                input_type=input_type,
246
                embedding_types=embedding_types,
247
                truncate=truncate,
248
                request_options=request_options,
249
            )
250

251
        texts_batches = [texts[i: i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)]
252

253
        responses = typing.cast(typing.List[EmbedResponse], await asyncio.gather(*[AsyncBaseCohere.embed(
254
                self,
255
                texts=text_batch,
256
                model=model,
257
                input_type=input_type,
258
                embedding_types=embedding_types,
259
                truncate=truncate,
260
                request_options=request_options,
261
        ) for text_batch in texts_batches]))
262

263
        return merge_embed_responses(responses)
264

265
    """
266
    The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
267
    Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
268
    """
269
    check_api_key: Never = deprecated_function("check_api_key")
270
    loglikelihood: Never = deprecated_function("loglikelihood")
271
    batch_generate: Never = deprecated_function("batch_generate")
272
    codebook: Never = deprecated_function("codebook")
273
    batch_tokenize: Never = deprecated_function("batch_tokenize")
274
    batch_detokenize: Never = deprecated_function("batch_detokenize")
275
    detect_language: Never = deprecated_function("detect_language")
276
    generate_feedback: Never = deprecated_function("generate_feedback")
277
    generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")
278
    create_dataset: Never = moved_function("create_dataset", ".datasets.create")
279
    get_dataset: Never = moved_function("get_dataset", ".datasets.get")
280
    list_datasets: Never = moved_function("list_datasets", ".datasets.list")
281
    delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")
282
    get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")
283
    wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")
284
    _check_response: Never = deprecated_function("_check_response")
285
    _request: Never = deprecated_function("_request")
286
    create_cluster_job: Never = deprecated_function("create_cluster_job")
287
    get_cluster_job: Never = deprecated_function("get_cluster_job")
288
    list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")
289
    wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")
290
    create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")
291
    list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")
292
    get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")
293
    cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")
294
    wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")
295
    create_custom_model: Never = deprecated_function("create_custom_model")
296
    wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")
297
    _upload_dataset: Never = deprecated_function("_upload_dataset")
298
    _create_signed_url: Never = deprecated_function("_create_signed_url")
299
    get_custom_model: Never = deprecated_function("get_custom_model")
300
    get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")
301
    get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")
302
    list_custom_models: Never = deprecated_function("list_custom_models")
303
    create_connector: Never = moved_function("create_connector", ".connectors.create")
304
    update_connector: Never = moved_function("update_connector", ".connectors.update")
305
    get_connector: Never = moved_function("get_connector", ".connectors.get")
306
    list_connectors: Never = moved_function("list_connectors", ".connectors.list")
307
    delete_connector: Never = moved_function("delete_connector", ".connectors.delete")
308
    oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")
309

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.