weaviate
228 строк · 5.8 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package clients
13
14import (
15"bytes"
16"context"
17"encoding/json"
18"fmt"
19"io"
20"net/http"
21"time"
22
23"github.com/weaviate/weaviate/usecases/modulecomponents"
24
25"github.com/pkg/errors"
26"github.com/sirupsen/logrus"
27"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent"
28)
29
30const (
31DefaultOrigin = "https://api-inference.huggingface.co"
32DefaultPath = "pipeline/feature-extraction"
33)
34
35type embeddingsRequest struct {
36Inputs []string `json:"inputs"`
37Options *options `json:"options,omitempty"`
38}
39
40type options struct {
41WaitForModel bool `json:"wait_for_model,omitempty"`
42UseGPU bool `json:"use_gpu,omitempty"`
43UseCache bool `json:"use_cache,omitempty"`
44}
45
46type embedding [][]float32
47
48type embeddingBert [][][][]float32
49
50type embeddingObject struct {
51Embeddings embedding `json:"embeddings"`
52}
53
54type huggingFaceApiError struct {
55Error string `json:"error"`
56EstimatedTime *float32 `json:"estimated_time,omitempty"`
57Warnings []string `json:"warnings"`
58}
59
60type vectorizer struct {
61apiKey string
62httpClient *http.Client
63bertEmbeddingsDecoder *bertEmbeddingsDecoder
64logger logrus.FieldLogger
65}
66
67func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
68return &vectorizer{
69apiKey: apiKey,
70httpClient: &http.Client{
71Timeout: timeout,
72},
73bertEmbeddingsDecoder: newBertEmbeddingsDecoder(),
74logger: logger,
75}
76}
77
78func (v *vectorizer) Vectorize(ctx context.Context, input string,
79config ent.VectorizationConfig,
80) (*ent.VectorizationResult, error) {
81return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
82}
83
84func (v *vectorizer) VectorizeQuery(ctx context.Context, input string,
85config ent.VectorizationConfig,
86) (*ent.VectorizationResult, error) {
87return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config))
88}
89
90func (v *vectorizer) vectorize(ctx context.Context, url string,
91input string, options options,
92) (*ent.VectorizationResult, error) {
93body, err := json.Marshal(embeddingsRequest{
94Inputs: []string{input},
95Options: &options,
96})
97if err != nil {
98return nil, errors.Wrapf(err, "marshal body")
99}
100
101req, err := http.NewRequestWithContext(ctx, "POST", url,
102bytes.NewReader(body))
103if err != nil {
104return nil, errors.Wrap(err, "create POST request")
105}
106if apiKey := v.getApiKey(ctx); apiKey != "" {
107req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
108}
109req.Header.Add("Content-Type", "application/json")
110
111res, err := v.httpClient.Do(req)
112if err != nil {
113return nil, errors.Wrap(err, "send POST request")
114}
115defer res.Body.Close()
116
117bodyBytes, err := io.ReadAll(res.Body)
118if err != nil {
119return nil, errors.Wrap(err, "read response body")
120}
121
122if err := checkResponse(res, bodyBytes); err != nil {
123return nil, err
124}
125
126vector, err := v.decodeVector(bodyBytes)
127if err != nil {
128return nil, errors.Wrap(err, "cannot decode vector")
129}
130
131return &ent.VectorizationResult{
132Text: input,
133Dimensions: len(vector),
134Vector: vector,
135}, nil
136}
137
138func checkResponse(res *http.Response, bodyBytes []byte) error {
139if res.StatusCode < 400 {
140return nil
141}
142
143var resBody huggingFaceApiError
144if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
145return fmt.Errorf("unmarshal error response body: %v", string(bodyBytes))
146}
147
148message := fmt.Sprintf("failed with status: %d", res.StatusCode)
149if resBody.Error != "" {
150message = fmt.Sprintf("%s error: %v", message, resBody.Error)
151if resBody.EstimatedTime != nil {
152message = fmt.Sprintf("%s estimated time: %v", message, *resBody.EstimatedTime)
153}
154if len(resBody.Warnings) > 0 {
155message = fmt.Sprintf("%s warnings: %v", message, resBody.Warnings)
156}
157}
158
159if res.StatusCode == http.StatusInternalServerError {
160message = fmt.Sprintf("connection to HuggingFace %v", message)
161}
162
163return errors.New(message)
164}
165
166func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) {
167var emb embedding
168if err := json.Unmarshal(bodyBytes, &emb); err != nil {
169var embObject embeddingObject
170if err := json.Unmarshal(bodyBytes, &embObject); err != nil {
171var embBert embeddingBert
172if err := json.Unmarshal(bodyBytes, &embBert); err != nil {
173return nil, errors.Wrap(err, "unmarshal response body")
174}
175
176if len(embBert) == 1 && len(embBert[0]) == 1 {
177return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0])
178}
179
180return nil, errors.New("unprocessable response body")
181}
182if len(embObject.Embeddings) == 1 {
183return embObject.Embeddings[0], nil
184}
185
186return nil, errors.New("unprocessable response body")
187}
188
189if len(emb) == 1 {
190return emb[0], nil
191}
192
193return nil, errors.New("unprocessable response body")
194}
195
196func (v *vectorizer) getApiKey(ctx context.Context) string {
197if len(v.apiKey) > 0 {
198return v.apiKey
199}
200key := "X-Huggingface-Api-Key"
201apiKey := ctx.Value(key)
202// try getting header from GRPC if not successful
203if apiKey == nil {
204apiKey = modulecomponents.GetValueFromGRPC(ctx, key)
205}
206
207if apiKeyHeader, ok := apiKey.([]string); ok &&
208len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 {
209return apiKeyHeader[0]
210}
211return ""
212}
213
214func (v *vectorizer) getOptions(config ent.VectorizationConfig) options {
215return options{
216WaitForModel: config.WaitForModel,
217UseGPU: config.UseGPU,
218UseCache: config.UseCache,
219}
220}
221
222func (v *vectorizer) getURL(config ent.VectorizationConfig) string {
223if config.EndpointURL != "" {
224return config.EndpointURL
225}
226
227return fmt.Sprintf("%s/%s/%s", DefaultOrigin, DefaultPath, config.Model)
228}
229