weaviate

Форк
0
228 строк · 5.8 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package clients
13

14
import (
15
	"bytes"
16
	"context"
17
	"encoding/json"
18
	"fmt"
19
	"io"
20
	"net/http"
21
	"time"
22

23
	"github.com/weaviate/weaviate/usecases/modulecomponents"
24

25
	"github.com/pkg/errors"
26
	"github.com/sirupsen/logrus"
27
	"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent"
28
)
29

30
const (
31
	DefaultOrigin = "https://api-inference.huggingface.co"
32
	DefaultPath   = "pipeline/feature-extraction"
33
)
34

35
type embeddingsRequest struct {
36
	Inputs  []string `json:"inputs"`
37
	Options *options `json:"options,omitempty"`
38
}
39

40
type options struct {
41
	WaitForModel bool `json:"wait_for_model,omitempty"`
42
	UseGPU       bool `json:"use_gpu,omitempty"`
43
	UseCache     bool `json:"use_cache,omitempty"`
44
}
45

46
type embedding [][]float32
47

48
type embeddingBert [][][][]float32
49

50
type embeddingObject struct {
51
	Embeddings embedding `json:"embeddings"`
52
}
53

54
type huggingFaceApiError struct {
55
	Error         string   `json:"error"`
56
	EstimatedTime *float32 `json:"estimated_time,omitempty"`
57
	Warnings      []string `json:"warnings"`
58
}
59

60
type vectorizer struct {
61
	apiKey                string
62
	httpClient            *http.Client
63
	bertEmbeddingsDecoder *bertEmbeddingsDecoder
64
	logger                logrus.FieldLogger
65
}
66

67
func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
68
	return &vectorizer{
69
		apiKey: apiKey,
70
		httpClient: &http.Client{
71
			Timeout: timeout,
72
		},
73
		bertEmbeddingsDecoder: newBertEmbeddingsDecoder(),
74
		logger:                logger,
75
	}
76
}
77

78
func (v *vectorizer) Vectorize(ctx context.Context, input string,
79
	config ent.VectorizationConfig,
80
) (*ent.VectorizationResult, error) {
81
	return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
82
}
83

84
func (v *vectorizer) VectorizeQuery(ctx context.Context, input string,
85
	config ent.VectorizationConfig,
86
) (*ent.VectorizationResult, error) {
87
	return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
88
}
89

90
func (v *vectorizer) vectorize(ctx context.Context, url string,
91
	input string, options options,
92
) (*ent.VectorizationResult, error) {
93
	body, err := json.Marshal(embeddingsRequest{
94
		Inputs:  []string{input},
95
		Options: &options,
96
	})
97
	if err != nil {
98
		return nil, errors.Wrapf(err, "marshal body")
99
	}
100

101
	req, err := http.NewRequestWithContext(ctx, "POST", url,
102
		bytes.NewReader(body))
103
	if err != nil {
104
		return nil, errors.Wrap(err, "create POST request")
105
	}
106
	if apiKey := v.getApiKey(ctx); apiKey != "" {
107
		req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
108
	}
109
	req.Header.Add("Content-Type", "application/json")
110

111
	res, err := v.httpClient.Do(req)
112
	if err != nil {
113
		return nil, errors.Wrap(err, "send POST request")
114
	}
115
	defer res.Body.Close()
116

117
	bodyBytes, err := io.ReadAll(res.Body)
118
	if err != nil {
119
		return nil, errors.Wrap(err, "read response body")
120
	}
121

122
	if err := checkResponse(res, bodyBytes); err != nil {
123
		return nil, err
124
	}
125

126
	vector, err := v.decodeVector(bodyBytes)
127
	if err != nil {
128
		return nil, errors.Wrap(err, "cannot decode vector")
129
	}
130

131
	return &ent.VectorizationResult{
132
		Text:       input,
133
		Dimensions: len(vector),
134
		Vector:     vector,
135
	}, nil
136
}
137

138
func checkResponse(res *http.Response, bodyBytes []byte) error {
139
	if res.StatusCode < 400 {
140
		return nil
141
	}
142

143
	var resBody huggingFaceApiError
144
	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
145
		return fmt.Errorf("unmarshal error response body: %v", string(bodyBytes))
146
	}
147

148
	message := fmt.Sprintf("failed with status: %d", res.StatusCode)
149
	if resBody.Error != "" {
150
		message = fmt.Sprintf("%s error: %v", message, resBody.Error)
151
		if resBody.EstimatedTime != nil {
152
			message = fmt.Sprintf("%s estimated time: %v", message, *resBody.EstimatedTime)
153
		}
154
		if len(resBody.Warnings) > 0 {
155
			message = fmt.Sprintf("%s warnings: %v", message, resBody.Warnings)
156
		}
157
	}
158

159
	if res.StatusCode == http.StatusInternalServerError {
160
		message = fmt.Sprintf("connection to HuggingFace %v", message)
161
	}
162

163
	return errors.New(message)
164
}
165

166
func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) {
167
	var emb embedding
168
	if err := json.Unmarshal(bodyBytes, &emb); err != nil {
169
		var embObject embeddingObject
170
		if err := json.Unmarshal(bodyBytes, &embObject); err != nil {
171
			var embBert embeddingBert
172
			if err := json.Unmarshal(bodyBytes, &embBert); err != nil {
173
				return nil, errors.Wrap(err, "unmarshal response body")
174
			}
175

176
			if len(embBert) == 1 && len(embBert[0]) == 1 {
177
				return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0])
178
			}
179

180
			return nil, errors.New("unprocessable response body")
181
		}
182
		if len(embObject.Embeddings) == 1 {
183
			return embObject.Embeddings[0], nil
184
		}
185

186
		return nil, errors.New("unprocessable response body")
187
	}
188

189
	if len(emb) == 1 {
190
		return emb[0], nil
191
	}
192

193
	return nil, errors.New("unprocessable response body")
194
}
195

196
func (v *vectorizer) getApiKey(ctx context.Context) string {
197
	if len(v.apiKey) > 0 {
198
		return v.apiKey
199
	}
200
	key := "X-Huggingface-Api-Key"
201
	apiKey := ctx.Value(key)
202
	// try getting header from GRPC if not successful
203
	if apiKey == nil {
204
		apiKey = modulecomponents.GetValueFromGRPC(ctx, key)
205
	}
206

207
	if apiKeyHeader, ok := apiKey.([]string); ok &&
208
		len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 {
209
		return apiKeyHeader[0]
210
	}
211
	return ""
212
}
213

214
func (v *vectorizer) getOptions(config ent.VectorizationConfig) options {
215
	return options{
216
		WaitForModel: config.WaitForModel,
217
		UseGPU:       config.UseGPU,
218
		UseCache:     config.UseCache,
219
	}
220
}
221

222
func (v *vectorizer) getURL(config ent.VectorizationConfig) string {
223
	if config.EndpointURL != "" {
224
		return config.EndpointURL
225
	}
226

227
	return fmt.Sprintf("%s/%s/%s", DefaultOrigin, DefaultPath, config.Model)
228
}
229

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.