weaviate
350 строк · 12.2 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package clients13
14import (15"net/http"16"net/http/httptest"17"testing"18"time"19
20"github.com/stretchr/testify/assert"21)
22
23func TestGetMeta(t *testing.T) {24t.Run("when common server is providing meta", func(t *testing.T) {25server := httptest.NewServer(&testMetaHandler{t: t})26defer server.Close()27v := New(server.URL, server.URL, 0, nullLogger())28meta, err := v.MetaInfo()29
30assert.Nil(t, err)31assert.NotNil(t, meta)32
33model := extractChildMap(t, meta, "model")34assert.NotNil(t, model["_name_or_path"])35assert.NotNil(t, model["architectures"])36assert.Contains(t, model["architectures"], "DistilBertModel")37ID2Label := extractChildMap(t, model, "id2label")38assert.NotNil(t, ID2Label["0"])39assert.NotNil(t, ID2Label["1"])40})41
42t.Run("when passage and query servers are providing meta", func(t *testing.T) {43serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage"})44serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query"})45defer serverPassage.Close()46defer serverQuery.Close()47v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())48meta, err := v.MetaInfo()49
50assert.Nil(t, err)51assert.NotNil(t, meta)52
53passage := extractChildMap(t, meta, "passage")54passageModel := extractChildMap(t, passage, "model")55assert.NotNil(t, passageModel["_name_or_path"])56assert.NotNil(t, passageModel["architectures"])57assert.Contains(t, passageModel["architectures"], "DPRContextEncoder")58passageID2Label := extractChildMap(t, passageModel, "id2label")59assert.NotNil(t, passageID2Label["0"])60assert.NotNil(t, passageID2Label["1"])61
62query := extractChildMap(t, meta, "query")63queryModel := extractChildMap(t, query, "model")64assert.NotNil(t, queryModel["_name_or_path"])65assert.NotNil(t, queryModel["architectures"])66assert.Contains(t, queryModel["architectures"], "DPRQuestionEncoder")67queryID2Label := extractChildMap(t, queryModel, "id2label")68assert.NotNil(t, queryID2Label["0"])69assert.NotNil(t, queryID2Label["1"])70})71
72t.Run("when passage and query servers are unavailable", func(t *testing.T) {73rt := time.Now().Add(time.Hour)74serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage", readyTime: rt})75serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query", readyTime: rt})76defer serverPassage.Close()77defer serverQuery.Close()78v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())79meta, err := v.MetaInfo()80
81assert.NotNil(t, err)82assert.Contains(t, err.Error(), "[passage] unexpected status code '503' of meta request")83assert.Contains(t, err.Error(), "[query] unexpected status code '503' of meta request")84assert.Nil(t, meta)85})86}
87
88type testMetaHandler struct {89t *testing.T90// the test handler will report as not ready before the time has passed91readyTime time.Time92modelType string93}
94
95func (h *testMetaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {96assert.Equal(h.t, "/meta", r.URL.String())97assert.Equal(h.t, http.MethodGet, r.Method)98
99if time.Since(h.readyTime) < 0 {100w.WriteHeader(http.StatusServiceUnavailable)101return102}103
104w.Write([]byte(h.metaInfo()))105}
106
107func (h *testMetaHandler) metaInfo() string {108switch h.modelType {109case "passage":110return `{111"model": {
112"return_dict": true,
113"output_hidden_states": false,
114"output_attentions": false,
115"torchscript": false,
116"torch_dtype": "float32",
117"use_bfloat16": false,
118"pruned_heads": {},
119"tie_word_embeddings": true,
120"is_encoder_decoder": false,
121"is_decoder": false,
122"cross_attention_hidden_size": null,
123"add_cross_attention": false,
124"tie_encoder_decoder": false,
125"max_length": 20,
126"min_length": 0,
127"do_sample": false,
128"early_stopping": false,
129"num_beams": 1,
130"num_beam_groups": 1,
131"diversity_penalty": 0,
132"temperature": 1,
133"top_k": 50,
134"top_p": 1,
135"repetition_penalty": 1,
136"length_penalty": 1,
137"no_repeat_ngram_size": 0,
138"encoder_no_repeat_ngram_size": 0,
139"bad_words_ids": null,
140"num_return_sequences": 1,
141"chunk_size_feed_forward": 0,
142"output_scores": false,
143"return_dict_in_generate": false,
144"forced_bos_token_id": null,
145"forced_eos_token_id": null,
146"remove_invalid_values": false,
147"architectures": [
148"DPRContextEncoder"
149],
150"finetuning_task": null,
151"id2label": {
152"0": "LABEL_0",
153"1": "LABEL_1"
154},
155"label2id": {
156"LABEL_0": 0,
157"LABEL_1": 1
158},
159"tokenizer_class": null,
160"prefix": null,
161"bos_token_id": null,
162"pad_token_id": 0,
163"eos_token_id": null,
164"sep_token_id": null,
165"decoder_start_token_id": null,
166"task_specific_params": null,
167"problem_type": null,
168"_name_or_path": "./models/model",
169"transformers_version": "4.16.2",
170"gradient_checkpointing": false,
171"model_type": "dpr",
172"vocab_size": 30522,
173"hidden_size": 768,
174"num_hidden_layers": 12,
175"num_attention_heads": 12,
176"hidden_act": "gelu",
177"intermediate_size": 3072,
178"hidden_dropout_prob": 0.1,
179"attention_probs_dropout_prob": 0.1,
180"max_position_embeddings": 512,
181"type_vocab_size": 2,
182"initializer_range": 0.02,
183"layer_norm_eps": 1e-12,
184"projection_dim": 0,
185"position_embedding_type": "absolute"
186}
187}`
188case "query":189return `{190"model": {
191"return_dict": true,
192"output_hidden_states": false,
193"output_attentions": false,
194"torchscript": false,
195"torch_dtype": "float32",
196"use_bfloat16": false,
197"pruned_heads": {},
198"tie_word_embeddings": true,
199"is_encoder_decoder": false,
200"is_decoder": false,
201"cross_attention_hidden_size": null,
202"add_cross_attention": false,
203"tie_encoder_decoder": false,
204"max_length": 20,
205"min_length": 0,
206"do_sample": false,
207"early_stopping": false,
208"num_beams": 1,
209"num_beam_groups": 1,
210"diversity_penalty": 0,
211"temperature": 1,
212"top_k": 50,
213"top_p": 1,
214"repetition_penalty": 1,
215"length_penalty": 1,
216"no_repeat_ngram_size": 0,
217"encoder_no_repeat_ngram_size": 0,
218"bad_words_ids": null,
219"num_return_sequences": 1,
220"chunk_size_feed_forward": 0,
221"output_scores": false,
222"return_dict_in_generate": false,
223"forced_bos_token_id": null,
224"forced_eos_token_id": null,
225"remove_invalid_values": false,
226"architectures": [
227"DPRQuestionEncoder"
228],
229"finetuning_task": null,
230"id2label": {
231"0": "LABEL_0",
232"1": "LABEL_1"
233},
234"label2id": {
235"LABEL_0": 0,
236"LABEL_1": 1
237},
238"tokenizer_class": null,
239"prefix": null,
240"bos_token_id": null,
241"pad_token_id": 0,
242"eos_token_id": null,
243"sep_token_id": null,
244"decoder_start_token_id": null,
245"task_specific_params": null,
246"problem_type": null,
247"_name_or_path": "./models/model",
248"transformers_version": "4.16.2",
249"gradient_checkpointing": false,
250"model_type": "dpr",
251"vocab_size": 30522,
252"hidden_size": 768,
253"num_hidden_layers": 12,
254"num_attention_heads": 12,
255"hidden_act": "gelu",
256"intermediate_size": 3072,
257"hidden_dropout_prob": 0.1,
258"attention_probs_dropout_prob": 0.1,
259"max_position_embeddings": 512,
260"type_vocab_size": 2,
261"initializer_range": 0.02,
262"layer_norm_eps": 1e-12,
263"projection_dim": 0,
264"position_embedding_type": "absolute"
265}
266}`
267default:268return `{269"model": {
270"_name_or_path": "distilbert-base-uncased",
271"activation": "gelu",
272"add_cross_attention": false,
273"architectures": [
274"DistilBertModel"
275],
276"attention_dropout": 0.1,
277"bad_words_ids": null,
278"bos_token_id": null,
279"chunk_size_feed_forward": 0,
280"decoder_start_token_id": null,
281"dim": 768,
282"diversity_penalty": 0,
283"do_sample": false,
284"dropout": 0.1,
285"early_stopping": false,
286"encoder_no_repeat_ngram_size": 0,
287"eos_token_id": null,
288"finetuning_task": null,
289"hidden_dim": 3072,
290"id2label": {
291"0": "LABEL_0",
292"1": "LABEL_1"
293},
294"initializer_range": 0.02,
295"is_decoder": false,
296"is_encoder_decoder": false,
297"label2id": {
298"LABEL_0": 0,
299"LABEL_1": 1
300},
301"length_penalty": 1,
302"max_length": 20,
303"max_position_embeddings": 512,
304"min_length": 0,
305"model_type": "distilbert",
306"n_heads": 12,
307"n_layers": 6,
308"no_repeat_ngram_size": 0,
309"num_beam_groups": 1,
310"num_beams": 1,
311"num_return_sequences": 1,
312"output_attentions": false,
313"output_hidden_states": false,
314"output_scores": false,
315"pad_token_id": 0,
316"prefix": null,
317"pruned_heads": {},
318"qa_dropout": 0.1,
319"repetition_penalty": 1,
320"return_dict": true,
321"return_dict_in_generate": false,
322"sep_token_id": null,
323"seq_classif_dropout": 0.2,
324"sinusoidal_pos_embds": false,
325"task_specific_params": null,
326"temperature": 1,
327"tie_encoder_decoder": false,
328"tie_weights_": true,
329"tie_word_embeddings": true,
330"tokenizer_class": null,
331"top_k": 50,
332"top_p": 1,
333"torchscript": false,
334"transformers_version": "4.3.2",
335"use_bfloat16": false,
336"vocab_size": 30522,
337"xla_device": null
338}
339}`
340}341}
342
343func extractChildMap(t *testing.T, parent map[string]interface{}, name string) map[string]interface{} {344assert.NotNil(t, parent[name])345child, ok := parent[name].(map[string]interface{})346assert.True(t, ok)347assert.NotNil(t, child)348
349return child350}
351