haystack

Форк
0
/
test_in_memory_bm25_retriever.py 
177 строк · 7.0 Кб
1
from typing import Dict, Any
2

3
import pytest
4

5
from haystack import Pipeline, DeserializationError
6
from haystack.testing.factory import document_store_class
7
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
8
from haystack.dataclasses import Document
9
from haystack.document_stores.in_memory import InMemoryDocumentStore
10

11

12
@pytest.fixture()
13
def mock_docs():
14
    return [
15
        Document(content="Javascript is a popular programming language"),
16
        Document(content="Java is a popular programming language"),
17
        Document(content="Python is a popular programming language"),
18
        Document(content="Ruby is a popular programming language"),
19
        Document(content="PHP is a popular programming language"),
20
    ]
21

22

23
class TestMemoryBM25Retriever:
24
    def test_init_default(self):
25
        retriever = InMemoryBM25Retriever(InMemoryDocumentStore())
26
        assert retriever.filters is None
27
        assert retriever.top_k == 10
28
        assert retriever.scale_score is False
29

30
    def test_init_with_parameters(self):
31
        retriever = InMemoryBM25Retriever(
32
            InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
33
        )
34
        assert retriever.filters == {"name": "test.txt"}
35
        assert retriever.top_k == 5
36
        assert retriever.scale_score
37

38
    def test_init_with_invalid_top_k_parameter(self):
39
        with pytest.raises(ValueError):
40
            InMemoryBM25Retriever(InMemoryDocumentStore(), top_k=-2)
41

42
    def test_to_dict(self):
43
        MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
44
        document_store = MyFakeStore()
45
        document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
46
        component = InMemoryBM25Retriever(document_store=document_store)
47

48
        data = component.to_dict()
49
        assert data == {
50
            "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
51
            "init_parameters": {
52
                "document_store": {"type": "MyFakeStore", "init_parameters": {}},
53
                "filters": None,
54
                "top_k": 10,
55
                "scale_score": False,
56
            },
57
        }
58

59
    def test_to_dict_with_custom_init_parameters(self):
60
        ds = InMemoryDocumentStore()
61
        serialized_ds = ds.to_dict()
62

63
        component = InMemoryBM25Retriever(
64
            document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
65
        )
66
        data = component.to_dict()
67
        assert data == {
68
            "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
69
            "init_parameters": {
70
                "document_store": serialized_ds,
71
                "filters": {"name": "test.txt"},
72
                "top_k": 5,
73
                "scale_score": True,
74
            },
75
        }
76

77
    #
78

79
    def test_from_dict(self):
80
        data = {
81
            "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
82
            "init_parameters": {
83
                "document_store": {
84
                    "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
85
                    "init_parameters": {},
86
                },
87
                "filters": {"name": "test.txt"},
88
                "top_k": 5,
89
            },
90
        }
91
        component = InMemoryBM25Retriever.from_dict(data)
92
        assert isinstance(component.document_store, InMemoryDocumentStore)
93
        assert component.filters == {"name": "test.txt"}
94
        assert component.top_k == 5
95
        assert component.scale_score is False
96

97
    def test_from_dict_without_docstore(self):
98
        data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
99
        with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
100
            InMemoryBM25Retriever.from_dict(data)
101

102
    def test_from_dict_without_docstore_type(self):
103
        data = {"type": "InMemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
104
        with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
105
            InMemoryBM25Retriever.from_dict(data)
106

107
    def test_from_dict_nonexisting_docstore(self):
108
        data = {
109
            "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
110
            "init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
111
        }
112
        with pytest.raises(DeserializationError):
113
            InMemoryBM25Retriever.from_dict(data)
114

115
    def test_retriever_valid_run(self, mock_docs):
116
        ds = InMemoryDocumentStore()
117
        ds.write_documents(mock_docs)
118

119
        retriever = InMemoryBM25Retriever(ds, top_k=5)
120
        result = retriever.run(query="PHP")
121

122
        assert "documents" in result
123
        assert len(result["documents"]) == 5
124
        assert result["documents"][0].content == "PHP is a popular programming language"
125

126
    def test_invalid_run_wrong_store_type(self):
127
        SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
128
        with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
129
            InMemoryBM25Retriever(SomeOtherDocumentStore())
130

131
    @pytest.mark.integration
132
    @pytest.mark.parametrize(
133
        "query, query_result",
134
        [
135
            ("Javascript", "Javascript is a popular programming language"),
136
            ("Java", "Java is a popular programming language"),
137
        ],
138
    )
139
    def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
140
        ds = InMemoryDocumentStore()
141
        ds.write_documents(mock_docs)
142
        retriever = InMemoryBM25Retriever(ds)
143

144
        pipeline = Pipeline()
145
        pipeline.add_component("retriever", retriever)
146
        result: Dict[str, Any] = pipeline.run(data={"retriever": {"query": query}})
147

148
        assert result
149
        assert "retriever" in result
150
        results_docs = result["retriever"]["documents"]
151
        assert results_docs
152
        assert results_docs[0].content == query_result
153

154
    @pytest.mark.integration
155
    @pytest.mark.parametrize(
156
        "query, query_result, top_k",
157
        [
158
            ("Javascript", "Javascript is a popular programming language", 1),
159
            ("Java", "Java is a popular programming language", 2),
160
            ("Ruby", "Ruby is a popular programming language", 3),
161
        ],
162
    )
163
    def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
164
        ds = InMemoryDocumentStore()
165
        ds.write_documents(mock_docs)
166
        retriever = InMemoryBM25Retriever(ds)
167

168
        pipeline = Pipeline()
169
        pipeline.add_component("retriever", retriever)
170
        result: Dict[str, Any] = pipeline.run(data={"retriever": {"query": query, "top_k": top_k}})
171

172
        assert result
173
        assert "retriever" in result
174
        results_docs = result["retriever"]["documents"]
175
        assert results_docs
176
        assert len(results_docs) == top_k
177
        assert results_docs[0].content == query_result
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.