weaviate

Форк
0
114 строк · 3.0 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package clients
13

14
import (
15
	"context"
16
	"encoding/json"
17
	"fmt"
18
	"io"
19
	"net/http"
20
	"net/http/httptest"
21
	"testing"
22
	"time"
23

24
	"github.com/pkg/errors"
25
	"github.com/stretchr/testify/assert"
26
	"github.com/stretchr/testify/require"
27
	"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
28
)
29

30
func TestClient(t *testing.T) {
31
	t.Run("when all is fine", func(t *testing.T) {
32
		server := httptest.NewServer(&fakeHandler{t: t})
33
		defer server.Close()
34
		c := New(server.URL, server.URL, 0, nullLogger())
35
		expected := &ent.VectorizationResult{
36
			Text:       "This is my text",
37
			Vector:     []float32{0.1, 0.2, 0.3},
38
			Dimensions: 3,
39
		}
40
		res, err := c.VectorizeObject(context.Background(), "This is my text",
41
			ent.VectorizationConfig{
42
				PoolingStrategy: "masked_mean",
43
			})
44

45
		assert.Nil(t, err)
46
		assert.Equal(t, expected, res)
47
	})
48

49
	t.Run("when the context is expired", func(t *testing.T) {
50
		server := httptest.NewServer(&fakeHandler{t: t})
51
		defer server.Close()
52
		c := New(server.URL, server.URL, 0, nullLogger())
53
		ctx, cancel := context.WithDeadline(context.Background(), time.Now())
54
		defer cancel()
55

56
		_, err := c.VectorizeObject(ctx, "This is my text", ent.VectorizationConfig{})
57

58
		require.NotNil(t, err)
59
		assert.Contains(t, err.Error(), "context deadline exceeded")
60
	})
61

62
	t.Run("when the server returns an error", func(t *testing.T) {
63
		server := httptest.NewServer(&fakeHandler{
64
			t:           t,
65
			serverError: errors.Errorf("nope, not gonna happen"),
66
		})
67
		defer server.Close()
68
		c := New(server.URL, server.URL, 0, nullLogger())
69
		_, err := c.VectorizeObject(context.Background(), "This is my text",
70
			ent.VectorizationConfig{})
71

72
		require.NotNil(t, err)
73
		assert.Contains(t, err.Error(), "nope, not gonna happen")
74
	})
75
}
76

77
type fakeHandler struct {
78
	t           *testing.T
79
	serverError error
80
}
81

82
func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
83
	assert.Equal(f.t, "/vectors", r.URL.String())
84
	assert.Equal(f.t, http.MethodPost, r.Method)
85

86
	if f.serverError != nil {
87
		w.WriteHeader(http.StatusInternalServerError)
88
		w.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, f.serverError.Error())))
89
		return
90
	}
91

92
	bodyBytes, err := io.ReadAll(r.Body)
93
	require.Nil(f.t, err)
94
	defer r.Body.Close()
95

96
	var b map[string]interface{}
97
	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
98

99
	textInput := b["text"].(string)
100
	assert.Greater(f.t, len(textInput), 0)
101

102
	pooling := b["config"].(map[string]interface{})["pooling_strategy"].(string)
103
	assert.Equal(f.t, "masked_mean", pooling)
104

105
	out := map[string]interface{}{
106
		"text":   textInput,
107
		"dims":   3,
108
		"vector": []float32{0.1, 0.2, 0.3},
109
	}
110
	outBytes, err := json.Marshal(out)
111
	require.Nil(f.t, err)
112

113
	w.Write(outBytes)
114
}
115

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.