llama-index

Форк
0
118 строк · 3.9 Кб
1
"""Jina embeddings file."""
2

3
from typing import Any, List, Optional
4

5
import requests
6

7
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
8
from llama_index.legacy.callbacks.base import CallbackManager
9
from llama_index.legacy.core.embeddings.base import (
10
    DEFAULT_EMBED_BATCH_SIZE,
11
    BaseEmbedding,
12
)
13
from llama_index.legacy.llms.generic_utils import get_from_param_or_env
14

15
MAX_BATCH_SIZE = 2048
16

17
API_URL = "https://api.jina.ai/v1/embeddings"
18

19

20
class JinaEmbedding(BaseEmbedding):
21
    """JinaAI class for embeddings.
22

23
    Args:
24
        model (str): Model for embedding.
25
            Defaults to `jina-embeddings-v2-base-en`
26
    """
27

28
    api_key: str = Field(default=None, description="The JinaAI API key.")
29
    model: str = Field(
30
        default="jina-embeddings-v2-base-en",
31
        description="The model to use when calling Jina AI API",
32
    )
33

34
    _session: Any = PrivateAttr()
35

36
    def __init__(
37
        self,
38
        model: str = "jina-embeddings-v2-base-en",
39
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
40
        api_key: Optional[str] = None,
41
        callback_manager: Optional[CallbackManager] = None,
42
        **kwargs: Any,
43
    ) -> None:
44
        super().__init__(
45
            embed_batch_size=embed_batch_size,
46
            callback_manager=callback_manager,
47
            model=model,
48
            api_key=api_key,
49
            **kwargs,
50
        )
51
        self.api_key = get_from_param_or_env("api_key", api_key, "JINAAI_API_KEY", "")
52
        self.model = model
53
        self._session = requests.Session()
54
        self._session.headers.update(
55
            {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
56
        )
57

58
    @classmethod
59
    def class_name(cls) -> str:
60
        return "JinaAIEmbedding"
61

62
    def _get_query_embedding(self, query: str) -> List[float]:
63
        """Get query embedding."""
64
        return self._get_text_embedding(query)
65

66
    async def _aget_query_embedding(self, query: str) -> List[float]:
67
        """The asynchronous version of _get_query_embedding."""
68
        return await self._aget_text_embedding(query)
69

70
    def _get_text_embedding(self, text: str) -> List[float]:
71
        """Get text embedding."""
72
        return self._get_text_embeddings([text])[0]
73

74
    async def _aget_text_embedding(self, text: str) -> List[float]:
75
        """Asynchronously get text embedding."""
76
        result = await self._aget_text_embeddings([text])
77
        return result[0]
78

79
    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
80
        """Get text embeddings."""
81
        # Call Jina AI Embedding API
82
        resp = self._session.post(  # type: ignore
83
            API_URL, json={"input": texts, "model": self.model}
84
        ).json()
85
        if "data" not in resp:
86
            raise RuntimeError(resp["detail"])
87

88
        embeddings = resp["data"]
89

90
        # Sort resulting embeddings by index
91
        sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])  # type: ignore
92

93
        # Return just the embeddings
94
        return [result["embedding"] for result in sorted_embeddings]
95

96
    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
97
        """Asynchronously get text embeddings."""
98
        import aiohttp
99

100
        async with aiohttp.ClientSession(trust_env=True) as session:
101
            headers = {
102
                "Authorization": f"Bearer {self.api_key}",
103
                "Accept-Encoding": "identity",
104
            }
105
            async with session.post(
106
                f"{API_URL}",
107
                json={"input": texts, "model": self.model},
108
                headers=headers,
109
            ) as response:
110
                resp = await response.json()
111
                response.raise_for_status()
112
                embeddings = resp["data"]
113

114
                # Sort resulting embeddings by index
115
                sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])  # type: ignore
116

117
                # Return just the embeddings
118
                return [result["embedding"] for result in sorted_embeddings]
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.