weaviate

Форк
0
186 строк · 5.8 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package modtransformers
13

14
import (
15
	"context"
16
	"net/http"
17
	"os"
18
	"time"
19

20
	"github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase"
21

22
	"github.com/pkg/errors"
23
	"github.com/sirupsen/logrus"
24
	"github.com/weaviate/weaviate/entities/models"
25
	"github.com/weaviate/weaviate/entities/modulecapabilities"
26
	"github.com/weaviate/weaviate/entities/moduletools"
27
	"github.com/weaviate/weaviate/modules/text2vec-transformers/clients"
28
	"github.com/weaviate/weaviate/modules/text2vec-transformers/vectorizer"
29
	"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
30
)
31

32
func New() *TransformersModule {
33
	return &TransformersModule{}
34
}
35

36
type TransformersModule struct {
37
	vectorizer                   text2vecbase.TextVectorizer
38
	metaProvider                 text2vecbase.MetaProvider
39
	graphqlProvider              modulecapabilities.GraphQLArguments
40
	searcher                     modulecapabilities.Searcher
41
	nearTextTransformer          modulecapabilities.TextTransform
42
	logger                       logrus.FieldLogger
43
	additionalPropertiesProvider modulecapabilities.AdditionalProperties
44
}
45

46
func (m *TransformersModule) Name() string {
47
	return "text2vec-transformers"
48
}
49

50
func (m *TransformersModule) Type() modulecapabilities.ModuleType {
51
	return modulecapabilities.Text2Vec
52
}
53

54
func (m *TransformersModule) Init(ctx context.Context,
55
	params moduletools.ModuleInitParams,
56
) error {
57
	m.logger = params.GetLogger()
58

59
	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
60
		return errors.Wrap(err, "init vectorizer")
61
	}
62

63
	if err := m.initAdditionalPropertiesProvider(); err != nil {
64
		return errors.Wrap(err, "init additional properties provider")
65
	}
66

67
	return nil
68
}
69

70
func (m *TransformersModule) InitExtension(modules []modulecapabilities.Module) error {
71
	for _, module := range modules {
72
		if module.Name() == m.Name() {
73
			continue
74
		}
75
		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
76
			if arg != nil && arg.TextTransformers() != nil {
77
				m.nearTextTransformer = arg.TextTransformers()["nearText"]
78
			}
79
		}
80
	}
81

82
	if err := m.initNearText(); err != nil {
83
		return errors.Wrap(err, "init graphql provider")
84
	}
85
	return nil
86
}
87

88
func (m *TransformersModule) initVectorizer(ctx context.Context, timeout time.Duration,
89
	logger logrus.FieldLogger,
90
) error {
91
	// TODO: gh-1486 proper config management
92
	uriPassage := os.Getenv("TRANSFORMERS_PASSAGE_INFERENCE_API")
93
	uriQuery := os.Getenv("TRANSFORMERS_QUERY_INFERENCE_API")
94
	uriCommon := os.Getenv("TRANSFORMERS_INFERENCE_API")
95

96
	if uriCommon == "" {
97
		if uriPassage == "" && uriQuery == "" {
98
			return errors.Errorf("required variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API are not set")
99
		}
100
		if uriPassage != "" && uriQuery == "" {
101
			return errors.Errorf("required variable TRANSFORMERS_QUERY_INFERENCE_API is not set")
102
		}
103
		if uriPassage == "" && uriQuery != "" {
104
			return errors.Errorf("required variable TRANSFORMERS_PASSAGE_INFERENCE_API is not set")
105
		}
106
	} else {
107
		if uriPassage != "" || uriQuery != "" {
108
			return errors.Errorf("either variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API should be set")
109
		}
110
		uriPassage = uriCommon
111
		uriQuery = uriCommon
112
	}
113

114
	client := clients.New(uriPassage, uriQuery, timeout, logger)
115
	if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
116
		return errors.Wrap(err, "init remote vectorizer")
117
	}
118

119
	m.vectorizer = vectorizer.New(client)
120
	m.metaProvider = client
121

122
	return nil
123
}
124

125
func (m *TransformersModule) initAdditionalPropertiesProvider() error {
126
	m.additionalPropertiesProvider = additional.NewText2VecProvider()
127
	return nil
128
}
129

130
func (m *TransformersModule) RootHandler() http.Handler {
131
	// TODO: remove once this is a capability interface
132
	return nil
133
}
134

135
func (m *TransformersModule) VectorizeObject(ctx context.Context,
136
	obj *models.Object, cfg moduletools.ClassConfig,
137
) ([]float32, models.AdditionalProperties, error) {
138
	return m.vectorizer.Object(ctx, obj, cfg)
139
}
140

141
// VectorizeBatch is _slower_ if many requests are done in parallel. So do all objects sequentially
142
func (m *TransformersModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
143
	vecs := make([][]float32, len(objs))
144
	addProps := make([]models.AdditionalProperties, len(objs))
145
	// error should be the exception so dont preallocate
146
	errs := make(map[int]error, 0)
147
	for i, obj := range objs {
148
		if skipObject[i] {
149
			continue
150
		}
151
		vec, addProp, err := m.vectorizer.Object(ctx, obj, cfg)
152
		if err != nil {
153
			errs[i] = err
154
			continue
155
		}
156
		addProps[i] = addProp
157
		vecs[i] = vec
158
	}
159

160
	return vecs, addProps, errs
161
}
162

163
func (m *TransformersModule) MetaInfo() (map[string]interface{}, error) {
164
	return m.metaProvider.MetaInfo()
165
}
166

167
func (m *TransformersModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
168
	return m.additionalPropertiesProvider.AdditionalProperties()
169
}
170

171
func (m *TransformersModule) VectorizeInput(ctx context.Context,
172
	input string, cfg moduletools.ClassConfig,
173
) ([]float32, error) {
174
	return m.vectorizer.Texts(ctx, []string{input}, cfg)
175
}
176

177
func (m *TransformersModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
178
	return true, nil, nil
179
}
180

181
// verify we implement the modules.Module interface
182
var (
183
	_ = modulecapabilities.Module(New())
184
	_ = modulecapabilities.Vectorizer(New())
185
	_ = modulecapabilities.MetaProvider(New())
186
)
187

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.