weaviate

Форк
0
152 строки · 4.4 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package modpalm
13

14
import (
15
	"context"
16
	"net/http"
17
	"os"
18
	"time"
19

20
	"github.com/weaviate/weaviate/usecases/modulecomponents/text2vecbase"
21

22
	"github.com/weaviate/weaviate/usecases/modulecomponents/batch"
23

24
	"github.com/pkg/errors"
25
	"github.com/sirupsen/logrus"
26
	"github.com/weaviate/weaviate/entities/models"
27
	"github.com/weaviate/weaviate/entities/modulecapabilities"
28
	"github.com/weaviate/weaviate/entities/moduletools"
29
	"github.com/weaviate/weaviate/modules/text2vec-palm/clients"
30
	"github.com/weaviate/weaviate/modules/text2vec-palm/vectorizer"
31
	"github.com/weaviate/weaviate/usecases/modulecomponents/additional"
32
)
33

34
const Name = "text2vec-palm"
35

36
func New() *PalmModule {
37
	return &PalmModule{}
38
}
39

40
type PalmModule struct {
41
	vectorizer                   text2vecbase.TextVectorizer
42
	metaProvider                 text2vecbase.MetaProvider
43
	graphqlProvider              modulecapabilities.GraphQLArguments
44
	searcher                     modulecapabilities.Searcher
45
	nearTextTransformer          modulecapabilities.TextTransform
46
	logger                       logrus.FieldLogger
47
	additionalPropertiesProvider modulecapabilities.AdditionalProperties
48
}
49

50
func (m *PalmModule) Name() string {
51
	return "text2vec-palm"
52
}
53

54
func (m *PalmModule) Type() modulecapabilities.ModuleType {
55
	return modulecapabilities.Text2Vec
56
}
57

58
func (m *PalmModule) Init(ctx context.Context,
59
	params moduletools.ModuleInitParams,
60
) error {
61
	m.logger = params.GetLogger()
62

63
	if err := m.initVectorizer(ctx, params.GetConfig().ModuleHttpClientTimeout, m.logger); err != nil {
64
		return errors.Wrap(err, "init vectorizer")
65
	}
66

67
	if err := m.initAdditionalPropertiesProvider(); err != nil {
68
		return errors.Wrap(err, "init additional properties provider")
69
	}
70

71
	return nil
72
}
73

74
func (m *PalmModule) InitExtension(modules []modulecapabilities.Module) error {
75
	for _, module := range modules {
76
		if module.Name() == m.Name() {
77
			continue
78
		}
79
		if arg, ok := module.(modulecapabilities.TextTransformers); ok {
80
			if arg != nil && arg.TextTransformers() != nil {
81
				m.nearTextTransformer = arg.TextTransformers()["nearText"]
82
			}
83
		}
84
	}
85

86
	if err := m.initNearText(); err != nil {
87
		return errors.Wrap(err, "init graphql provider")
88
	}
89
	return nil
90
}
91

92
func (m *PalmModule) initVectorizer(ctx context.Context, timeout time.Duration,
93
	logger logrus.FieldLogger,
94
) error {
95
	apiKey := os.Getenv("GOOGLE_APIKEY")
96
	if apiKey == "" {
97
		apiKey = os.Getenv("PALM_APIKEY")
98
	}
99
	client := clients.New(apiKey, timeout, logger)
100

101
	m.vectorizer = vectorizer.New(client)
102
	m.metaProvider = client
103

104
	return nil
105
}
106

107
func (m *PalmModule) initAdditionalPropertiesProvider() error {
108
	m.additionalPropertiesProvider = additional.NewText2VecProvider()
109
	return nil
110
}
111

112
func (m *PalmModule) RootHandler() http.Handler {
113
	// TODO: remove once this is a capability interface
114
	return nil
115
}
116

117
func (m *PalmModule) VectorizeObject(ctx context.Context,
118
	obj *models.Object, cfg moduletools.ClassConfig,
119
) ([]float32, models.AdditionalProperties, error) {
120
	return m.vectorizer.Object(ctx, obj, cfg)
121
}
122

123
func (m *PalmModule) VectorizeBatch(ctx context.Context, objs []*models.Object, skipObject []bool, cfg moduletools.ClassConfig) ([][]float32, []models.AdditionalProperties, map[int]error) {
124
	return batch.VectorizeBatch(ctx, objs, skipObject, cfg, m.logger, m.vectorizer.Object)
125
}
126

127
func (m *PalmModule) MetaInfo() (map[string]interface{}, error) {
128
	return m.metaProvider.MetaInfo()
129
}
130

131
func (m *PalmModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
132
	return m.additionalPropertiesProvider.AdditionalProperties()
133
}
134

135
func (m *PalmModule) VectorizeInput(ctx context.Context,
136
	input string, cfg moduletools.ClassConfig,
137
) ([]float32, error) {
138
	return m.vectorizer.Texts(ctx, []string{input}, cfg)
139
}
140

141
func (m *PalmModule) VectorizableProperties(cfg moduletools.ClassConfig) (bool, []string, error) {
142
	return true, nil, nil
143
}
144

145
// verify we implement the modules.Module interface
146
var (
147
	_ = modulecapabilities.Module(New())
148
	_ = modulecapabilities.Vectorizer(New())
149
	_ = modulecapabilities.MetaProvider(New())
150
	_ = modulecapabilities.Searcher(New())
151
	_ = modulecapabilities.GraphQLArguments(New())
152
)
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.