weaviate
122 строки · 3.6 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package vectorizer13
14import (15"errors"16"testing"17
18"github.com/stretchr/testify/assert"19"github.com/weaviate/weaviate/entities/moduletools"20)
21
22func Test_classSettings_getPassageModel(t *testing.T) {23tests := []struct {24name string25cfg moduletools.ClassConfig26wantPassageModel string27wantQueryModel string28wantWaitForModel bool29wantUseGPU bool30wantUseCache bool31wantEndpointURL string32wantError error33}{34{35name: "CShorten/CORD-19-Title-Abstracts",36cfg: fakeClassConfig{37classConfig: map[string]interface{}{38"model": "CShorten/CORD-19-Title-Abstracts",39"options": map[string]interface{}{40"waitForModel": true,41"useGPU": false,42"useCache": false,43},44},45},46wantPassageModel: "CShorten/CORD-19-Title-Abstracts",47wantQueryModel: "CShorten/CORD-19-Title-Abstracts",48wantWaitForModel: true,49wantUseGPU: false,50wantUseCache: false,51},52{53name: "sentence-transformers/all-MiniLM-L6-v2",54cfg: fakeClassConfig{55classConfig: map[string]interface{}{56"model": "sentence-transformers/all-MiniLM-L6-v2",57},58},59wantPassageModel: "sentence-transformers/all-MiniLM-L6-v2",60wantQueryModel: "sentence-transformers/all-MiniLM-L6-v2",61wantWaitForModel: false,62wantUseGPU: false,63wantUseCache: true,64},65{66name: "DPR models",67cfg: fakeClassConfig{68classConfig: map[string]interface{}{69"passageModel": "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base",70"queryModel": "sentence-transformers/facebook-dpr-question_encoder-single-nq-base",71},72},73wantPassageModel: "sentence-transformers/facebook-dpr-ctx_encoder-single-nq-base",74wantQueryModel: "sentence-transformers/facebook-dpr-question_encoder-single-nq-base",75wantWaitForModel: false,76wantUseGPU: false,77wantUseCache: true,78},79{80name: "Hugging Face Inference API - endpointURL",81cfg: fakeClassConfig{82classConfig: map[string]interface{}{83"endpointURL": "http://endpoint.cloud",84},85},86wantPassageModel: "",87wantQueryModel: "",88wantWaitForModel: false,89wantUseGPU: false,90wantUseCache: true,91wantEndpointURL: "http://endpoint.cloud",92},93{94name: "Hugging Face Inference API - wrong properties",95cfg: fakeClassConfig{96classConfig: map[string]interface{}{97"endpointUrl": "http://endpoint.cloud",98"properties": "wrong-properties",99},100},101wantPassageModel: "",102wantQueryModel: "",103wantWaitForModel: false,104wantUseGPU: false,105wantUseCache: true,106wantEndpointURL: "http://endpoint.cloud",107wantError: errors.New("properties field needs to be of array type, got: string"),108},109}110for _, tt := range tests {111t.Run(tt.name, func(t *testing.T) {112ic := NewClassSettings(tt.cfg)113assert.Equal(t, tt.wantPassageModel, ic.getPassageModel())114assert.Equal(t, tt.wantQueryModel, ic.getQueryModel())115assert.Equal(t, tt.wantWaitForModel, ic.OptionWaitForModel())116assert.Equal(t, tt.wantUseGPU, ic.OptionUseGPU())117assert.Equal(t, tt.wantUseCache, ic.OptionUseCache())118assert.Equal(t, tt.wantEndpointURL, ic.EndpointURL())119assert.Equal(t, tt.wantError, ic.validateClassSettings())120})121}122}
123