weaviate
283 строки · 7.9 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package clients13
14import (15"context"16"encoding/json"17"io"18"net/http"19"net/http/httptest"20"testing"21"time"22
23"github.com/pkg/errors"24"github.com/sirupsen/logrus"25"github.com/sirupsen/logrus/hooks/test"26"github.com/stretchr/testify/assert"27"github.com/stretchr/testify/require"28"github.com/weaviate/weaviate/modules/text2vec-huggingface/ent"29)
30
31func TestClient(t *testing.T) {32t.Run("when all is fine", func(t *testing.T) {33server := httptest.NewServer(&fakeHandler{t: t})34defer server.Close()35c := &vectorizer{36apiKey: "apiKey",37httpClient: &http.Client{},38logger: nullLogger(),39}40expected := &ent.VectorizationResult{41Text: "This is my text",42Vector: []float32{0.1, 0.2, 0.3},43Dimensions: 3,44}45res, err := c.Vectorize(context.Background(), "This is my text",46ent.VectorizationConfig{47Model: "sentence-transformers/gtr-t5-xxl",48WaitForModel: false,49UseGPU: false,50UseCache: true,51EndpointURL: server.URL,52})53
54assert.Nil(t, err)55assert.Equal(t, expected, res)56})57
58t.Run("when the context is expired", func(t *testing.T) {59server := httptest.NewServer(&fakeHandler{t: t})60defer server.Close()61c := &vectorizer{62apiKey: "apiKey",63httpClient: &http.Client{},64logger: nullLogger(),65}66ctx, cancel := context.WithDeadline(context.Background(), time.Now())67defer cancel()68
69_, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{70EndpointURL: server.URL,71})72
73require.NotNil(t, err)74assert.Contains(t, err.Error(), "context deadline exceeded")75})76
77t.Run("when the server returns an error", func(t *testing.T) {78server := httptest.NewServer(&fakeHandler{79t: t,80serverError: errors.Errorf("nope, not gonna happen"),81})82defer server.Close()83c := &vectorizer{84apiKey: "apiKey",85httpClient: &http.Client{},86logger: nullLogger(),87}88_, err := c.Vectorize(context.Background(), "This is my text",89ent.VectorizationConfig{90EndpointURL: server.URL,91})92
93require.NotNil(t, err)94assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: nope, not gonna happen estimated time: 20")95})96
97t.Run("when HuggingFace key is passed using X-Huggingface-Api-Key header", func(t *testing.T) {98server := httptest.NewServer(&fakeHandler{t: t})99defer server.Close()100c := &vectorizer{101apiKey: "",102httpClient: &http.Client{},103logger: nullLogger(),104}105ctxWithValue := context.WithValue(context.Background(),106"X-Huggingface-Api-Key", []string{"some-key"})107
108expected := &ent.VectorizationResult{109Text: "This is my text",110Vector: []float32{0.1, 0.2, 0.3},111Dimensions: 3,112}113res, err := c.Vectorize(ctxWithValue, "This is my text",114ent.VectorizationConfig{115Model: "sentence-transformers/gtr-t5-xxl",116WaitForModel: true,117UseGPU: false,118UseCache: true,119EndpointURL: server.URL,120})121
122require.Nil(t, err)123assert.Equal(t, expected, res)124})125
126t.Run("when a request requires an API KEY", func(t *testing.T) {127server := httptest.NewServer(&fakeHandler{128t: t,129serverError: errors.Errorf("A valid user or organization token is required"),130})131defer server.Close()132c := &vectorizer{133apiKey: "",134httpClient: &http.Client{},135logger: nullLogger(),136}137ctxWithValue := context.WithValue(context.Background(),138"X-Huggingface-Api-Key", []string{""})139
140_, err := c.Vectorize(ctxWithValue, "This is my text",141ent.VectorizationConfig{142Model: "sentence-transformers/gtr-t5-xxl",143EndpointURL: server.URL,144})145
146require.NotNil(t, err)147assert.Equal(t, err.Error(), "failed with status: 401 error: A valid user or organization token is required")148})149
150t.Run("when the server returns an error with warnings", func(t *testing.T) {151server := httptest.NewServer(&fakeHandler{152t: t,153serverError: errors.Errorf("with warnings"),154})155defer server.Close()156c := &vectorizer{157apiKey: "apiKey",158httpClient: &http.Client{},159logger: nullLogger(),160}161_, err := c.Vectorize(context.Background(), "This is my text",162ent.VectorizationConfig{163EndpointURL: server.URL,164})165
166require.NotNil(t, err)167assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: with warnings "+168"warnings: [There was an inference error: CUDA error: all CUDA-capable devices are busy or unavailable\n"+169"CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\n"+170"For debugging consider passing CUDA_LAUNCH_BLOCKING=1.]")171})172}
173
174type fakeHandler struct {175t *testing.T176serverError error177}
178
179func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {180assert.Equal(f.t, http.MethodPost, r.Method)181
182if f.serverError != nil {183switch f.serverError.Error() {184case "with warnings":185embeddingError := map[string]interface{}{186"error": f.serverError.Error(),187"warnings": []string{188"There was an inference error: CUDA error: all CUDA-capable devices are busy or unavailable\n" +189"CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\n" +190"For debugging consider passing CUDA_LAUNCH_BLOCKING=1.",191},192}193outBytes, err := json.Marshal(embeddingError)194require.Nil(f.t, err)195
196w.WriteHeader(http.StatusInternalServerError)197w.Write(outBytes)198return199case "A valid user or organization token is required":200embeddingError := map[string]interface{}{201"error": "A valid user or organization token is required",202}203outBytes, err := json.Marshal(embeddingError)204require.Nil(f.t, err)205
206w.WriteHeader(http.StatusUnauthorized)207w.Write(outBytes)208return209default:210embeddingError := map[string]interface{}{211"error": f.serverError.Error(),212"estimated_time": 20.0,213}214outBytes, err := json.Marshal(embeddingError)215require.Nil(f.t, err)216
217w.WriteHeader(http.StatusInternalServerError)218w.Write(outBytes)219return220}221}222
223bodyBytes, err := io.ReadAll(r.Body)224require.Nil(f.t, err)225defer r.Body.Close()226
227var b map[string]interface{}228require.Nil(f.t, json.Unmarshal(bodyBytes, &b))229
230textInputs := b["inputs"].([]interface{})231assert.Greater(f.t, len(textInputs), 0)232textInput := textInputs[0].(string)233assert.Greater(f.t, len(textInput), 0)234
235// TODO: fix this236embedding := [][]float32{{0.1, 0.2, 0.3}}237outBytes, err := json.Marshal(embedding)238require.Nil(f.t, err)239
240w.Write(outBytes)241}
242
243func nullLogger() logrus.FieldLogger {244l, _ := test.NewNullLogger()245return l246}
247
248func Test_getURL(t *testing.T) {249v := &vectorizer{}250
251tests := []struct {252name string253config ent.VectorizationConfig254want string255}{256{257name: "Facebook DPR model",258config: ent.VectorizationConfig{259Model: "sentence-transformers/facebook-dpr-ctx_encoder-multiset-base",260},261want: "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/facebook-dpr-ctx_encoder-multiset-base",262},263{264name: "BERT base model (uncased)",265config: ent.VectorizationConfig{266Model: "bert-base-uncased",267},268want: "https://api-inference.huggingface.co/pipeline/feature-extraction/bert-base-uncased",269},270{271name: "BERT base model (uncased)",272config: ent.VectorizationConfig{273EndpointURL: "https://self-hosted-instance.com/bert-base-uncased",274},275want: "https://self-hosted-instance.com/bert-base-uncased",276},277}278for _, tt := range tests {279t.Run(tt.name, func(t *testing.T) {280assert.Equal(t, tt.want, v.getURL(tt.config))281})282}283}
284