cohere-python
308 строк · 13.7 Кб
1import asyncio2import os3import typing4from concurrent.futures import ThreadPoolExecutor5
6import httpx7
8from . import EmbedResponse, EmbedInputType, EmbeddingType, EmbedRequestTruncate9from .base_client import BaseCohere, AsyncBaseCohere, OMIT10from .config import embed_batch_size11from .core import RequestOptions12from .environment import ClientEnvironment13from .overrides import run_overrides14from .utils import wait, async_wait, merge_embed_responses, SyncSdkUtils, AsyncSdkUtils15
16run_overrides()17
18# Use NoReturn as Never type for compatibility
19Never = typing.NoReturn20
21
22def validate_args(obj: typing.Any, method_name: str, check_fn: typing.Callable[[typing.Any], typing.Any]) -> None:23method = getattr(obj, method_name)24
25def wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:26check_fn(*args, **kwargs)27return method(*args, **kwargs)28
29setattr(obj, method_name, wrapped)30
31
32def throw_if_stream_is_true(*args, **kwargs) -> None:33if kwargs.get("stream") is True:34raise ValueError(35"Since python sdk cohere==5.0.0, you must now use chat_stream(...) instead of chat(stream=True, ...)"36)37
38
39def moved_function(fn_name: str, new_fn_name: str) -> typing.Any:40"""41This method is moved. Please update usage.
42"""
43
44def fn(*args, **kwargs):45raise ValueError(46f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been moved to {new_fn_name}(...). "47f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."48)49
50return fn51
52
53def deprecated_function(fn_name: str) -> typing.Any:54"""55This method is deprecated. Please update usage.
56"""
57
58def fn(*args, **kwargs):59raise ValueError(60f"Since python sdk cohere==5.0.0, the function {fn_name}(...) has been deprecated. "61f"Please update your code. Issues may be filed in https://github.com/cohere-ai/cohere-python/issues."62)63
64return fn65
66
67class Client(BaseCohere):68def __init__(69self,70api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,71*,72base_url: typing.Optional[str] = os.getenv("CO_API_URL"),73environment: ClientEnvironment = ClientEnvironment.PRODUCTION,74client_name: typing.Optional[str] = None,75timeout: typing.Optional[float] = 60,76httpx_client: typing.Optional[httpx.Client] = None,77):78if api_key is None:79api_key = os.getenv("CO_API_KEY")80
81BaseCohere.__init__(82self,83base_url=base_url,84environment=environment,85client_name=client_name,86token=api_key,87timeout=timeout,88httpx_client=httpx_client,89)90
91validate_args(self, "chat", throw_if_stream_is_true)92
93utils = SyncSdkUtils()94
95# support context manager until Fern upstreams96# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily97def __enter__(self):98return self99
100def __exit__(self, exc_type, exc_value, traceback):101self._client_wrapper.httpx_client.httpx_client.close()102
103wait = wait104
105_executor = ThreadPoolExecutor(64)106
107def embed(108self,109*,110texts: typing.Sequence[str],111model: typing.Optional[str] = OMIT,112input_type: typing.Optional[EmbedInputType] = OMIT,113embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,114truncate: typing.Optional[EmbedRequestTruncate] = OMIT,115request_options: typing.Optional[RequestOptions] = None,116batching: typing.Optional[bool] = True,117) -> EmbedResponse:118if batching is False:119return BaseCohere.embed(120self,121texts=texts,122model=model,123input_type=input_type,124embedding_types=embedding_types,125truncate=truncate,126request_options=request_options,127)128
129texts_batches = [texts[i: i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)]130
131responses = [response for response in self._executor.map(lambda text_batch: BaseCohere.embed(132self,133texts=text_batch,134model=model,135input_type=input_type,136embedding_types=embedding_types,137truncate=truncate,138request_options=request_options,139), texts_batches)]140
141return merge_embed_responses(responses)142
143"""144The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
145Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
146"""
147check_api_key: Never = deprecated_function("check_api_key")148loglikelihood: Never = deprecated_function("loglikelihood")149batch_generate: Never = deprecated_function("batch_generate")150codebook: Never = deprecated_function("codebook")151batch_tokenize: Never = deprecated_function("batch_tokenize")152batch_detokenize: Never = deprecated_function("batch_detokenize")153detect_language: Never = deprecated_function("detect_language")154generate_feedback: Never = deprecated_function("generate_feedback")155generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")156create_dataset: Never = moved_function("create_dataset", ".datasets.create")157get_dataset: Never = moved_function("get_dataset", ".datasets.get")158list_datasets: Never = moved_function("list_datasets", ".datasets.list")159delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")160get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")161wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")162_check_response: Never = deprecated_function("_check_response")163_request: Never = deprecated_function("_request")164create_cluster_job: Never = deprecated_function("create_cluster_job")165get_cluster_job: Never = deprecated_function("get_cluster_job")166list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")167wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")168create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")169list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")170get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")171cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")172wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")173create_custom_model: Never = deprecated_function("create_custom_model")174wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")175_upload_dataset: Never = deprecated_function("_upload_dataset")176_create_signed_url: Never = deprecated_function("_create_signed_url")177get_custom_model: Never = deprecated_function("get_custom_model")178get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")179get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")180list_custom_models: Never = deprecated_function("list_custom_models")181create_connector: Never = moved_function("create_connector", ".connectors.create")182update_connector: Never = moved_function("update_connector", ".connectors.update")183get_connector: Never = moved_function("get_connector", ".connectors.get")184list_connectors: Never = moved_function("list_connectors", ".connectors.list")185delete_connector: Never = moved_function("delete_connector", ".connectors.delete")186oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")187
188
189class AsyncClient(AsyncBaseCohere):190def __init__(191self,192api_key: typing.Optional[typing.Union[str, typing.Callable[[], str]]] = None,193*,194base_url: typing.Optional[str] = os.getenv("CO_API_URL"),195environment: ClientEnvironment = ClientEnvironment.PRODUCTION,196client_name: typing.Optional[str] = None,197timeout: typing.Optional[float] = 60,198httpx_client: typing.Optional[httpx.AsyncClient] = None,199):200if api_key is None:201api_key = os.getenv("CO_API_KEY")202
203AsyncBaseCohere.__init__(204self,205base_url=base_url,206environment=environment,207client_name=client_name,208token=api_key,209timeout=timeout,210httpx_client=httpx_client,211)212
213validate_args(self, "chat", throw_if_stream_is_true)214
215utils = AsyncSdkUtils()216
217# support context manager until Fern upstreams218# https://linear.app/buildwithfern/issue/FER-1242/expose-a-context-manager-interface-or-the-http-client-easily219async def __aenter__(self):220return self221
222async def __aexit__(self, exc_type, exc_value, traceback):223await self._client_wrapper.httpx_client.httpx_client.aclose()224
225wait = async_wait226
227_executor = ThreadPoolExecutor(64)228
229async def embed(230self,231*,232texts: typing.Sequence[str],233model: typing.Optional[str] = OMIT,234input_type: typing.Optional[EmbedInputType] = OMIT,235embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,236truncate: typing.Optional[EmbedRequestTruncate] = OMIT,237request_options: typing.Optional[RequestOptions] = None,238batching: typing.Optional[bool] = True,239) -> EmbedResponse:240if batching is False:241return await AsyncBaseCohere.embed(242self,243texts=texts,244model=model,245input_type=input_type,246embedding_types=embedding_types,247truncate=truncate,248request_options=request_options,249)250
251texts_batches = [texts[i: i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)]252
253responses = typing.cast(typing.List[EmbedResponse], await asyncio.gather(*[AsyncBaseCohere.embed(254self,255texts=text_batch,256model=model,257input_type=input_type,258embedding_types=embedding_types,259truncate=truncate,260request_options=request_options,261) for text_batch in texts_batches]))262
263return merge_embed_responses(responses)264
265"""266The following methods have been moved or deprecated in cohere==5.0.0. Please update your usage.
267Issues may be filed in https://github.com/cohere-ai/cohere-python/issues.
268"""
269check_api_key: Never = deprecated_function("check_api_key")270loglikelihood: Never = deprecated_function("loglikelihood")271batch_generate: Never = deprecated_function("batch_generate")272codebook: Never = deprecated_function("codebook")273batch_tokenize: Never = deprecated_function("batch_tokenize")274batch_detokenize: Never = deprecated_function("batch_detokenize")275detect_language: Never = deprecated_function("detect_language")276generate_feedback: Never = deprecated_function("generate_feedback")277generate_preference_feedback: Never = deprecated_function("generate_preference_feedback")278create_dataset: Never = moved_function("create_dataset", ".datasets.create")279get_dataset: Never = moved_function("get_dataset", ".datasets.get")280list_datasets: Never = moved_function("list_datasets", ".datasets.list")281delete_dataset: Never = moved_function("delete_dataset", ".datasets.delete")282get_dataset_usage: Never = moved_function("get_dataset_usage", ".datasets.get_usage")283wait_for_dataset: Never = moved_function("wait_for_dataset", ".wait")284_check_response: Never = deprecated_function("_check_response")285_request: Never = deprecated_function("_request")286create_cluster_job: Never = deprecated_function("create_cluster_job")287get_cluster_job: Never = deprecated_function("get_cluster_job")288list_cluster_jobs: Never = deprecated_function("list_cluster_jobs")289wait_for_cluster_job: Never = deprecated_function("wait_for_cluster_job")290create_embed_job: Never = moved_function("create_embed_job", ".embed_jobs.create")291list_embed_jobs: Never = moved_function("list_embed_jobs", ".embed_jobs.list")292get_embed_job: Never = moved_function("get_embed_job", ".embed_jobs.get")293cancel_embed_job: Never = moved_function("cancel_embed_job", ".embed_jobs.cancel")294wait_for_embed_job: Never = moved_function("wait_for_embed_job", ".wait")295create_custom_model: Never = deprecated_function("create_custom_model")296wait_for_custom_model: Never = deprecated_function("wait_for_custom_model")297_upload_dataset: Never = deprecated_function("_upload_dataset")298_create_signed_url: Never = deprecated_function("_create_signed_url")299get_custom_model: Never = deprecated_function("get_custom_model")300get_custom_model_by_name: Never = deprecated_function("get_custom_model_by_name")301get_custom_model_metrics: Never = deprecated_function("get_custom_model_metrics")302list_custom_models: Never = deprecated_function("list_custom_models")303create_connector: Never = moved_function("create_connector", ".connectors.create")304update_connector: Never = moved_function("update_connector", ".connectors.update")305get_connector: Never = moved_function("get_connector", ".connectors.get")306list_connectors: Never = moved_function("list_connectors", ".connectors.list")307delete_connector: Never = moved_function("delete_connector", ".connectors.delete")308oauth_authorize_connector: Never = moved_function("oauth_authorize_connector", ".connectors.o_auth_authorize")309