llama-index

Форк
0
103 строки · 3.8 Кб
1
import logging
2
from typing import Callable, List, Optional, cast
3

4
from nltk.stem import PorterStemmer
5

6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.constants import DEFAULT_SIMILARITY_TOP_K
8
from llama_index.legacy.core.base_retriever import BaseRetriever
9
from llama_index.legacy.indices.keyword_table.utils import simple_extract_keywords
10
from llama_index.legacy.indices.vector_store.base import VectorStoreIndex
11
from llama_index.legacy.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
12
from llama_index.legacy.storage.docstore.types import BaseDocumentStore
13

14
logger = logging.getLogger(__name__)
15

16

17
def tokenize_remove_stopwords(text: str) -> List[str]:
18
    # lowercase and stem words
19
    text = text.lower()
20
    stemmer = PorterStemmer()
21
    words = list(simple_extract_keywords(text))
22
    return [stemmer.stem(word) for word in words]
23

24

25
class BM25Retriever(BaseRetriever):
26
    def __init__(
27
        self,
28
        nodes: List[BaseNode],
29
        tokenizer: Optional[Callable[[str], List[str]]],
30
        similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
31
        callback_manager: Optional[CallbackManager] = None,
32
        objects: Optional[List[IndexNode]] = None,
33
        object_map: Optional[dict] = None,
34
        verbose: bool = False,
35
    ) -> None:
36
        try:
37
            from rank_bm25 import BM25Okapi
38
        except ImportError:
39
            raise ImportError("Please install rank_bm25: pip install rank-bm25")
40

41
        self._nodes = nodes
42
        self._tokenizer = tokenizer or tokenize_remove_stopwords
43
        self._similarity_top_k = similarity_top_k
44
        self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
45
        self.bm25 = BM25Okapi(self._corpus)
46
        super().__init__(
47
            callback_manager=callback_manager,
48
            object_map=object_map,
49
            objects=objects,
50
            verbose=verbose,
51
        )
52

53
    @classmethod
54
    def from_defaults(
55
        cls,
56
        index: Optional[VectorStoreIndex] = None,
57
        nodes: Optional[List[BaseNode]] = None,
58
        docstore: Optional[BaseDocumentStore] = None,
59
        tokenizer: Optional[Callable[[str], List[str]]] = None,
60
        similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
61
        verbose: bool = False,
62
    ) -> "BM25Retriever":
63
        # ensure only one of index, nodes, or docstore is passed
64
        if sum(bool(val) for val in [index, nodes, docstore]) != 1:
65
            raise ValueError("Please pass exactly one of index, nodes, or docstore.")
66

67
        if index is not None:
68
            docstore = index.docstore
69

70
        if docstore is not None:
71
            nodes = cast(List[BaseNode], list(docstore.docs.values()))
72

73
        assert (
74
            nodes is not None
75
        ), "Please pass exactly one of index, nodes, or docstore."
76

77
        tokenizer = tokenizer or tokenize_remove_stopwords
78
        return cls(
79
            nodes=nodes,
80
            tokenizer=tokenizer,
81
            similarity_top_k=similarity_top_k,
82
            verbose=verbose,
83
        )
84

85
    def _get_scored_nodes(self, query: str) -> List[NodeWithScore]:
86
        tokenized_query = self._tokenizer(query)
87
        doc_scores = self.bm25.get_scores(tokenized_query)
88

89
        nodes: List[NodeWithScore] = []
90
        for i, node in enumerate(self._nodes):
91
            nodes.append(NodeWithScore(node=node, score=doc_scores[i]))
92

93
        return nodes
94

95
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
96
        if query_bundle.custom_embedding_strs or query_bundle.embedding:
97
            logger.warning("BM25Retriever does not support embeddings, skipping...")
98

99
        scored_nodes = self._get_scored_nodes(query_bundle.query_str)
100

101
        # Sort and get top_k nodes, score range => 0..1, closer to 1 means more relevant
102
        nodes = sorted(scored_nodes, key=lambda x: x.score or 0.0, reverse=True)
103
        return nodes[: self._similarity_top_k]
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.