weaviate

Форк
0
341 строка · 10.1 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package clients
13

14
import (
15
	"bytes"
16
	"context"
17
	"crypto/sha256"
18
	"encoding/json"
19
	"fmt"
20
	"io"
21
	"net/http"
22
	"net/url"
23
	"strconv"
24
	"strings"
25
	"time"
26

27
	"github.com/weaviate/weaviate/entities/moduletools"
28

29
	"github.com/weaviate/weaviate/usecases/modulecomponents"
30

31
	"github.com/pkg/errors"
32
	"github.com/sirupsen/logrus"
33
	"github.com/weaviate/weaviate/modules/text2vec-openai/ent"
34
)
35

36
type embeddingsRequest struct {
37
	Input      []string `json:"input"`
38
	Model      string   `json:"model,omitempty"`
39
	Dimensions *int64   `json:"dimensions,omitempty"`
40
}
41

42
type embedding struct {
43
	Object string          `json:"object"`
44
	Data   []embeddingData `json:"data,omitempty"`
45
	Error  *openAIApiError `json:"error,omitempty"`
46
}
47

48
type embeddingData struct {
49
	Object    string          `json:"object"`
50
	Index     int             `json:"index"`
51
	Embedding []float32       `json:"embedding"`
52
	Error     *openAIApiError `json:"error,omitempty"`
53
}
54

55
type openAIApiError struct {
56
	Message string     `json:"message"`
57
	Type    string     `json:"type"`
58
	Param   string     `json:"param"`
59
	Code    openAICode `json:"code"`
60
}
61

62
type openAICode string
63

64
func (c *openAICode) String() string {
65
	if c == nil {
66
		return ""
67
	}
68
	return string(*c)
69
}
70

71
func (c *openAICode) UnmarshalJSON(data []byte) (err error) {
72
	if number, err := strconv.Atoi(string(data)); err == nil {
73
		str := strconv.Itoa(number)
74
		*c = openAICode(str)
75
		return nil
76
	}
77
	var str string
78
	err = json.Unmarshal(data, &str)
79
	if err != nil {
80
		return err
81
	}
82
	*c = openAICode(str)
83
	return nil
84
}
85

86
func buildUrl(baseURL, resourceName, deploymentID, apiVersion string, isAzure bool) (string, error) {
87
	if isAzure {
88
		host := baseURL
89
		if host == "" || host == "https://api.openai.com" {
90
			// Fall back to old assumption
91
			host = "https://" + resourceName + ".openai.azure.com"
92
		}
93

94
		path := "openai/deployments/" + deploymentID + "/embeddings"
95
		queryParam := fmt.Sprintf("api-version=%s", apiVersion)
96
		return fmt.Sprintf("%s/%s?%s", host, path, queryParam), nil
97
	}
98

99
	host := baseURL
100
	path := "/v1/embeddings"
101
	return url.JoinPath(host, path)
102
}
103

104
type client struct {
105
	openAIApiKey       string
106
	openAIOrganization string
107
	azureApiKey        string
108
	httpClient         *http.Client
109
	buildUrlFn         func(baseURL, resourceName, deploymentID, apiVersion string, isAzure bool) (string, error)
110
	logger             logrus.FieldLogger
111
}
112

113
func New(openAIApiKey, openAIOrganization, azureApiKey string, timeout time.Duration, logger logrus.FieldLogger) *client {
114
	return &client{
115
		openAIApiKey:       openAIApiKey,
116
		openAIOrganization: openAIOrganization,
117
		azureApiKey:        azureApiKey,
118
		httpClient: &http.Client{
119
			Timeout: timeout,
120
		},
121
		buildUrlFn: buildUrl,
122
		logger:     logger,
123
	}
124
}
125

126
func (v *client) Vectorize(ctx context.Context, input []string,
127
	cfg moduletools.ClassConfig,
128
) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) {
129
	config := v.getVectorizationConfig(cfg)
130
	return v.vectorize(ctx, input, v.getModelString(config.Type, config.Model, "document", config.ModelVersion), config)
131
}
132

133
func (v *client) VectorizeQuery(ctx context.Context, input []string,
134
	cfg moduletools.ClassConfig,
135
) (*modulecomponents.VectorizationResult, error) {
136
	config := v.getVectorizationConfig(cfg)
137
	res, _, err := v.vectorize(ctx, input, v.getModelString(config.Type, config.Model, "query", config.ModelVersion), config)
138
	return res, err
139
}
140

141
func (v *client) vectorize(ctx context.Context, input []string, model string, config ent.VectorizationConfig) (*modulecomponents.VectorizationResult, *modulecomponents.RateLimits, error) {
142
	body, err := json.Marshal(v.getEmbeddingsRequest(input, model, config.IsAzure, config.Dimensions))
143
	if err != nil {
144
		return nil, nil, errors.Wrap(err, "marshal body")
145
	}
146

147
	endpoint, err := v.buildURL(ctx, config)
148
	if err != nil {
149
		return nil, nil, errors.Wrap(err, "join OpenAI API host and path")
150
	}
151

152
	req, err := http.NewRequestWithContext(ctx, "POST", endpoint,
153
		bytes.NewReader(body))
154
	if err != nil {
155
		return nil, nil, errors.Wrap(err, "create POST request")
156
	}
157
	apiKey, err := v.getApiKey(ctx, config.IsAzure)
158
	if err != nil {
159
		return nil, nil, errors.Wrap(err, "API Key")
160
	}
161
	req.Header.Add(v.getApiKeyHeaderAndValue(apiKey, config.IsAzure))
162
	if openAIOrganization := v.getOpenAIOrganization(ctx); openAIOrganization != "" {
163
		req.Header.Add("OpenAI-Organization", openAIOrganization)
164
	}
165
	req.Header.Add("Content-Type", "application/json")
166

167
	res, err := v.httpClient.Do(req)
168
	if err != nil {
169
		return nil, nil, errors.Wrap(err, "send POST request")
170
	}
171
	defer res.Body.Close()
172

173
	bodyBytes, err := io.ReadAll(res.Body)
174
	if err != nil {
175
		return nil, nil, errors.Wrap(err, "read response body")
176
	}
177

178
	var resBody embedding
179
	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
180
		return nil, nil, errors.Wrap(err, "unmarshal response body")
181
	}
182

183
	if res.StatusCode != 200 || resBody.Error != nil {
184
		return nil, nil, v.getError(res.StatusCode, resBody.Error, config.IsAzure)
185
	}
186
	rateLimit := ent.GetRateLimitsFromHeader(res.Header)
187

188
	texts := make([]string, len(resBody.Data))
189
	embeddings := make([][]float32, len(resBody.Data))
190
	openAIerror := make([]error, len(resBody.Data))
191
	for i := range resBody.Data {
192
		texts[i] = resBody.Data[i].Object
193
		embeddings[i] = resBody.Data[i].Embedding
194
		if resBody.Data[i].Error != nil {
195
			openAIerror[i] = v.getError(res.StatusCode, resBody.Data[i].Error, config.IsAzure)
196
		}
197
	}
198

199
	return &modulecomponents.VectorizationResult{
200
		Text:       texts,
201
		Dimensions: len(resBody.Data[0].Embedding),
202
		Vector:     embeddings,
203
		Errors:     openAIerror,
204
	}, rateLimit, nil
205
}
206

207
func (v *client) buildURL(ctx context.Context, config ent.VectorizationConfig) (string, error) {
208
	baseURL, resourceName, deploymentID, apiVersion, isAzure := config.BaseURL, config.ResourceName, config.DeploymentID, config.ApiVersion, config.IsAzure
209
	if headerBaseURL := modulecomponents.GetValueFromContext(ctx, "X-Openai-Baseurl"); headerBaseURL != "" {
210
		baseURL = headerBaseURL
211
	}
212
	return v.buildUrlFn(baseURL, resourceName, deploymentID, apiVersion, isAzure)
213
}
214

215
func (v *client) getError(statusCode int, resBodyError *openAIApiError, isAzure bool) error {
216
	endpoint := "OpenAI API"
217
	if isAzure {
218
		endpoint = "Azure OpenAI API"
219
	}
220
	if resBodyError != nil {
221
		return fmt.Errorf("connection to: %s failed with status: %d error: %v", endpoint, statusCode, resBodyError.Message)
222
	}
223
	return fmt.Errorf("connection to: %s failed with status: %d", endpoint, statusCode)
224
}
225

226
func (v *client) getEmbeddingsRequest(input []string, model string, isAzure bool, dimensions *int64) embeddingsRequest {
227
	if isAzure {
228
		return embeddingsRequest{Input: input}
229
	}
230
	return embeddingsRequest{Input: input, Model: model, Dimensions: dimensions}
231
}
232

233
func (v *client) getApiKeyHeaderAndValue(apiKey string, isAzure bool) (string, string) {
234
	if isAzure {
235
		return "api-key", apiKey
236
	}
237
	return "Authorization", fmt.Sprintf("Bearer %s", apiKey)
238
}
239

240
func (v *client) getOpenAIOrganization(ctx context.Context) string {
241
	if value := modulecomponents.GetValueFromContext(ctx, "X-Openai-Organization"); value != "" {
242
		return value
243
	}
244
	return v.openAIOrganization
245
}
246

247
func (v *client) GetApiKeyHash(ctx context.Context, cfg moduletools.ClassConfig) [32]byte {
248
	config := v.getVectorizationConfig(cfg)
249

250
	key, err := v.getApiKey(ctx, config.IsAzure)
251
	if err != nil {
252
		return [32]byte{}
253
	}
254
	return sha256.Sum256([]byte(key))
255
}
256

257
func (v *client) GetVectorizerRateLimit(ctx context.Context) *modulecomponents.RateLimits {
258
	rpm, tpm := modulecomponents.GetRateLimitFromContext(ctx, "Openai", 0, 0)
259
	return &modulecomponents.RateLimits{
260
		RemainingTokens:   tpm,
261
		LimitTokens:       tpm,
262
		ResetTokens:       time.Now().Add(61 * time.Second),
263
		RemainingRequests: rpm,
264
		LimitRequests:     rpm,
265
		ResetRequests:     time.Now().Add(61 * time.Second),
266
	}
267
}
268

269
func (v *client) getApiKey(ctx context.Context, isAzure bool) (string, error) {
270
	var apiKey, envVar string
271

272
	if isAzure {
273
		apiKey = "X-Azure-Api-Key"
274
		envVar = "AZURE_APIKEY"
275
		if len(v.azureApiKey) > 0 {
276
			return v.azureApiKey, nil
277
		}
278
	} else {
279
		apiKey = "X-Openai-Api-Key"
280
		envVar = "OPENAI_APIKEY"
281
		if len(v.openAIApiKey) > 0 {
282
			return v.openAIApiKey, nil
283
		}
284
	}
285

286
	return v.getApiKeyFromContext(ctx, apiKey, envVar)
287
}
288

289
func (v *client) getApiKeyFromContext(ctx context.Context, apiKey, envVar string) (string, error) {
290
	if apiKeyValue := modulecomponents.GetValueFromContext(ctx, apiKey); apiKeyValue != "" {
291
		return apiKeyValue, nil
292
	}
293
	return "", fmt.Errorf("no api key found neither in request header: %s nor in environment variable under %s", apiKey, envVar)
294
}
295

296
func (v *client) getModelString(docType, model, action, version string) string {
297
	if strings.HasPrefix(model, "text-embedding-3") {
298
		// indicates that we handle v3 models
299
		return model
300
	}
301
	if version == "002" {
302
		return v.getModel002String(model)
303
	}
304
	return v.getModel001String(docType, model, action)
305
}
306

307
func (v *client) getModel001String(docType, model, action string) string {
308
	modelBaseString := "%s-search-%s-%s-001"
309
	if action == "document" {
310
		if docType == "code" {
311
			return fmt.Sprintf(modelBaseString, docType, model, "code")
312
		}
313
		return fmt.Sprintf(modelBaseString, docType, model, "doc")
314

315
	} else {
316
		if docType == "code" {
317
			return fmt.Sprintf(modelBaseString, docType, model, "text")
318
		}
319
		return fmt.Sprintf(modelBaseString, docType, model, "query")
320
	}
321
}
322

323
func (v *client) getModel002String(model string) string {
324
	modelBaseString := "text-embedding-%s-002"
325
	return fmt.Sprintf(modelBaseString, model)
326
}
327

328
func (v *client) getVectorizationConfig(cfg moduletools.ClassConfig) ent.VectorizationConfig {
329
	settings := ent.NewClassSettings(cfg)
330
	return ent.VectorizationConfig{
331
		Type:         settings.Type(),
332
		Model:        settings.Model(),
333
		ModelVersion: settings.ModelVersion(),
334
		ResourceName: settings.ResourceName(),
335
		DeploymentID: settings.DeploymentID(),
336
		BaseURL:      settings.BaseURL(),
337
		IsAzure:      settings.IsAzure(),
338
		ApiVersion:   settings.ApiVersion(),
339
		Dimensions:   settings.Dimensions(),
340
	}
341
}
342

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.