llama-index
119 строк · 4.2 Кб
1import asyncio2import os3from typing import Any, List, Optional4
5import httpx6import requests7
8from llama_index.legacy.bridge.pydantic import Field9from llama_index.legacy.embeddings.base import BaseEmbedding, Embedding10
11
12class TogetherEmbedding(BaseEmbedding):13api_base: str = Field(14default="https://api.together.xyz/v1",15description="The base URL for the Together API.",16)17api_key: str = Field(18default="",19description="The API key for the Together API. If not set, will attempt to use the TOGETHER_API_KEY environment variable.",20)21
22def __init__(23self,24model_name: str,25api_key: Optional[str] = None,26api_base: str = "https://api.together.xyz/v1",27**kwargs: Any,28) -> None:29api_key = api_key or os.environ.get("TOGETHER_API_KEY", None)30super().__init__(31model_name=model_name,32api_key=api_key,33api_base=api_base,34**kwargs,35)36
37def _generate_embedding(self, text: str, model_api_string: str) -> Embedding:38"""Generate embeddings from Together API.39
40Args:
41text: str. An input text sentence or document.
42model_api_string: str. An API string for a specific embedding model of your choice.
43
44Returns:
45embeddings: a list of float numbers. Embeddings correspond to your given text.
46"""
47headers = {48"accept": "application/json",49"content-type": "application/json",50"Authorization": f"Bearer {self.api_key}",51}52
53session = requests.Session()54response = session.post(55self.api_base.strip("/") + "/embeddings",56headers=headers,57json={"input": text, "model": model_api_string},58)59if response.status_code != 200:60raise ValueError(61f"Request failed with status code {response.status_code}: {response.text}"62)63
64return response.json()["data"][0]["embedding"]65
66async def _agenerate_embedding(self, text: str, model_api_string: str) -> Embedding:67"""Async generate embeddings from Together API.68
69Args:
70text: str. An input text sentence or document.
71model_api_string: str. An API string for a specific embedding model of your choice.
72
73Returns:
74embeddings: a list of float numbers. Embeddings correspond to your given text.
75"""
76headers = {77"accept": "application/json",78"content-type": "application/json",79"Authorization": f"Bearer {self.api_key}",80}81
82async with httpx.AsyncClient() as client:83response = await client.post(84self.api_base.strip("/") + "/embeddings",85headers=headers,86json={"input": text, "model": model_api_string},87)88if response.status_code != 200:89raise ValueError(90f"Request failed with status code {response.status_code}: {response.text}"91)92
93return response.json()["data"][0]["embedding"]94
95def _get_text_embedding(self, text: str) -> Embedding:96"""Get text embedding."""97return self._generate_embedding(text, self.model_name)98
99def _get_query_embedding(self, query: str) -> Embedding:100"""Get query embedding."""101return self._generate_embedding(query, self.model_name)102
103def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:104"""Get text embeddings."""105return [self._generate_embedding(text, self.model_name) for text in texts]106
107async def _aget_text_embedding(self, text: str) -> Embedding:108"""Async get text embedding."""109return await self._agenerate_embedding(text, self.model_name)110
111async def _aget_query_embedding(self, query: str) -> Embedding:112"""Async get query embedding."""113return await self._agenerate_embedding(query, self.model_name)114
115async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:116"""Async get text embeddings."""117return await asyncio.gather(118*[self._agenerate_embedding(text, self.model_name) for text in texts]119)120