weaviate
232 строки · 6.4 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package ent
13
14import (
15"fmt"
16
17"github.com/pkg/errors"
18
19"github.com/weaviate/weaviate/entities/models"
20"github.com/weaviate/weaviate/entities/moduletools"
21basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings"
22)
23
24const (
25DefaultOpenAIDocumentType = "text"
26DefaultOpenAIModel = "ada"
27DefaultVectorizeClassName = true
28DefaultPropertyIndexed = true
29DefaultVectorizePropertyName = false
30DefaultBaseURL = "https://api.openai.com"
31DefaultApiVersion = "2024-02-01"
32)
33
34const (
35TextEmbedding3Small = "text-embedding-3-small"
36TextEmbedding3Large = "text-embedding-3-large"
37)
38
39var (
40TextEmbedding3SmallDefaultDimensions int64 = 1536
41TextEmbedding3LargeDefaultDimensions int64 = 3072
42)
43
44var availableOpenAITypes = []string{"text", "code"}
45
46var availableV3Models = []string{
47// new v3 models
48TextEmbedding3Small,
49TextEmbedding3Large,
50}
51
52var availableV3ModelsDimensions = map[string][]int64{
53TextEmbedding3Small: {512, TextEmbedding3SmallDefaultDimensions},
54TextEmbedding3Large: {256, 1024, TextEmbedding3LargeDefaultDimensions},
55}
56
57var availableOpenAIModels = []string{
58"ada", // supports 001 and 002
59"babbage", // only supports 001
60"curie", // only supports 001
61"davinci", // only supports 001
62}
63
64var availableApiVersions = []string{
65"2022-12-01",
66"2023-03-15-preview",
67"2023-05-15",
68"2023-06-01-preview",
69"2023-07-01-preview",
70"2023-08-01-preview",
71"2023-09-01-preview",
72"2023-12-01-preview",
73"2024-02-15-preview",
74"2024-03-01-preview",
75"2024-02-01",
76}
77
78type classSettings struct {
79basesettings.BaseClassSettings
80cfg moduletools.ClassConfig
81}
82
83func NewClassSettings(cfg moduletools.ClassConfig) *classSettings {
84return &classSettings{cfg: cfg, BaseClassSettings: *basesettings.NewBaseClassSettings(cfg)}
85}
86
87func (cs *classSettings) Model() string {
88return cs.BaseClassSettings.GetPropertyAsString("model", DefaultOpenAIModel)
89}
90
91func (cs *classSettings) Type() string {
92return cs.BaseClassSettings.GetPropertyAsString("type", DefaultOpenAIDocumentType)
93}
94
95func (cs *classSettings) ModelVersion() string {
96defaultVersion := PickDefaultModelVersion(cs.Model(), cs.Type())
97return cs.BaseClassSettings.GetPropertyAsString("modelVersion", defaultVersion)
98}
99
100func (cs *classSettings) ResourceName() string {
101return cs.BaseClassSettings.GetPropertyAsString("resourceName", "")
102}
103
104func (cs *classSettings) BaseURL() string {
105return cs.BaseClassSettings.GetPropertyAsString("baseURL", DefaultBaseURL)
106}
107
108func (cs *classSettings) DeploymentID() string {
109return cs.BaseClassSettings.GetPropertyAsString("deploymentId", "")
110}
111
112func (cs *classSettings) ApiVersion() string {
113return cs.BaseClassSettings.GetPropertyAsString("apiVersion", DefaultApiVersion)
114}
115
116func (cs *classSettings) IsAzure() bool {
117return cs.ResourceName() != "" && cs.DeploymentID() != ""
118}
119
120func (cs *classSettings) Dimensions() *int64 {
121defaultValue := PickDefaultDimensions(cs.Model())
122return cs.BaseClassSettings.GetPropertyAsInt64("dimensions", defaultValue)
123}
124
125func (cs *classSettings) Validate(class *models.Class) error {
126if err := cs.BaseClassSettings.Validate(class); err != nil {
127return err
128}
129
130docType := cs.Type()
131if !basesettings.ValidateSetting[string](docType, availableOpenAITypes) {
132return errors.Errorf("wrong OpenAI type name, available model names are: %v", availableOpenAITypes)
133}
134
135availableModels := append(availableOpenAIModels, availableV3Models...)
136model := cs.Model()
137if !basesettings.ValidateSetting[string](model, availableModels) {
138return errors.Errorf("wrong OpenAI model name, available model names are: %v", availableModels)
139}
140
141dimensions := cs.Dimensions()
142if dimensions != nil {
143if !basesettings.ValidateSetting[string](model, availableV3Models) {
144return errors.Errorf("dimensions setting can only be used with V3 embedding models: %v", availableV3Models)
145}
146availableDimensions := availableV3ModelsDimensions[model]
147if !basesettings.ValidateSetting[int64](*dimensions, availableDimensions) {
148return errors.Errorf("wrong dimensions setting for %s model, available dimensions are: %v", model, availableDimensions)
149}
150}
151
152version := cs.ModelVersion()
153if err := cs.validateModelVersion(version, model, docType); err != nil {
154return err
155}
156
157err := cs.validateAzureConfig(cs.ResourceName(), cs.DeploymentID(), cs.ApiVersion())
158if err != nil {
159return err
160}
161
162return nil
163}
164
165func (cs *classSettings) validateModelVersion(version, model, docType string) error {
166for i := range availableV3Models {
167if model == availableV3Models[i] {
168return nil
169}
170}
171
172if version == "001" {
173// no restrictions
174return nil
175}
176
177if version == "002" {
178// only ada/davinci 002
179if model != "ada" && model != "davinci" {
180return fmt.Errorf("unsupported version %s", version)
181}
182}
183
184if version == "003" && model != "davinci" {
185// only davinci 003
186return fmt.Errorf("unsupported version %s", version)
187}
188
189if version != "002" && version != "003" {
190// all other fallback
191return fmt.Errorf("model %s is only available in version 001", model)
192}
193
194if docType != "text" {
195return fmt.Errorf("ada-002 no longer distinguishes between text/code, use 'text' for all use cases")
196}
197
198return nil
199}
200
201func (cs *classSettings) validateAzureConfig(resourceName, deploymentId, apiVersion string) error {
202if (resourceName == "" && deploymentId != "") || (resourceName != "" && deploymentId == "") {
203return fmt.Errorf("both resourceName and deploymentId must be provided")
204}
205if !basesettings.ValidateSetting[string](apiVersion, availableApiVersions) {
206return errors.Errorf("wrong Azure OpenAI apiVersion setting, available api versions are: %v", availableApiVersions)
207}
208return nil
209}
210
211func PickDefaultModelVersion(model, docType string) string {
212for i := range availableV3Models {
213if model == availableV3Models[i] {
214return ""
215}
216}
217if model == "ada" && docType == "text" {
218return "002"
219}
220// for all other combinations stick with "001"
221return "001"
222}
223
224func PickDefaultDimensions(model string) *int64 {
225if model == TextEmbedding3Small {
226return &TextEmbedding3SmallDefaultDimensions
227}
228if model == TextEmbedding3Large {
229return &TextEmbedding3LargeDefaultDimensions
230}
231return nil
232}
233