weaviate
341 строка · 10.1 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package clients
13
14import (
15"bytes"
16"context"
17"crypto/sha256"
18"encoding/json"
19"fmt"
20"io"
21"net/http"
22"net/url"
23"strconv"
24"strings"
25"time"
26
27"github.com/weaviate/weaviate/entities/moduletools"
28
29"github.com/weaviate/weaviate/usecases/modulecomponents"
30
31"github.com/pkg/errors"
32"github.com/sirupsen/logrus"
33"github.com/weaviate/weaviate/modules/text2vec-openai/ent"
34)
35
36type embeddingsRequest struct {
37Input []string `json:"input"`
38Model string `json:"model,omitempty"`
39Dimensions *int64 `json:"dimensions,omitempty"`
40}
41
42type embedding struct {
43Object string `json:"object"`
44Data []embeddingData `json:"data,omitempty"`
45Error *openAIApiError `json:"error,omitempty"`
46}
47
48type embeddingData struct {
49Object string `json:"object"`
50Index int `json:"index"`
51Embedding []float32 `json:"embedding"`
52Error *openAIApiError `json:"error,omitempty"`
53}
54
55type openAIApiError struct {
56Message string `json:"message"`
57Type string `json:"type"`
58Param string `json:"param"`
59Code openAICode `json:"code"`
60}
61
62type openAICode string
63
64func (c *openAICode) String() string {
65if c == nil {
66return ""
67}
68return string(*c)
69}
70
71func (c *openAICode) UnmarshalJSON(data []byte) (err error) {
72if number, err := strconv.Atoi(string(data)); err == nil {
73str := strconv.Itoa(number)
74*c = openAICode(str)
75return nil
76}
77var str string
78err = json.Unmarshal(data, &str)
79if err != nil {
80return err
81}
82*c = openAICode(str)
83return nil
84}
85
86func buildUrl(baseURL, resourceName, deploymentID, apiVersion string, isAzure bool) (string, error) {
87if isAzure {
88host := baseURL
89if host == "" || host == "https://api.openai.com" {
90// Fall back to old assumption
91host = "https://" + resourceName + ".openai.azure.com"
92}
93
94path := "openai/deployments/" + deploymentID + "/embeddings"
95queryParam := fmt.Sprintf("api-version=%s", apiVersion)
96return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
97}
98
99host := baseURL
100path := "/v1/embeddings"
101return url.JoinPath(host, path)
102}
103
104type client struct {
105openAIApiKey string
106openAIOrganization string
107azureApiKey string
108httpClient *http.Client
109buildUrlFn func(baseURL, resourceName, deploymentID, apiVersion string, isAzure bool) (string, error)
110logger logrus.FieldLogger
111}
112
113func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *client {
114return &client{
115openAIApiKey: openAIApiKey,
116openAIOrganization: openAIOrganization,
117azureApiKey: azureApiKey,
118httpClient: &http.Client{
119Timeout: timeout,
120},
121buildUrlFn: buildUrl,
122logger: logger,
123}
124}
125
126func (v *client) Vectorize(ctx context.Context, input []string,
127cfg moduletools.ClassConfig,
128) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) {
129config := v.getVectorizationConfig(cfg)
130return v.vectorize(ctx, input, v.getModelString(config.Type, config.Model, "document", config.ModelVersion), config)
131}
132
133func (v *client) VectorizeQuery(ctx context.Context, input []string,
134cfg moduletools.ClassConfig,
135) (*modulecomponents.VectorizationResult, error) {
136config := v.getVectorizationConfig(cfg)
137res, _, err := v.vectorize(ctx, input, v.getModelString(config.Type, config.Model, "query", config.ModelVersion), config)
138return res, err
139}
140
141func (v *client) vectorize(ctx context.Context, input []string, model string, config ent.VectorizationConfig) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) {
142body, err := json.Marshal(v.getEmbeddingsRequest(input, model, config.IsAzure, config.Dimensions))
143if err != nil {
144return nil, nil, errors.Wrap(err, "marshal body")
145}
146
147endpoint, err := v.buildURL(ctx, config)
148if err != nil {
149return nil, nil, errors.Wrap(err, "join OpenAI API host and path")
150}
151
152req, err := http.NewRequestWithContext(ctx, "POST", endpoint,
153bytes.NewReader(body))
154if err != nil {
155return nil, nil, errors.Wrap(err, "create POST request")
156}
157apiKey, err := v.getApiKey(ctx, config.IsAzure)
158if err != nil {
159return nil, nil, errors.Wrap(err, "API Key")
160}
161req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, config.IsAzure))
162if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" {
163req.Header.Add("OpenAI-Organization", openAIOrganization)
164}
165req.Header.Add("Content-Type", "application/json")
166
167res, err := v.httpClient.Do(req)
168if err != nil {
169return nil, nil, errors.Wrap(err, "send POST request")
170}
171defer res.Body.Close()
172
173bodyBytes, err := io.ReadAll(res.Body)
174if err != nil {
175return nil, nil, errors.Wrap(err, "read response body")
176}
177
178var resBody embedding
179if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
180return nil, nil, errors.Wrap(err, "unmarshal response body")
181}
182
183if res.StatusCode != 200 || resBody.Error != nil {
184return nil, nil, v.getError(res.StatusCode, resBody.Error, config.IsAzure)
185}
186rateLimit := ent.GetRateLimitsFromHeader(res.Header)
187
188texts := make([]string, len(resBody.Data))
189embeddings := make([][]float32, len(resBody.Data))
190openAIerror := make([]error, len(resBody.Data))
191for i := range resBody.Data {
192texts[i] = resBody.Data[i].Object
193embeddings[i] = resBody.Data[i].Embedding
194if resBody.Data[i].Error != nil {
195openAIerror[i] = v.getError(res.StatusCode, resBody.Data[i].Error, config.IsAzure)
196}
197}
198
199return &modulecomponents.VectorizationResult{
200Text: texts,
201Dimensions: len(resBody.Data[0].Embedding),
202Vector: embeddings,
203Errors: openAIerror,
204}, rateLimit, nil
205}
206
207func (v *client) buildURL(ctx context.Context, config ent.VectorizationConfig) (string, error) {
208baseURL, resourceName, deploymentID, apiVersion, isAzure := config.BaseURL, config.ResourceName, config.DeploymentID, config.ApiVersion, config.IsAzure
209if headerBaseURL := modulecomponents.GetValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
210baseURL = headerBaseURL
211}
212return v.buildUrlFn(baseURL, resourceName, deploymentID, apiVersion, isAzure)
213}
214
215func (v *client) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error {
216endpoint := "OpenAI API"
217if isAzure {
218endpoint = "Azure OpenAI API"
219}
220if resBodyError != nil {
221return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
222}
223return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
224}
225
226func (v *client) getEmbeddingsRequest(input []string, model string, isAzure bool, dimensions *int64) embeddingsRequest {
227if isAzure {
228return embeddingsRequest{Input: input}
229}
230return embeddingsRequest{Input: input, Model: model, Dimensions: dimensions}
231}
232
233func (v *client) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) {
234if isAzure {
235return "api-key", apiKey
236}
237return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
238}
239
240func (v *client) getOpenAIOrganization(ctx context.Context) string {
241if value := modulecomponents.GetValueFromContext(ctx, "X-Openai-Organization"); value != "" {
242return value
243}
244return v.openAIOrganization
245}
246
247func (v *client) GetApiKeyHash(ctx context.Context, cfg moduletools.ClassConfig) [32]byte {
248config := v.getVectorizationConfig(cfg)
249
250key, err := v.getApiKey(ctx, config.IsAzure)
251if err != nil {
252return [32]byte{}
253}
254return sha256.Sum256([]byte(key))
255}
256
257func (v *client) GetVectorizerRateLimit(ctx context.Context) *modulecomponents.RateLimits {
258rpm, tpm := modulecomponents.GetRateLimitFromContext(ctx, "Openai", 0, 0)
259return &modulecomponents.RateLimits{
260RemainingTokens: tpm,
261LimitTokens: tpm,
262ResetTokens: time.Now().Add(61 * time.Second),
263RemainingRequests: rpm,
264LimitRequests: rpm,
265ResetRequests: time.Now().Add(61 * time.Second),
266}
267}
268
269func (v *client) getApiKey(ctx context.Context, isAzure bool) (string, error) {
270var apiKey, envVar string
271
272if isAzure {
273apiKey = "X-Azure-Api-Key"
274envVar = "AZURE_APIKEY"
275if len(v.azureApiKey) > 0 {
276return v.azureApiKey, nil
277}
278} else {
279apiKey = "X-Openai-Api-Key"
280envVar = "OPENAI_APIKEY"
281if len(v.openAIApiKey) > 0 {
282return v.openAIApiKey, nil
283}
284}
285
286return v.getApiKeyFromContext(ctx, apiKey, envVar)
287}
288
289func (v *client) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
290if apiKeyValue := modulecomponents.GetValueFromContext(ctx, apiKey); apiKeyValue != "" {
291return apiKeyValue, nil
292}
293return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
294}
295
296func (v *client) getModelString(docType, model, action, version string) string {
297if strings.HasPrefix(model, "text-embedding-3") {
298// indicates that we handle v3 models
299return model
300}
301if version == "002" {
302return v.getModel002String(model)
303}
304return v.getModel001String(docType, model, action)
305}
306
307func (v *client) getModel001String(docType, model, action string) string {
308modelBaseString := "%s-search-%s-%s-001"
309if action == "document" {
310if docType == "code" {
311return fmt.Sprintf(modelBaseString, docType, model, "code")
312}
313return fmt.Sprintf(modelBaseString, docType, model, "doc")
314
315} else {
316if docType == "code" {
317return fmt.Sprintf(modelBaseString, docType, model, "text")
318}
319return fmt.Sprintf(modelBaseString, docType, model, "query")
320}
321}
322
323func (v *client) getModel002String(model string) string {
324modelBaseString := "text-embedding-%s-002"
325return fmt.Sprintf(modelBaseString, model)
326}
327
328func (v *client) getVectorizationConfig(cfg moduletools.ClassConfig) ent.VectorizationConfig {
329settings := ent.NewClassSettings(cfg)
330return ent.VectorizationConfig{
331Type: settings.Type(),
332Model: settings.Model(),
333ModelVersion: settings.ModelVersion(),
334ResourceName: settings.ResourceName(),
335DeploymentID: settings.DeploymentID(),
336BaseURL: settings.BaseURL(),
337IsAzure: settings.IsAzure(),
338ApiVersion: settings.ApiVersion(),
339Dimensions: settings.Dimensions(),
340}
341}
342