llama-index
141 строка · 5.1 Кб
1import logging
2from typing import Any, List, Optional
3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5from llama_index.legacy.callbacks import CallbackManager
6from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
7from llama_index.legacy.core.embeddings.base import BaseEmbedding
8
9logger = logging.getLogger(__name__)
10
11EXAMPLE_URL = "https://clarifai.com/anthropic/completion/models/claude-v2"
12
13
14class ClarifaiEmbedding(BaseEmbedding):
15"""Clarifai embeddings class.
16
17Clarifai uses Personal Access Tokens(PAT) to validate requests.
18You can create and manage PATs under your Clarifai account security settings.
19Export your PAT as an environment variable by running `export CLARIFAI_PAT={PAT}`
20"""
21
22model_url: Optional[str] = Field(
23description=f"Full URL of the model. e.g. `{EXAMPLE_URL}`"
24)
25model_id: Optional[str] = Field(description="Model ID.")
26model_version_id: Optional[str] = Field(description="Model Version ID.")
27app_id: Optional[str] = Field(description="Clarifai application ID of the model.")
28user_id: Optional[str] = Field(description="Clarifai user ID of the model.")
29pat: Optional[str] = Field(
30description="Personal Access Tokens(PAT) to validate requests."
31)
32
33_model: Any = PrivateAttr()
34
35def __init__(
36self,
37model_name: Optional[str] = None,
38model_url: Optional[str] = None,
39model_version_id: Optional[str] = "",
40app_id: Optional[str] = None,
41user_id: Optional[str] = None,
42pat: Optional[str] = None,
43embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
44callback_manager: Optional[CallbackManager] = None,
45):
46try:
47import os
48
49from clarifai.client.model import Model
50except ImportError:
51raise ImportError("ClarifaiEmbedding requires `pip install clarifai`.")
52
53embed_batch_size = min(128, embed_batch_size)
54
55if pat is None and os.environ.get("CLARIFAI_PAT") is not None:
56pat = os.environ.get("CLARIFAI_PAT")
57
58if not pat and os.environ.get("CLARIFAI_PAT") is None:
59raise ValueError(
60"Set `CLARIFAI_PAT` as env variable or pass `pat` as constructor argument"
61)
62
63if model_url is not None and model_name is not None:
64raise ValueError("You can only specify one of model_url or model_name.")
65if model_url is None and model_name is None:
66raise ValueError("You must specify one of model_url or model_name.")
67
68if model_name is not None:
69if app_id is None or user_id is None:
70raise ValueError(
71f"Missing one app ID or user ID of the model: {app_id=}, {user_id=}"
72)
73else:
74self._model = Model(
75user_id=user_id,
76app_id=app_id,
77model_id=model_name,
78model_version={"id": model_version_id},
79pat=pat,
80)
81
82if model_url is not None:
83self._model = Model(model_url, pat=pat)
84model_name = self._model.id
85
86super().__init__(
87embed_batch_size=embed_batch_size,
88callback_manager=callback_manager,
89model_name=model_name,
90)
91
92@classmethod
93def class_name(cls) -> str:
94return "ClarifaiEmbedding"
95
96def _embed(self, sentences: List[str]) -> List[List[float]]:
97"""Embed sentences."""
98try:
99from clarifai.client.input import Inputs
100except ImportError:
101raise ImportError("ClarifaiEmbedding requires `pip install clarifai`.")
102
103embeddings = []
104try:
105for i in range(0, len(sentences), self.embed_batch_size):
106batch = sentences[i : i + self.embed_batch_size]
107input_batch = [
108Inputs.get_text_input(input_id=str(id), raw_text=inp)
109for id, inp in enumerate(batch)
110]
111predict_response = self._model.predict(input_batch)
112embeddings.extend(
113[
114list(output.data.embeddings[0].vector)
115for output in predict_response.outputs
116]
117)
118except Exception as e:
119logger.error(f"Predict failed, exception: {e}")
120
121return embeddings
122
123def _get_query_embedding(self, query: str) -> List[float]:
124"""Get query embedding."""
125return self._embed([query])[0]
126
127async def _aget_query_embedding(self, query: str) -> List[float]:
128"""Get query embedding async."""
129return self._get_query_embedding(query)
130
131async def _aget_text_embedding(self, text: str) -> List[float]:
132"""Get text embedding async."""
133return self._get_text_embedding(text)
134
135def _get_text_embedding(self, text: str) -> List[float]:
136"""Get text embedding."""
137return self._embed([text])[0]
138
139def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
140"""Get text embeddings."""
141return self._embed(texts)
142