llama-index

Форк
0
91 строка · 3.0 Кб
1
from typing import Any, Dict, Optional, Sequence
2

3
from llama_index.legacy.core.base_selector import (
4
    BaseSelector,
5
    SelectorResult,
6
    SingleSelection,
7
)
8
from llama_index.legacy.embeddings.base import BaseEmbedding
9
from llama_index.legacy.embeddings.utils import resolve_embed_model
10
from llama_index.legacy.indices.query.embedding_utils import get_top_k_embeddings
11
from llama_index.legacy.prompts.mixin import PromptDictType
12
from llama_index.legacy.schema import QueryBundle
13
from llama_index.legacy.tools.types import ToolMetadata
14

15

16
class EmbeddingSingleSelector(BaseSelector):
17
    """Embedding selector.
18

19
    Embedding selector that chooses one out of many options.
20

21
    Args:
22
        embed_model (BaseEmbedding): An embedding model.
23
    """
24

25
    def __init__(
26
        self,
27
        embed_model: BaseEmbedding,
28
    ) -> None:
29
        self._embed_model = embed_model
30

31
    @classmethod
32
    def from_defaults(
33
        cls,
34
        embed_model: Optional[BaseEmbedding] = None,
35
    ) -> "EmbeddingSingleSelector":
36
        # optionally initialize defaults
37
        embed_model = embed_model or resolve_embed_model("default")
38

39
        # construct prompt
40
        return cls(embed_model)
41

42
    def _get_prompts(self) -> Dict[str, Any]:
43
        """Get prompts."""
44
        return {}
45

46
    def _update_prompts(self, prompts: PromptDictType) -> None:
47
        """Update prompts."""
48

49
    def _select(
50
        self, choices: Sequence[ToolMetadata], query: QueryBundle
51
    ) -> SelectorResult:
52
        query_embedding = self._embed_model.get_query_embedding(query.query_str)
53
        text_embeddings = [
54
            self._embed_model.get_text_embedding(choice.description)
55
            for choice in choices
56
        ]
57

58
        top_similarities, top_ids = get_top_k_embeddings(
59
            query_embedding,
60
            text_embeddings,
61
            similarity_top_k=1,
62
            embedding_ids=list(range(len(choices))),
63
        )
64
        # get top choice
65
        top_selection_reason = f"Top similarity match: {top_similarities[0]:.2f}, {choices[top_ids[0]].name}"
66
        top_selection = SingleSelection(index=top_ids[0], reason=top_selection_reason)
67

68
        # parse output
69
        return SelectorResult(selections=[top_selection])
70

71
    async def _aselect(
72
        self, choices: Sequence[ToolMetadata], query: QueryBundle
73
    ) -> SelectorResult:
74
        query_embedding = await self._embed_model.aget_query_embedding(query.query_str)
75
        text_embeddings = [
76
            await self._embed_model.aget_text_embedding(choice.description)
77
            for choice in choices
78
        ]
79

80
        top_similarities, top_ids = get_top_k_embeddings(
81
            query_embedding,
82
            text_embeddings,
83
            similarity_top_k=1,
84
            embedding_ids=list(range(len(choices))),
85
        )
86
        # get top choice
87
        top_selection_reason = f"Top similarity match: {top_similarities[0]:.2f}, {choices[top_ids[0]].name}"
88
        top_selection = SingleSelection(index=top_ids[0], reason=top_selection_reason)
89

90
        # parse output
91
        return SelectorResult(selections=[top_selection])
92

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.