weaviate
114 строк · 3.0 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package clients
13
14import (
15"context"
16"encoding/json"
17"fmt"
18"io"
19"net/http"
20"net/http/httptest"
21"testing"
22"time"
23
24"github.com/pkg/errors"
25"github.com/stretchr/testify/assert"
26"github.com/stretchr/testify/require"
27"github.com/weaviate/weaviate/modules/text2vec-transformers/ent"
28)
29
30func TestClient(t *testing.T) {
31t.Run("when all is fine", func(t *testing.T) {
32server := httptest.NewServer(&fakeHandler{t: t})
33defer server.Close()
34c := New(server.URL, server.URL, 0, nullLogger())
35expected := &ent.VectorizationResult{
36Text: "This is my text",
37Vector: []float32{0.1, 0.2, 0.3},
38Dimensions: 3,
39}
40res, err := c.VectorizeObject(context.Background(), "This is my text",
41ent.VectorizationConfig{
42PoolingStrategy: "masked_mean",
43})
44
45assert.Nil(t, err)
46assert.Equal(t, expected, res)
47})
48
49t.Run("when the context is expired", func(t *testing.T) {
50server := httptest.NewServer(&fakeHandler{t: t})
51defer server.Close()
52c := New(server.URL, server.URL, 0, nullLogger())
53ctx, cancel := context.WithDeadline(context.Background(), time.Now())
54defer cancel()
55
56_, err := c.VectorizeObject(ctx, "This is my text", ent.VectorizationConfig{})
57
58require.NotNil(t, err)
59assert.Contains(t, err.Error(), "context deadline exceeded")
60})
61
62t.Run("when the server returns an error", func(t *testing.T) {
63server := httptest.NewServer(&fakeHandler{
64t: t,
65serverError: errors.Errorf("nope, not gonna happen"),
66})
67defer server.Close()
68c := New(server.URL, server.URL, 0, nullLogger())
69_, err := c.VectorizeObject(context.Background(), "This is my text",
70ent.VectorizationConfig{})
71
72require.NotNil(t, err)
73assert.Contains(t, err.Error(), "nope, not gonna happen")
74})
75}
76
77type fakeHandler struct {
78t *testing.T
79serverError error
80}
81
82func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
83assert.Equal(f.t, "/vectors", r.URL.String())
84assert.Equal(f.t, http.MethodPost, r.Method)
85
86if f.serverError != nil {
87w.WriteHeader(http.StatusInternalServerError)
88w.Write([]byte(fmt.Sprintf(`{"error":"%s"}`, f.serverError.Error())))
89return
90}
91
92bodyBytes, err := io.ReadAll(r.Body)
93require.Nil(f.t, err)
94defer r.Body.Close()
95
96var b map[string]interface{}
97require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
98
99textInput := b["text"].(string)
100assert.Greater(f.t, len(textInput), 0)
101
102pooling := b["config"].(map[string]interface{})["pooling_strategy"].(string)
103assert.Equal(f.t, "masked_mean", pooling)
104
105out := map[string]interface{}{
106"text": textInput,
107"dims": 3,
108"vector": []float32{0.1, 0.2, 0.3},
109}
110outBytes, err := json.Marshal(out)
111require.Nil(f.t, err)
112
113w.Write(outBytes)
114}
115