weaviate

Форк
0
164 строки · 4.7 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package modaws
13

14
import (
15
	"context"
16
	"net/http"
17
	"os"
18
	"time"
19

20
	"github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase"
21

22
	"github.com/weaviate/weaviate/usecases/modulecomponents/batch"
23

24
	"github.com/pkg/errors"
25
	"github.com/sirupsen/logrus"
26
	"github.com/weaviate/weaviate/entities/models"
27
	"github.com/weaviate/weaviate/entities/modulecapabilities"
28
	"github.com/weaviate/weaviate/entities/moduletools"
29
	"github.com/weaviate/weaviate/modules/text2vec-aws/clients"
30
	"github.com/weaviate/weaviate/modules/text2vec-aws/vectorizer"
31
	"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
32
)
33

34
const Name = "text2vec-aws"
35

36
func New() *AwsModule {
37
	return &AwsModule{}
38
}
39

40
type AwsModule struct {
41
	vectorizer                   text2vecbase.TextVectorizer
42
	metaProvider                 text2vecbase.MetaProvider
43
	graphqlProvider              modulecapabilities.GraphQLArguments
44
	searcher                     modulecapabilities.Searcher
45
	nearTextTransformer          modulecapabilities.TextTransform
46
	logger                       logrus.FieldLogger
47
	additionalPropertiesProvider modulecapabilities.AdditionalProperties
48
}
49

50
func (m *AwsModule) Name() string {
51
	return "text2vec-aws"
52
}
53

54
func (m *AwsModule) Type() modulecapabilities.ModuleType {
55
	return modulecapabilities.Text2Vec
56
}
57

58
func (m *AwsModule) Init(ctx context.Context,
59
	params moduletools.ModuleInitParams,
60
) error {
61
	m.logger = params.GetLogger()
62

63
	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
64
		return errors.Wrap(err, "init vectorizer")
65
	}
66

67
	if err := m.initAdditionalPropertiesProvider(); err != nil {
68
		return errors.Wrap(err, "init additional properties provider")
69
	}
70

71
	return nil
72
}
73

74
func (m *AwsModule) InitExtension(modules []modulecapabilities.Module) error {
75
	for _, module := range modules {
76
		if module.Name() == m.Name() {
77
			continue
78
		}
79
		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
80
			if arg != nil && arg.TextTransformers() != nil {
81
				m.nearTextTransformer = arg.TextTransformers()["nearText"]
82
			}
83
		}
84
	}
85

86
	if err := m.initNearText(); err != nil {
87
		return errors.Wrap(err, "init graphql provider")
88
	}
89
	return nil
90
}
91

92
func (m *AwsModule) initVectorizer(ctx context.Context, timeout time.Duration,
93
	logger logrus.FieldLogger,
94
) error {
95
	awsAccessKey := m.getAWSAccessKey()
96
	awsSecret := m.getAWSSecretAccessKey()
97
	client := clients.New(awsAccessKey, awsSecret, timeout, logger)
98

99
	m.vectorizer = vectorizer.New(client)
100
	m.metaProvider = client
101

102
	return nil
103
}
104

105
func (m *AwsModule) getAWSAccessKey() string {
106
	if os.Getenv("AWS_ACCESS_KEY_ID") != "" {
107
		return os.Getenv("AWS_ACCESS_KEY_ID")
108
	}
109
	return os.Getenv("AWS_ACCESS_KEY")
110
}
111

112
func (m *AwsModule) getAWSSecretAccessKey() string {
113
	if os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
114
		return os.Getenv("AWS_SECRET_ACCESS_KEY")
115
	}
116
	return os.Getenv("AWS_SECRET_KEY")
117
}
118

119
func (m *AwsModule) initAdditionalPropertiesProvider() error {
120
	m.additionalPropertiesProvider = additional.NewText2VecProvider()
121
	return nil
122
}
123

124
func (m *AwsModule) RootHandler() http.Handler {
125
	// TODO: remove once this is a capability interface
126
	return nil
127
}
128

129
func (m *AwsModule) VectorizeObject(ctx context.Context,
130
	obj *models.Object, cfg moduletools.ClassConfig,
131
) ([]float32, models.AdditionalProperties, error) {
132
	return m.vectorizer.Object(ctx, obj, cfg)
133
}
134

135
func (m *AwsModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
136
	return batch.VectorizeBatch(ctx, objs, skipObject, cfg, m.logger, m.vectorizer.Object)
137
}
138

139
func (m *AwsModule) MetaInfo() (map[string]interface{}, error) {
140
	return m.metaProvider.MetaInfo()
141
}
142

143
func (m *AwsModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
144
	return m.additionalPropertiesProvider.AdditionalProperties()
145
}
146

147
func (m *AwsModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
148
	return true, nil, nil
149
}
150

151
func (m *AwsModule) VectorizeInput(ctx context.Context,
152
	input string, cfg moduletools.ClassConfig,
153
) ([]float32, error) {
154
	return m.vectorizer.Texts(ctx, []string{input}, cfg)
155
}
156

157
// verify we implement the modules.Module interface
158
var (
159
	_ = modulecapabilities.Module(New())
160
	_ = modulecapabilities.Vectorizer(New())
161
	_ = modulecapabilities.MetaProvider(New())
162
	_ = modulecapabilities.Searcher(New())
163
	_ = modulecapabilities.GraphQLArguments(New())
164
)
165

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.