weaviate
161 строка · 4.0 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package vectorizer
13
14import (
15"github.com/pkg/errors"
16
17"github.com/weaviate/weaviate/entities/models"
18"github.com/weaviate/weaviate/entities/moduletools"
19basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
20)
21
22const (
23DefaultHuggingFaceModel = "sentence-transformers/msmarco-bert-base-dot-v5"
24DefaultOptionWaitForModel = false
25DefaultOptionUseGPU = false
26DefaultOptionUseCache = true
27DefaultVectorizeClassName = true
28DefaultPropertyIndexed = true
29DefaultVectorizePropertyName = false
30)
31
32type classSettings struct {
33basesettings.BaseClassSettings
34cfg moduletools.ClassConfig
35}
36
37func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
38return &classSettings{cfg: cfg, BaseClassSettings: *basesettings.NewBaseClassSettings(cfg)}
39}
40
41func (cs *classSettings) EndpointURL() string {
42return cs.getEndpointURL()
43}
44
45func (cs *classSettings) PassageModel() string {
46model := cs.getPassageModel()
47if model == "" {
48return DefaultHuggingFaceModel
49}
50return model
51}
52
53func (cs *classSettings) QueryModel() string {
54model := cs.getQueryModel()
55if model == "" {
56return DefaultHuggingFaceModel
57}
58return model
59}
60
61func (cs *classSettings) OptionWaitForModel() bool {
62return cs.getOptionOrDefault("waitForModel", DefaultOptionWaitForModel)
63}
64
65func (cs *classSettings) OptionUseGPU() bool {
66return cs.getOptionOrDefault("useGPU", DefaultOptionUseGPU)
67}
68
69func (cs *classSettings) OptionUseCache() bool {
70return cs.getOptionOrDefault("useCache", DefaultOptionUseCache)
71}
72
73func (cs *classSettings) Validate(class *models.Class) error {
74return cs.BaseClassSettings.Validate(class)
75}
76
77func (cs *classSettings) validateClassSettings() error {
78if err := cs.BaseClassSettings.ValidateClassSettings(); err != nil {
79return err
80}
81
82endpointURL := cs.getEndpointURL()
83if endpointURL != "" {
84// endpoint is set, should be used for feature extraction
85// all other settings are not relevant
86return nil
87}
88
89model := cs.getProperty("model")
90passageModel := cs.getProperty("passageModel")
91queryModel := cs.getProperty("queryModel")
92
93if model != "" && (passageModel != "" || queryModel != "") {
94return errors.New("only one setting must be set either 'model' or 'passageModel' with 'queryModel'")
95}
96
97if model == "" {
98if passageModel != "" && queryModel == "" {
99return errors.New("'passageModel' is set, but 'queryModel' is empty")
100}
101if passageModel == "" && queryModel != "" {
102return errors.New("'queryModel' is set, but 'passageModel' is empty")
103}
104}
105return nil
106}
107
108func (cs *classSettings) getPassageModel() string {
109model := cs.getProperty("model")
110if model == "" {
111model = cs.getProperty("passageModel")
112}
113return model
114}
115
116func (cs *classSettings) getQueryModel() string {
117model := cs.getProperty("model")
118if model == "" {
119model = cs.getProperty("queryModel")
120}
121return model
122}
123
124func (cs *classSettings) getEndpointURL() string {
125endpointURL := cs.getProperty("endpointUrl")
126if endpointURL == "" {
127endpointURL = cs.getProperty("endpointURL")
128}
129return endpointURL
130}
131
132func (cs *classSettings) getOption(option string) *bool {
133if cs.cfg != nil {
134options, ok := cs.cfg.Class()["options"]
135if ok {
136asMap, ok := options.(map[string]interface{})
137if ok {
138option, ok := asMap[option]
139if ok {
140asBool, ok := option.(bool)
141if ok {
142return &asBool
143}
144}
145}
146}
147}
148return nil
149}
150
151func (cs *classSettings) getOptionOrDefault(option string, defaultValue bool) bool {
152optionValue := cs.getOption(option)
153if optionValue != nil {
154return *optionValue
155}
156return defaultValue
157}
158
159func (cs *classSettings) getProperty(name string) string {
160return cs.BaseClassSettings.GetPropertyAsString(name, "")
161}
162