weaviate

Форк
0
283 строки · 7.9 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package clients
13

14
import (
15
	"context"
16
	"encoding/json"
17
	"io"
18
	"net/http"
19
	"net/http/httptest"
20
	"testing"
21
	"time"
22

23
	"github.com/pkg/errors"
24
	"github.com/sirupsen/logrus"
25
	"github.com/sirupsen/logrus/hooks/test"
26
	"github.com/stretchr/testify/assert"
27
	"github.com/stretchr/testify/require"
28
	"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent"
29
)
30

31
func TestClient(t *testing.T) {
32
	t.Run("when all is fine", func(t *testing.T) {
33
		server := httptest.NewServer(&fakeHandler{t: t})
34
		defer server.Close()
35
		c := &vectorizer{
36
			apiKey:     "apiKey",
37
			httpClient: &http.Client{},
38
			logger:     nullLogger(),
39
		}
40
		expected := &ent.VectorizationResult{
41
			Text:       "This is my text",
42
			Vector:     []float32{0.1, 0.2, 0.3},
43
			Dimensions: 3,
44
		}
45
		res, err := c.Vectorize(context.Background(), "This is my text",
46
			ent.VectorizationConfig{
47
				Model:        "sentence-transformers/gtr-t5-xxl",
48
				WaitForModel: false,
49
				UseGPU:       false,
50
				UseCache:     true,
51
				EndpointURL:  server.URL,
52
			})
53

54
		assert.Nil(t, err)
55
		assert.Equal(t, expected, res)
56
	})
57

58
	t.Run("when the context is expired", func(t *testing.T) {
59
		server := httptest.NewServer(&fakeHandler{t: t})
60
		defer server.Close()
61
		c := &vectorizer{
62
			apiKey:     "apiKey",
63
			httpClient: &http.Client{},
64
			logger:     nullLogger(),
65
		}
66
		ctx, cancel := context.WithDeadline(context.Background(), time.Now())
67
		defer cancel()
68

69
		_, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{
70
			EndpointURL: server.URL,
71
		})
72

73
		require.NotNil(t, err)
74
		assert.Contains(t, err.Error(), "context deadline exceeded")
75
	})
76

77
	t.Run("when the server returns an error", func(t *testing.T) {
78
		server := httptest.NewServer(&fakeHandler{
79
			t:           t,
80
			serverError: errors.Errorf("nope, not gonna happen"),
81
		})
82
		defer server.Close()
83
		c := &vectorizer{
84
			apiKey:     "apiKey",
85
			httpClient: &http.Client{},
86
			logger:     nullLogger(),
87
		}
88
		_, err := c.Vectorize(context.Background(), "This is my text",
89
			ent.VectorizationConfig{
90
				EndpointURL: server.URL,
91
			})
92

93
		require.NotNil(t, err)
94
		assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: nope, not gonna happen estimated time: 20")
95
	})
96

97
	t.Run("when HuggingFace key is passed using X-Huggingface-Api-Key header", func(t *testing.T) {
98
		server := httptest.NewServer(&fakeHandler{t: t})
99
		defer server.Close()
100
		c := &vectorizer{
101
			apiKey:     "",
102
			httpClient: &http.Client{},
103
			logger:     nullLogger(),
104
		}
105
		ctxWithValue := context.WithValue(context.Background(),
106
			"X-Huggingface-Api-Key", []string{"some-key"})
107

108
		expected := &ent.VectorizationResult{
109
			Text:       "This is my text",
110
			Vector:     []float32{0.1, 0.2, 0.3},
111
			Dimensions: 3,
112
		}
113
		res, err := c.Vectorize(ctxWithValue, "This is my text",
114
			ent.VectorizationConfig{
115
				Model:        "sentence-transformers/gtr-t5-xxl",
116
				WaitForModel: true,
117
				UseGPU:       false,
118
				UseCache:     true,
119
				EndpointURL:  server.URL,
120
			})
121

122
		require.Nil(t, err)
123
		assert.Equal(t, expected, res)
124
	})
125

126
	t.Run("when a request requires an API KEY", func(t *testing.T) {
127
		server := httptest.NewServer(&fakeHandler{
128
			t:           t,
129
			serverError: errors.Errorf("A valid user or organization token is required"),
130
		})
131
		defer server.Close()
132
		c := &vectorizer{
133
			apiKey:     "",
134
			httpClient: &http.Client{},
135
			logger:     nullLogger(),
136
		}
137
		ctxWithValue := context.WithValue(context.Background(),
138
			"X-Huggingface-Api-Key", []string{""})
139

140
		_, err := c.Vectorize(ctxWithValue, "This is my text",
141
			ent.VectorizationConfig{
142
				Model:       "sentence-transformers/gtr-t5-xxl",
143
				EndpointURL: server.URL,
144
			})
145

146
		require.NotNil(t, err)
147
		assert.Equal(t, err.Error(), "failed with status: 401 error: A valid user or organization token is required")
148
	})
149

150
	t.Run("when the server returns an error with warnings", func(t *testing.T) {
151
		server := httptest.NewServer(&fakeHandler{
152
			t:           t,
153
			serverError: errors.Errorf("with warnings"),
154
		})
155
		defer server.Close()
156
		c := &vectorizer{
157
			apiKey:     "apiKey",
158
			httpClient: &http.Client{},
159
			logger:     nullLogger(),
160
		}
161
		_, err := c.Vectorize(context.Background(), "This is my text",
162
			ent.VectorizationConfig{
163
				EndpointURL: server.URL,
164
			})
165

166
		require.NotNil(t, err)
167
		assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: with warnings "+
168
			"warnings: [There was an inference error: CUDA error: all CUDA-capable devices are busy or unavailable\n"+
169
			"CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\n"+
170
			"For debugging consider passing CUDA_LAUNCH_BLOCKING=1.]")
171
	})
172
}
173

174
type fakeHandler struct {
175
	t           *testing.T
176
	serverError error
177
}
178

179
func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
180
	assert.Equal(f.t, http.MethodPost, r.Method)
181

182
	if f.serverError != nil {
183
		switch f.serverError.Error() {
184
		case "with warnings":
185
			embeddingError := map[string]interface{}{
186
				"error": f.serverError.Error(),
187
				"warnings": []string{
188
					"There was an inference error: CUDA error: all CUDA-capable devices are busy or unavailable\n" +
189
						"CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\n" +
190
						"For debugging consider passing CUDA_LAUNCH_BLOCKING=1.",
191
				},
192
			}
193
			outBytes, err := json.Marshal(embeddingError)
194
			require.Nil(f.t, err)
195

196
			w.WriteHeader(http.StatusInternalServerError)
197
			w.Write(outBytes)
198
			return
199
		case "A valid user or organization token is required":
200
			embeddingError := map[string]interface{}{
201
				"error": "A valid user or organization token is required",
202
			}
203
			outBytes, err := json.Marshal(embeddingError)
204
			require.Nil(f.t, err)
205

206
			w.WriteHeader(http.StatusUnauthorized)
207
			w.Write(outBytes)
208
			return
209
		default:
210
			embeddingError := map[string]interface{}{
211
				"error":          f.serverError.Error(),
212
				"estimated_time": 20.0,
213
			}
214
			outBytes, err := json.Marshal(embeddingError)
215
			require.Nil(f.t, err)
216

217
			w.WriteHeader(http.StatusInternalServerError)
218
			w.Write(outBytes)
219
			return
220
		}
221
	}
222

223
	bodyBytes, err := io.ReadAll(r.Body)
224
	require.Nil(f.t, err)
225
	defer r.Body.Close()
226

227
	var b map[string]interface{}
228
	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
229

230
	textInputs := b["inputs"].([]interface{})
231
	assert.Greater(f.t, len(textInputs), 0)
232
	textInput := textInputs[0].(string)
233
	assert.Greater(f.t, len(textInput), 0)
234

235
	// TODO: fix this
236
	embedding := [][]float32{{0.1, 0.2, 0.3}}
237
	outBytes, err := json.Marshal(embedding)
238
	require.Nil(f.t, err)
239

240
	w.Write(outBytes)
241
}
242

243
func nullLogger() logrus.FieldLogger {
244
	l, _ := test.NewNullLogger()
245
	return l
246
}
247

248
func Test_getURL(t *testing.T) {
249
	v := &vectorizer{}
250

251
	tests := []struct {
252
		name   string
253
		config ent.VectorizationConfig
254
		want   string
255
	}{
256
		{
257
			name: "Facebook DPR model",
258
			config: ent.VectorizationConfig{
259
				Model: "sentence-transformers/facebook-dpr-ctx_encoder-multiset-base",
260
			},
261
			want: "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/facebook-dpr-ctx_encoder-multiset-base",
262
		},
263
		{
264
			name: "BERT base model (uncased)",
265
			config: ent.VectorizationConfig{
266
				Model: "bert-base-uncased",
267
			},
268
			want: "https://api-inference.huggingface.co/pipeline/feature-extraction/bert-base-uncased",
269
		},
270
		{
271
			name: "BERT base model (uncased)",
272
			config: ent.VectorizationConfig{
273
				EndpointURL: "https://self-hosted-instance.com/bert-base-uncased",
274
			},
275
			want: "https://self-hosted-instance.com/bert-base-uncased",
276
		},
277
	}
278
	for _, tt := range tests {
279
		t.Run(tt.name, func(t *testing.T) {
280
			assert.Equal(t, tt.want, v.getURL(tt.config))
281
		})
282
	}
283
}
284

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.