llama-index

Форк
0
119 строк · 4.2 Кб
1
import asyncio
2
import os
3
from typing import Any, List, Optional
4

5
import httpx
6
import requests
7

8
from llama_index.legacy.bridge.pydantic import Field
9
from llama_index.legacy.embeddings.base import BaseEmbedding, Embedding
10

11

12
class TogetherEmbedding(BaseEmbedding):
13
    api_base: str = Field(
14
        default="https://api.together.xyz/v1",
15
        description="The base URL for the Together API.",
16
    )
17
    api_key: str = Field(
18
        default="",
19
        description="The API key for the Together API. If not set, will attempt to use the TOGETHER_API_KEY environment variable.",
20
    )
21

22
    def __init__(
23
        self,
24
        model_name: str,
25
        api_key: Optional[str] = None,
26
        api_base: str = "https://api.together.xyz/v1",
27
        **kwargs: Any,
28
    ) -> None:
29
        api_key = api_key or os.environ.get("TOGETHER_API_KEY", None)
30
        super().__init__(
31
            model_name=model_name,
32
            api_key=api_key,
33
            api_base=api_base,
34
            **kwargs,
35
        )
36

37
    def _generate_embedding(self, text: str, model_api_string: str) -> Embedding:
38
        """Generate embeddings from Together API.
39

40
        Args:
41
            text: str. An input text sentence or document.
42
            model_api_string: str. An API string for a specific embedding model of your choice.
43

44
        Returns:
45
            embeddings: a list of float numbers. Embeddings correspond to your given text.
46
        """
47
        headers = {
48
            "accept": "application/json",
49
            "content-type": "application/json",
50
            "Authorization": f"Bearer {self.api_key}",
51
        }
52

53
        session = requests.Session()
54
        response = session.post(
55
            self.api_base.strip("/") + "/embeddings",
56
            headers=headers,
57
            json={"input": text, "model": model_api_string},
58
        )
59
        if response.status_code != 200:
60
            raise ValueError(
61
                f"Request failed with status code {response.status_code}: {response.text}"
62
            )
63

64
        return response.json()["data"][0]["embedding"]
65

66
    async def _agenerate_embedding(self, text: str, model_api_string: str) -> Embedding:
67
        """Async generate embeddings from Together API.
68

69
        Args:
70
            text: str. An input text sentence or document.
71
            model_api_string: str. An API string for a specific embedding model of your choice.
72

73
        Returns:
74
            embeddings: a list of float numbers. Embeddings correspond to your given text.
75
        """
76
        headers = {
77
            "accept": "application/json",
78
            "content-type": "application/json",
79
            "Authorization": f"Bearer {self.api_key}",
80
        }
81

82
        async with httpx.AsyncClient() as client:
83
            response = await client.post(
84
                self.api_base.strip("/") + "/embeddings",
85
                headers=headers,
86
                json={"input": text, "model": model_api_string},
87
            )
88
            if response.status_code != 200:
89
                raise ValueError(
90
                    f"Request failed with status code {response.status_code}: {response.text}"
91
                )
92

93
            return response.json()["data"][0]["embedding"]
94

95
    def _get_text_embedding(self, text: str) -> Embedding:
96
        """Get text embedding."""
97
        return self._generate_embedding(text, self.model_name)
98

99
    def _get_query_embedding(self, query: str) -> Embedding:
100
        """Get query embedding."""
101
        return self._generate_embedding(query, self.model_name)
102

103
    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
104
        """Get text embeddings."""
105
        return [self._generate_embedding(text, self.model_name) for text in texts]
106

107
    async def _aget_text_embedding(self, text: str) -> Embedding:
108
        """Async get text embedding."""
109
        return await self._agenerate_embedding(text, self.model_name)
110

111
    async def _aget_query_embedding(self, query: str) -> Embedding:
112
        """Async get query embedding."""
113
        return await self._agenerate_embedding(query, self.model_name)
114

115
    async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
116
        """Async get text embeddings."""
117
        return await asyncio.gather(
118
            *[self._agenerate_embedding(text, self.model_name) for text in texts]
119
        )
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.