weaviate
91 строка · 2.6 Кб
1// _ _
2// __ _____ __ ___ ___ __ _| |_ ___
3// \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
4// \ V V / __/ (_| |\ V /| | (_| | || __/
5// \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
6//
7// Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
8//
9// CONTACT: hello@weaviate.io
10//
11
12package modrerankervoyageai
13
14import (
15"context"
16"net/http"
17"os"
18"time"
19
20"github.com/pkg/errors"
21"github.com/sirupsen/logrus"
22"github.com/weaviate/weaviate/entities/modulecapabilities"
23"github.com/weaviate/weaviate/entities/moduletools"
24"github.com/weaviate/weaviate/modules/reranker-voyageai/clients"
25rerankeradditional "github.com/weaviate/weaviate/usecases/modulecomponents/additional"
26"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
27)
28
29const Name = "reranker-voyageai"
30
31func New() *ReRankerVoyageAIModule {
32return &ReRankerVoyageAIModule{}
33}
34
35type ReRankerVoyageAIModule struct {
36reranker ReRankerVoyageAIClient
37additionalPropertiesProvider modulecapabilities.AdditionalProperties
38}
39
40type ReRankerVoyageAIClient interface {
41Rank(ctx context.Context, query string, documents []string, cfg moduletools.ClassConfig) (*ent.RankResult, error)
42MetaInfo() (map[string]interface{}, error)
43}
44
45func (m *ReRankerVoyageAIModule) Name() string {
46return Name
47}
48
49func (m *ReRankerVoyageAIModule) Type() modulecapabilities.ModuleType {
50return modulecapabilities.Text2TextReranker
51}
52
53func (m *ReRankerVoyageAIModule) Init(ctx context.Context,
54params moduletools.ModuleInitParams,
55) error {
56if err := m.initAdditional(ctx, params.GetConfig().ModuleHttpClientTimeout, params.GetLogger()); err != nil {
57return errors.Wrap(err, "init cross encoder")
58}
59
60return nil
61}
62
63func (m *ReRankerVoyageAIModule) initAdditional(ctx context.Context, timeout time.Duration,
64logger logrus.FieldLogger,
65) error {
66apiKey := os.Getenv("VOYAGEAI_APIKEY")
67client := clients.New(apiKey, timeout, logger)
68m.reranker = client
69m.additionalPropertiesProvider = rerankeradditional.NewRankerProvider(m.reranker)
70return nil
71}
72
73func (m *ReRankerVoyageAIModule) MetaInfo() (map[string]interface{}, error) {
74return m.reranker.MetaInfo()
75}
76
77func (m *ReRankerVoyageAIModule) RootHandler() http.Handler {
78// TODO: remove once this is a capability interface
79return nil
80}
81
82func (m *ReRankerVoyageAIModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
83return m.additionalPropertiesProvider.AdditionalProperties()
84}
85
86// verify we implement the modules.Module interface
87var (
88_ = modulecapabilities.Module(New())
89_ = modulecapabilities.AdditionalProperties(New())
90_ = modulecapabilities.MetaProvider(New())
91)
92