weaviate

Форк
0
161 строка · 4.0 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package vectorizer
13

14
import (
15
	"github.com/pkg/errors"
16

17
	"github.com/weaviate/weaviate/entities/models"
18
	"github.com/weaviate/weaviate/entities/moduletools"
19
	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
20
)
21

22
const (
23
	DefaultHuggingFaceModel      = "sentence-transformers/msmarco-bert-base-dot-v5"
24
	DefaultOptionWaitForModel    = false
25
	DefaultOptionUseGPU          = false
26
	DefaultOptionUseCache        = true
27
	DefaultVectorizeClassName    = true
28
	DefaultPropertyIndexed       = true
29
	DefaultVectorizePropertyName = false
30
)
31

32
type classSettings struct {
33
	basesettings.BaseClassSettings
34
	cfg moduletools.ClassConfig
35
}
36

37
func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
38
	return &classSettings{cfg: cfg, BaseClassSettings: *basesettings.NewBaseClassSettings(cfg)}
39
}
40

41
func (cs *classSettings) EndpointURL() string {
42
	return cs.getEndpointURL()
43
}
44

45
func (cs *classSettings) PassageModel() string {
46
	model := cs.getPassageModel()
47
	if model == "" {
48
		return DefaultHuggingFaceModel
49
	}
50
	return model
51
}
52

53
func (cs *classSettings) QueryModel() string {
54
	model := cs.getQueryModel()
55
	if model == "" {
56
		return DefaultHuggingFaceModel
57
	}
58
	return model
59
}
60

61
func (cs *classSettings) OptionWaitForModel() bool {
62
	return cs.getOptionOrDefault("waitForModel", DefaultOptionWaitForModel)
63
}
64

65
func (cs *classSettings) OptionUseGPU() bool {
66
	return cs.getOptionOrDefault("useGPU", DefaultOptionUseGPU)
67
}
68

69
func (cs *classSettings) OptionUseCache() bool {
70
	return cs.getOptionOrDefault("useCache", DefaultOptionUseCache)
71
}
72

73
func (cs *classSettings) Validate(class *models.Class) error {
74
	return cs.BaseClassSettings.Validate(class)
75
}
76

77
func (cs *classSettings) validateClassSettings() error {
78
	if err := cs.BaseClassSettings.ValidateClassSettings(); err != nil {
79
		return err
80
	}
81

82
	endpointURL := cs.getEndpointURL()
83
	if endpointURL != "" {
84
		// endpoint is set, should be used for feature extraction
85
		// all other settings are not relevant
86
		return nil
87
	}
88

89
	model := cs.getProperty("model")
90
	passageModel := cs.getProperty("passageModel")
91
	queryModel := cs.getProperty("queryModel")
92

93
	if model != "" && (passageModel != "" || queryModel != "") {
94
		return errors.New("only one setting must be set either 'model' or 'passageModel' with 'queryModel'")
95
	}
96

97
	if model == "" {
98
		if passageModel != "" && queryModel == "" {
99
			return errors.New("'passageModel' is set, but 'queryModel' is empty")
100
		}
101
		if passageModel == "" && queryModel != "" {
102
			return errors.New("'queryModel' is set, but 'passageModel' is empty")
103
		}
104
	}
105
	return nil
106
}
107

108
func (cs *classSettings) getPassageModel() string {
109
	model := cs.getProperty("model")
110
	if model == "" {
111
		model = cs.getProperty("passageModel")
112
	}
113
	return model
114
}
115

116
func (cs *classSettings) getQueryModel() string {
117
	model := cs.getProperty("model")
118
	if model == "" {
119
		model = cs.getProperty("queryModel")
120
	}
121
	return model
122
}
123

124
func (cs *classSettings) getEndpointURL() string {
125
	endpointURL := cs.getProperty("endpointUrl")
126
	if endpointURL == "" {
127
		endpointURL = cs.getProperty("endpointURL")
128
	}
129
	return endpointURL
130
}
131

132
func (cs *classSettings) getOption(option string) *bool {
133
	if cs.cfg != nil {
134
		options, ok := cs.cfg.Class()["options"]
135
		if ok {
136
			asMap, ok := options.(map[string]interface{})
137
			if ok {
138
				option, ok := asMap[option]
139
				if ok {
140
					asBool, ok := option.(bool)
141
					if ok {
142
						return &asBool
143
					}
144
				}
145
			}
146
		}
147
	}
148
	return nil
149
}
150

151
func (cs *classSettings) getOptionOrDefault(option string, defaultValue bool) bool {
152
	optionValue := cs.getOption(option)
153
	if optionValue != nil {
154
		return *optionValue
155
	}
156
	return defaultValue
157
}
158

159
func (cs *classSettings) getProperty(name string) string {
160
	return cs.BaseClassSettings.GetPropertyAsString(name, "")
161
}
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.