weaviate
186 строк · 5.8 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package modtransformers
13
14import (
15"context"
16"net/http"
17"os"
18"time"
19
20"github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase"
21
22"github.com/pkg/errors"
23"github.com/sirupsen/logrus"
24"github.com/weaviate/weaviate/entities/models"
25"github.com/weaviate/weaviate/entities/modulecapabilities"
26"github.com/weaviate/weaviate/entities/moduletools"
27"github.com/weaviate/weaviate/modules/text2vec-transformers/clients"
28"github.com/weaviate/weaviate/modules/text2vec-transformers/vectorizer"
29"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
30)
31
32func New() *TransformersModule {
33return &TransformersModule{}
34}
35
36type TransformersModule struct {
37vectorizer text2vecbase.TextVectorizer
38metaProvider text2vecbase.MetaProvider
39graphqlProvider modulecapabilities.GraphQLArguments
40searcher modulecapabilities.Searcher
41nearTextTransformer modulecapabilities.TextTransform
42logger logrus.FieldLogger
43additionalPropertiesProvider modulecapabilities.AdditionalProperties
44}
45
46func (m *TransformersModule) Name() string {
47return "text2vec-transformers"
48}
49
50func (m *TransformersModule) Type() modulecapabilities.ModuleType {
51return modulecapabilities.Text2Vec
52}
53
54func (m *TransformersModule) Init(ctx context.Context,
55params moduletools.ModuleInitParams,
56) error {
57m.logger = params.GetLogger()
58
59if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
60return errors.Wrap(err, "init vectorizer")
61}
62
63if err := m.initAdditionalPropertiesProvider(); err != nil {
64return errors.Wrap(err, "init additional properties provider")
65}
66
67return nil
68}
69
70func (m *TransformersModule) InitExtension(modules []modulecapabilities.Module) error {
71for _, module := range modules {
72if module.Name() == m.Name() {
73continue
74}
75if arg, ok := module.(modulecapabilities.TextTransformers); ok {
76if arg != nil && arg.TextTransformers() != nil {
77m.nearTextTransformer = arg.TextTransformers()["nearText"]
78}
79}
80}
81
82if err := m.initNearText(); err != nil {
83return errors.Wrap(err, "init graphql provider")
84}
85return nil
86}
87
88func (m *TransformersModule) initVectorizer(ctx context.Context, timeout time.Duration,
89logger logrus.FieldLogger,
90) error {
91// TODO: gh-1486 proper config management
92uriPassage := os.Getenv("TRANSFORMERS_PASSAGE_INFERENCE_API")
93uriQuery := os.Getenv("TRANSFORMERS_QUERY_INFERENCE_API")
94uriCommon := os.Getenv("TRANSFORMERS_INFERENCE_API")
95
96if uriCommon == "" {
97if uriPassage == "" && uriQuery == "" {
98return errors.Errorf("required variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API are not set")
99}
100if uriPassage != "" && uriQuery == "" {
101return errors.Errorf("required variable TRANSFORMERS_QUERY_INFERENCE_API is not set")
102}
103if uriPassage == "" && uriQuery != "" {
104return errors.Errorf("required variable TRANSFORMERS_PASSAGE_INFERENCE_API is not set")
105}
106} else {
107if uriPassage != "" || uriQuery != "" {
108return errors.Errorf("either variable TRANSFORMERS_INFERENCE_API or both variables TRANSFORMERS_PASSAGE_INFERENCE_API and TRANSFORMERS_QUERY_INFERENCE_API should be set")
109}
110uriPassage = uriCommon
111uriQuery = uriCommon
112}
113
114client := clients.New(uriPassage, uriQuery, timeout, logger)
115if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
116return errors.Wrap(err, "init remote vectorizer")
117}
118
119m.vectorizer = vectorizer.New(client)
120m.metaProvider = client
121
122return nil
123}
124
125func (m *TransformersModule) initAdditionalPropertiesProvider() error {
126m.additionalPropertiesProvider = additional.NewText2VecProvider()
127return nil
128}
129
130func (m *TransformersModule) RootHandler() http.Handler {
131// TODO: remove once this is a capability interface
132return nil
133}
134
135func (m *TransformersModule) VectorizeObject(ctx context.Context,
136obj *models.Object, cfg moduletools.ClassConfig,
137) ([]float32, models.AdditionalProperties, error) {
138return m.vectorizer.Object(ctx, obj, cfg)
139}
140
141// VectorizeBatch is _slower_ if many requests are done in parallel. So do all objects sequentially
142func (m *TransformersModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
143vecs := make([][]float32, len(objs))
144addProps := make([]models.AdditionalProperties, len(objs))
145// error should be the exception so dont preallocate
146errs := make(map[int]error, 0)
147for i, obj := range objs {
148if skipObject[i] {
149continue
150}
151vec, addProp, err := m.vectorizer.Object(ctx, obj, cfg)
152if err != nil {
153errs[i] = err
154continue
155}
156addProps[i] = addProp
157vecs[i] = vec
158}
159
160return vecs, addProps, errs
161}
162
163func (m *TransformersModule) MetaInfo() (map[string]interface{}, error) {
164return m.metaProvider.MetaInfo()
165}
166
167func (m *TransformersModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
168return m.additionalPropertiesProvider.AdditionalProperties()
169}
170
171func (m *TransformersModule) VectorizeInput(ctx context.Context,
172input string, cfg moduletools.ClassConfig,
173) ([]float32, error) {
174return m.vectorizer.Texts(ctx, []string{input}, cfg)
175}
176
177func (m *TransformersModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
178return true, nil, nil
179}
180
181// verify we implement the modules.Module interface
182var (
183_ = modulecapabilities.Module(New())
184_ = modulecapabilities.Vectorizer(New())
185_ = modulecapabilities.MetaProvider(New())
186)
187