weaviate

Форк
0
91 строка · 2.6 Кб
1
//                           _       _
2
// __      _____  __ ___   ___  __ _| |_ ___
3
// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4
//  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
5
//   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6
//
7
//  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8
//
9
//  CONTACT: hello@weaviate.io
10
//
11

12
package modrerankervoyageai
13

14
import (
15
	"context"
16
	"net/http"
17
	"os"
18
	"time"
19

20
	"github.com/pkg/errors"
21
	"github.com/sirupsen/logrus"
22
	"github.com/weaviate/weaviate/entities/modulecapabilities"
23
	"github.com/weaviate/weaviate/entities/moduletools"
24
	"github.com/weaviate/weaviate/modules/reranker-voyageai/clients"
25
	rerankeradditional "github.com/weaviate/weaviate/usecases/modulecomponents/additional"
26
	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
27
)
28

29
const Name = "reranker-voyageai"
30

31
func New() *ReRankerVoyageAIModule {
32
	return &ReRankerVoyageAIModule{}
33
}
34

35
type ReRankerVoyageAIModule struct {
36
	reranker                     ReRankerVoyageAIClient
37
	additionalPropertiesProvider modulecapabilities.AdditionalProperties
38
}
39

40
type ReRankerVoyageAIClient interface {
41
	Rank(ctx context.Context, query string, documents []string, cfg moduletools.ClassConfig) (*ent.RankResult, error)
42
	MetaInfo() (map[string]interface{}, error)
43
}
44

45
func (m *ReRankerVoyageAIModule) Name() string {
46
	return Name
47
}
48

49
func (m *ReRankerVoyageAIModule) Type() modulecapabilities.ModuleType {
50
	return modulecapabilities.Text2TextReranker
51
}
52

53
func (m *ReRankerVoyageAIModule) Init(ctx context.Context,
54
	params moduletools.ModuleInitParams,
55
) error {
56
	if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
57
		return errors.Wrap(err, "init cross encoder")
58
	}
59

60
	return nil
61
}
62

63
func (m *ReRankerVoyageAIModule) initAdditional(ctx context.Context, timeout time.Duration,
64
	logger logrus.FieldLogger,
65
) error {
66
	apiKey := os.Getenv("VOYAGEAI_APIKEY")
67
	client := clients.New(apiKey, timeout, logger)
68
	m.reranker = client
69
	m.additionalPropertiesProvider = rerankeradditional.NewRankerProvider(m.reranker)
70
	return nil
71
}
72

73
func (m *ReRankerVoyageAIModule) MetaInfo() (map[string]interface{}, error) {
74
	return m.reranker.MetaInfo()
75
}
76

77
func (m *ReRankerVoyageAIModule) RootHandler() http.Handler {
78
	// TODO: remove once this is a capability interface
79
	return nil
80
}
81

82
func (m *ReRankerVoyageAIModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
83
	return m.additionalPropertiesProvider.AdditionalProperties()
84
}
85

86
// verify we implement the modules.Module interface
87
var (
88
	_ = modulecapabilities.Module(New())
89
	_ = modulecapabilities.AdditionalProperties(New())
90
	_ = modulecapabilities.MetaProvider(New())
91
)
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.