weaviate

Форк
0
/
class_settings.go 
232 строки · 6.4 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package ent
13

14
import (
15
	"fmt"
16

17
	"github.com/pkg/errors"
18

19
	"github.com/weaviate/weaviate/entities/models"
20
	"github.com/weaviate/weaviate/entities/moduletools"
21
	basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
22
)
23

24
const (
25
	DefaultOpenAIDocumentType    = "text"
26
	DefaultOpenAIModel           = "ada"
27
	DefaultVectorizeClassName    = true
28
	DefaultPropertyIndexed       = true
29
	DefaultVectorizePropertyName = false
30
	DefaultBaseURL               = "https://api.openai.com"
31
	DefaultApiVersion            = "2024-02-01"
32
)
33

34
const (
35
	TextEmbedding3Small = "text-embedding-3-small"
36
	TextEmbedding3Large = "text-embedding-3-large"
37
)
38

39
var (
40
	TextEmbedding3SmallDefaultDimensions int64 = 1536
41
	TextEmbedding3LargeDefaultDimensions int64 = 3072
42
)
43

44
var availableOpenAITypes = []string{"text", "code"}
45

46
var availableV3Models = []string{
47
	// new v3 models
48
	TextEmbedding3Small,
49
	TextEmbedding3Large,
50
}
51

52
var availableV3ModelsDimensions = map[string][]int64{
53
	TextEmbedding3Small: {512, TextEmbedding3SmallDefaultDimensions},
54
	TextEmbedding3Large: {256, 1024, TextEmbedding3LargeDefaultDimensions},
55
}
56

57
var availableOpenAIModels = []string{
58
	"ada",     // supports 001 and 002
59
	"babbage", // only supports 001
60
	"curie",   // only supports 001
61
	"davinci", // only supports 001
62
}
63

64
var availableApiVersions = []string{
65
	"2022-12-01",
66
	"2023-03-15-preview",
67
	"2023-05-15",
68
	"2023-06-01-preview",
69
	"2023-07-01-preview",
70
	"2023-08-01-preview",
71
	"2023-09-01-preview",
72
	"2023-12-01-preview",
73
	"2024-02-15-preview",
74
	"2024-03-01-preview",
75
	"2024-02-01",
76
}
77

78
type classSettings struct {
79
	basesettings.BaseClassSettings
80
	cfg moduletools.ClassConfig
81
}
82

83
func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
84
	return &classSettings{cfg: cfg, BaseClassSettings: *basesettings.NewBaseClassSettings(cfg)}
85
}
86

87
func (cs *classSettings) Model() string {
88
	return cs.BaseClassSettings.GetPropertyAsString("model", DefaultOpenAIModel)
89
}
90

91
func (cs *classSettings) Type() string {
92
	return cs.BaseClassSettings.GetPropertyAsString("type", DefaultOpenAIDocumentType)
93
}
94

95
func (cs *classSettings) ModelVersion() string {
96
	defaultVersion := PickDefaultModelVersion(cs.Model(), cs.Type())
97
	return cs.BaseClassSettings.GetPropertyAsString("modelVersion", defaultVersion)
98
}
99

100
func (cs *classSettings) ResourceName() string {
101
	return cs.BaseClassSettings.GetPropertyAsString("resourceName", "")
102
}
103

104
func (cs *classSettings) BaseURL() string {
105
	return cs.BaseClassSettings.GetPropertyAsString("baseURL", DefaultBaseURL)
106
}
107

108
func (cs *classSettings) DeploymentID() string {
109
	return cs.BaseClassSettings.GetPropertyAsString("deploymentId", "")
110
}
111

112
func (cs *classSettings) ApiVersion() string {
113
	return cs.BaseClassSettings.GetPropertyAsString("apiVersion", DefaultApiVersion)
114
}
115

116
func (cs *classSettings) IsAzure() bool {
117
	return cs.ResourceName() != "" && cs.DeploymentID() != ""
118
}
119

120
func (cs *classSettings) Dimensions() *int64 {
121
	defaultValue := PickDefaultDimensions(cs.Model())
122
	return cs.BaseClassSettings.GetPropertyAsInt64("dimensions", defaultValue)
123
}
124

125
func (cs *classSettings) Validate(class *models.Class) error {
126
	if err := cs.BaseClassSettings.Validate(class); err != nil {
127
		return err
128
	}
129

130
	docType := cs.Type()
131
	if !basesettings.ValidateSetting[string](docType, availableOpenAITypes) {
132
		return errors.Errorf("wrong OpenAI type name, available model names are: %v", availableOpenAITypes)
133
	}
134

135
	availableModels := append(availableOpenAIModels, availableV3Models...)
136
	model := cs.Model()
137
	if !basesettings.ValidateSetting[string](model, availableModels) {
138
		return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableModels)
139
	}
140

141
	dimensions := cs.Dimensions()
142
	if dimensions != nil {
143
		if !basesettings.ValidateSetting[string](model, availableV3Models) {
144
			return errors.Errorf("dimensions setting can only be used with V3 embedding models: %v", availableV3Models)
145
		}
146
		availableDimensions := availableV3ModelsDimensions[model]
147
		if !basesettings.ValidateSetting[int64](*dimensions, availableDimensions) {
148
			return errors.Errorf("wrong dimensions setting for %s model, available dimensions are: %v", model, availableDimensions)
149
		}
150
	}
151

152
	version := cs.ModelVersion()
153
	if err := cs.validateModelVersion(version, model, docType); err != nil {
154
		return err
155
	}
156

157
	err := cs.validateAzureConfig(cs.ResourceName(), cs.DeploymentID(), cs.ApiVersion())
158
	if err != nil {
159
		return err
160
	}
161

162
	return nil
163
}
164

165
func (cs *classSettings) validateModelVersion(version, model, docType string) error {
166
	for i := range availableV3Models {
167
		if model == availableV3Models[i] {
168
			return nil
169
		}
170
	}
171

172
	if version == "001" {
173
		// no restrictions
174
		return nil
175
	}
176

177
	if version == "002" {
178
		// only ada/davinci 002
179
		if model != "ada" && model != "davinci" {
180
			return fmt.Errorf("unsupported version %s", version)
181
		}
182
	}
183

184
	if version == "003" && model != "davinci" {
185
		// only davinci 003
186
		return fmt.Errorf("unsupported version %s", version)
187
	}
188

189
	if version != "002" && version != "003" {
190
		// all other fallback
191
		return fmt.Errorf("model %s is only available in version 001", model)
192
	}
193

194
	if docType != "text" {
195
		return fmt.Errorf("ada-002 no longer distinguishes between text/code, use 'text' for all use cases")
196
	}
197

198
	return nil
199
}
200

201
func (cs *classSettings) validateAzureConfig(resourceName, deploymentId, apiVersion string) error {
202
	if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") {
203
		return fmt.Errorf("both resourceName and deploymentId must be provided")
204
	}
205
	if !basesettings.ValidateSetting[string](apiVersion, availableApiVersions) {
206
		return errors.Errorf("wrong Azure OpenAI apiVersion setting, available api versions are: %v", availableApiVersions)
207
	}
208
	return nil
209
}
210

211
func PickDefaultModelVersion(model, docType string) string {
212
	for i := range availableV3Models {
213
		if model == availableV3Models[i] {
214
			return ""
215
		}
216
	}
217
	if model == "ada" && docType == "text" {
218
		return "002"
219
	}
220
	// for all other combinations stick with "001"
221
	return "001"
222
}
223

224
func PickDefaultDimensions(model string) *int64 {
225
	if model == TextEmbedding3Small {
226
		return &TextEmbedding3SmallDefaultDimensions
227
	}
228
	if model == TextEmbedding3Large {
229
		return &TextEmbedding3LargeDefaultDimensions
230
	}
231
	return nil
232
}
233

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.