weaviate
152 строки · 4.4 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package modpalm
13
14import (
15"context"
16"net/http"
17"os"
18"time"
19
20"github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase"
21
22"github.com/weaviate/weaviate/usecases/modulecomponents/batch"
23
24"github.com/pkg/errors"
25"github.com/sirupsen/logrus"
26"github.com/weaviate/weaviate/entities/models"
27"github.com/weaviate/weaviate/entities/modulecapabilities"
28"github.com/weaviate/weaviate/entities/moduletools"
29"github.com/weaviate/weaviate/modules/text2vec-palm/clients"
30"github.com/weaviate/weaviate/modules/text2vec-palm/vectorizer"
31"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
32)
33
34const Name = "text2vec-palm"
35
36func New() *PalmModule {
37return &PalmModule{}
38}
39
40type PalmModule struct {
41vectorizer text2vecbase.TextVectorizer
42metaProvider text2vecbase.MetaProvider
43graphqlProvider modulecapabilities.GraphQLArguments
44searcher modulecapabilities.Searcher
45nearTextTransformer modulecapabilities.TextTransform
46logger logrus.FieldLogger
47additionalPropertiesProvider modulecapabilities.AdditionalProperties
48}
49
50func (m *PalmModule) Name() string {
51return "text2vec-palm"
52}
53
54func (m *PalmModule) Type() modulecapabilities.ModuleType {
55return modulecapabilities.Text2Vec
56}
57
58func (m *PalmModule) Init(ctx context.Context,
59params moduletools.ModuleInitParams,
60) error {
61m.logger = params.GetLogger()
62
63if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
64return errors.Wrap(err, "init vectorizer")
65}
66
67if err := m.initAdditionalPropertiesProvider(); err != nil {
68return errors.Wrap(err, "init additional properties provider")
69}
70
71return nil
72}
73
74func (m *PalmModule) InitExtension(modules []modulecapabilities.Module) error {
75for _, module := range modules {
76if module.Name() == m.Name() {
77continue
78}
79if arg, ok := module.(modulecapabilities.TextTransformers); ok {
80if arg != nil && arg.TextTransformers() != nil {
81m.nearTextTransformer = arg.TextTransformers()["nearText"]
82}
83}
84}
85
86if err := m.initNearText(); err != nil {
87return errors.Wrap(err, "init graphql provider")
88}
89return nil
90}
91
92func (m *PalmModule) initVectorizer(ctx context.Context, timeout time.Duration,
93logger logrus.FieldLogger,
94) error {
95apiKey := os.Getenv("GOOGLE_APIKEY")
96if apiKey == "" {
97apiKey = os.Getenv("PALM_APIKEY")
98}
99client := clients.New(apiKey, timeout, logger)
100
101m.vectorizer = vectorizer.New(client)
102m.metaProvider = client
103
104return nil
105}
106
107func (m *PalmModule) initAdditionalPropertiesProvider() error {
108m.additionalPropertiesProvider = additional.NewText2VecProvider()
109return nil
110}
111
112func (m *PalmModule) RootHandler() http.Handler {
113// TODO: remove once this is a capability interface
114return nil
115}
116
117func (m *PalmModule) VectorizeObject(ctx context.Context,
118obj *models.Object, cfg moduletools.ClassConfig,
119) ([]float32, models.AdditionalProperties, error) {
120return m.vectorizer.Object(ctx, obj, cfg)
121}
122
123func (m *PalmModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
124return batch.VectorizeBatch(ctx, objs, skipObject, cfg, m.logger, m.vectorizer.Object)
125}
126
127func (m *PalmModule) MetaInfo() (map[string]interface{}, error) {
128return m.metaProvider.MetaInfo()
129}
130
131func (m *PalmModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
132return m.additionalPropertiesProvider.AdditionalProperties()
133}
134
135func (m *PalmModule) VectorizeInput(ctx context.Context,
136input string, cfg moduletools.ClassConfig,
137) ([]float32, error) {
138return m.vectorizer.Texts(ctx, []string{input}, cfg)
139}
140
141func (m *PalmModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
142return true, nil, nil
143}
144
145// verify we implement the modules.Module interface
146var (
147_ = modulecapabilities.Module(New())
148_ = modulecapabilities.Vectorizer(New())
149_ = modulecapabilities.MetaProvider(New())
150_ = modulecapabilities.Searcher(New())
151_ = modulecapabilities.GraphQLArguments(New())
152)
153