weaviate

Форк
0
159 строк · 4.6 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package modclip
13

14
import (
15
	"context"
16
	"net/http"
17
	"os"
18
	"time"
19

20
	"github.com/weaviate/weaviate/usecases/modulecomponents/batch"
21

22
	"github.com/pkg/errors"
23
	"github.com/sirupsen/logrus"
24
	"github.com/weaviate/weaviate/entities/models"
25
	"github.com/weaviate/weaviate/entities/modulecapabilities"
26
	"github.com/weaviate/weaviate/entities/moduletools"
27
	"github.com/weaviate/weaviate/modules/multi2vec-clip/clients"
28
	"github.com/weaviate/weaviate/modules/multi2vec-clip/vectorizer"
29
)
30

31
func New() *ClipModule {
32
	return &ClipModule{}
33
}
34

35
type ClipModule struct {
36
	imageVectorizer          imageVectorizer
37
	nearImageGraphqlProvider modulecapabilities.GraphQLArguments
38
	nearImageSearcher        modulecapabilities.Searcher
39
	textVectorizer           textVectorizer
40
	nearTextGraphqlProvider  modulecapabilities.GraphQLArguments
41
	nearTextSearcher         modulecapabilities.Searcher
42
	nearTextTransformer      modulecapabilities.TextTransform
43
	metaClient               metaClient
44
	logger                   logrus.FieldLogger
45
}
46

47
type metaClient interface {
48
	MetaInfo() (map[string]interface{}, error)
49
}
50

51
type imageVectorizer interface {
52
	Object(ctx context.Context, obj *models.Object, cfg moduletools.ClassConfig) ([]float32, models.AdditionalProperties, error)
53
	VectorizeImage(ctx context.Context, id, image string, cfg moduletools.ClassConfig) ([]float32, error)
54
}
55

56
type textVectorizer interface {
57
	Texts(ctx context.Context, input []string,
58
		cfg moduletools.ClassConfig) ([]float32, error)
59
}
60

61
func (m *ClipModule) Name() string {
62
	return "multi2vec-clip"
63
}
64

65
func (m *ClipModule) Type() modulecapabilities.ModuleType {
66
	return modulecapabilities.Multi2Vec
67
}
68

69
func (m *ClipModule) Init(ctx context.Context,
70
	params moduletools.ModuleInitParams,
71
) error {
72
	m.logger = params.GetLogger()
73
	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
74
		return errors.Wrap(err, "init vectorizer")
75
	}
76

77
	if err := m.initNearImage(); err != nil {
78
		return errors.Wrap(err, "init near text")
79
	}
80

81
	return nil
82
}
83

84
func (m *ClipModule) InitExtension(modules []modulecapabilities.Module) error {
85
	for _, module := range modules {
86
		if module.Name() == m.Name() {
87
			continue
88
		}
89
		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
90
			if arg != nil && arg.TextTransformers() != nil {
91
				m.nearTextTransformer = arg.TextTransformers()["nearText"]
92
			}
93
		}
94
	}
95

96
	if err := m.initNearText(); err != nil {
97
		return errors.Wrap(err, "init near text")
98
	}
99

100
	return nil
101
}
102

103
func (m *ClipModule) initVectorizer(ctx context.Context, timeout time.Duration,
104
	logger logrus.FieldLogger,
105
) error {
106
	uri := os.Getenv("CLIP_INFERENCE_API")
107
	if uri == "" {
108
		return errors.Errorf("required variable CLIP_INFERENCE_API is not set")
109
	}
110

111
	client := clients.New(uri, timeout, logger)
112
	if err := client.WaitForStartup(ctx, 1*time.Second); err != nil {
113
		return errors.Wrap(err, "init remote vectorizer")
114
	}
115

116
	m.imageVectorizer = vectorizer.New(client)
117
	m.textVectorizer = vectorizer.New(client)
118
	m.metaClient = client
119

120
	return nil
121
}
122

123
func (m *ClipModule) RootHandler() http.Handler {
124
	// TODO: remove once this is a capability interface
125
	return nil
126
}
127

128
func (m *ClipModule) VectorizeObject(ctx context.Context,
129
	obj *models.Object, cfg moduletools.ClassConfig,
130
) ([]float32, models.AdditionalProperties, error) {
131
	return m.imageVectorizer.Object(ctx, obj, cfg)
132
}
133

134
func (m *ClipModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
135
	return batch.VectorizeBatch(ctx, objs, skipObject, cfg, m.logger, m.imageVectorizer.Object)
136
}
137

138
func (m *ClipModule) MetaInfo() (map[string]interface{}, error) {
139
	return m.metaClient.MetaInfo()
140
}
141

142
func (m *ClipModule) VectorizeInput(ctx context.Context,
143
	input string, cfg moduletools.ClassConfig,
144
) ([]float32, error) {
145
	return m.textVectorizer.Texts(ctx, []string{input}, cfg)
146
}
147

148
func (m *ClipModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
149
	ichek := vectorizer.NewClassSettings(cfg)
150
	mediaProps, err := ichek.Properties()
151
	return false, mediaProps, err
152
}
153

154
// verify we implement the modules.Module interface
155
var (
156
	_ = modulecapabilities.Module(New())
157
	_ = modulecapabilities.Vectorizer(New())
158
	_ = modulecapabilities.InputVectorizer(New())
159
)
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.