weaviate
164 строки · 4.7 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package modaws
13
14import (
15"context"
16"net/http"
17"os"
18"time"
19
20"github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase"
21
22"github.com/weaviate/weaviate/usecases/modulecomponents/batch"
23
24"github.com/pkg/errors"
25"github.com/sirupsen/logrus"
26"github.com/weaviate/weaviate/entities/models"
27"github.com/weaviate/weaviate/entities/modulecapabilities"
28"github.com/weaviate/weaviate/entities/moduletools"
29"github.com/weaviate/weaviate/modules/text2vec-aws/clients"
30"github.com/weaviate/weaviate/modules/text2vec-aws/vectorizer"
31"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
32)
33
34const Name = "text2vec-aws"
35
36func New() *AwsModule {
37return &AwsModule{}
38}
39
40type AwsModule struct {
41vectorizer text2vecbase.TextVectorizer
42metaProvider text2vecbase.MetaProvider
43graphqlProvider modulecapabilities.GraphQLArguments
44searcher modulecapabilities.Searcher
45nearTextTransformer modulecapabilities.TextTransform
46logger logrus.FieldLogger
47additionalPropertiesProvider modulecapabilities.AdditionalProperties
48}
49
50func (m *AwsModule) Name() string {
51return "text2vec-aws"
52}
53
54func (m *AwsModule) Type() modulecapabilities.ModuleType {
55return modulecapabilities.Text2Vec
56}
57
58func (m *AwsModule) Init(ctx context.Context,
59params moduletools.ModuleInitParams,
60) error {
61m.logger = params.GetLogger()
62
63if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
64return errors.Wrap(err, "init vectorizer")
65}
66
67if err := m.initAdditionalPropertiesProvider(); err != nil {
68return errors.Wrap(err, "init additional properties provider")
69}
70
71return nil
72}
73
74func (m *AwsModule) InitExtension(modules []modulecapabilities.Module) error {
75for _, module := range modules {
76if module.Name() == m.Name() {
77continue
78}
79if arg, ok := module.(modulecapabilities.TextTransformers); ok {
80if arg != nil && arg.TextTransformers() != nil {
81m.nearTextTransformer = arg.TextTransformers()["nearText"]
82}
83}
84}
85
86if err := m.initNearText(); err != nil {
87return errors.Wrap(err, "init graphql provider")
88}
89return nil
90}
91
92func (m *AwsModule) initVectorizer(ctx context.Context, timeout time.Duration,
93logger logrus.FieldLogger,
94) error {
95awsAccessKey := m.getAWSAccessKey()
96awsSecret := m.getAWSSecretAccessKey()
97client := clients.New(awsAccessKey, awsSecret, timeout, logger)
98
99m.vectorizer = vectorizer.New(client)
100m.metaProvider = client
101
102return nil
103}
104
105func (m *AwsModule) getAWSAccessKey() string {
106if os.Getenv("AWS_ACCESS_KEY_ID") != "" {
107return os.Getenv("AWS_ACCESS_KEY_ID")
108}
109return os.Getenv("AWS_ACCESS_KEY")
110}
111
112func (m *AwsModule) getAWSSecretAccessKey() string {
113if os.Getenv("AWS_SECRET_ACCESS_KEY") != "" {
114return os.Getenv("AWS_SECRET_ACCESS_KEY")
115}
116return os.Getenv("AWS_SECRET_KEY")
117}
118
119func (m *AwsModule) initAdditionalPropertiesProvider() error {
120m.additionalPropertiesProvider = additional.NewText2VecProvider()
121return nil
122}
123
124func (m *AwsModule) RootHandler() http.Handler {
125// TODO: remove once this is a capability interface
126return nil
127}
128
129func (m *AwsModule) VectorizeObject(ctx context.Context,
130obj *models.Object, cfg moduletools.ClassConfig,
131) ([]float32, models.AdditionalProperties, error) {
132return m.vectorizer.Object(ctx, obj, cfg)
133}
134
135func (m *AwsModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
136return batch.VectorizeBatch(ctx, objs, skipObject, cfg, m.logger, m.vectorizer.Object)
137}
138
139func (m *AwsModule) MetaInfo() (map[string]interface{}, error) {
140return m.metaProvider.MetaInfo()
141}
142
143func (m *AwsModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
144return m.additionalPropertiesProvider.AdditionalProperties()
145}
146
147func (m *AwsModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
148return true, nil, nil
149}
150
151func (m *AwsModule) VectorizeInput(ctx context.Context,
152input string, cfg moduletools.ClassConfig,
153) ([]float32, error) {
154return m.vectorizer.Texts(ctx, []string{input}, cfg)
155}
156
157// verify we implement the modules.Module interface
158var (
159_ = modulecapabilities.Module(New())
160_ = modulecapabilities.Vectorizer(New())
161_ = modulecapabilities.MetaProvider(New())
162_ = modulecapabilities.Searcher(New())
163_ = modulecapabilities.GraphQLArguments(New())
164)
165