haystack
177 строк · 7.0 Кб
1from typing import Dict, Any
2
3import pytest
4
5from haystack import Pipeline, DeserializationError
6from haystack.testing.factory import document_store_class
7from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
8from haystack.dataclasses import Document
9from haystack.document_stores.in_memory import InMemoryDocumentStore
10
11
12@pytest.fixture()
13def mock_docs():
14return [
15Document(content="Javascript is a popular programming language"),
16Document(content="Java is a popular programming language"),
17Document(content="Python is a popular programming language"),
18Document(content="Ruby is a popular programming language"),
19Document(content="PHP is a popular programming language"),
20]
21
22
23class TestMemoryBM25Retriever:
24def test_init_default(self):
25retriever = InMemoryBM25Retriever(InMemoryDocumentStore())
26assert retriever.filters is None
27assert retriever.top_k == 10
28assert retriever.scale_score is False
29
30def test_init_with_parameters(self):
31retriever = InMemoryBM25Retriever(
32InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
33)
34assert retriever.filters == {"name": "test.txt"}
35assert retriever.top_k == 5
36assert retriever.scale_score
37
38def test_init_with_invalid_top_k_parameter(self):
39with pytest.raises(ValueError):
40InMemoryBM25Retriever(InMemoryDocumentStore(), top_k=-2)
41
42def test_to_dict(self):
43MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
44document_store = MyFakeStore()
45document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
46component = InMemoryBM25Retriever(document_store=document_store)
47
48data = component.to_dict()
49assert data == {
50"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
51"init_parameters": {
52"document_store": {"type": "MyFakeStore", "init_parameters": {}},
53"filters": None,
54"top_k": 10,
55"scale_score": False,
56},
57}
58
59def test_to_dict_with_custom_init_parameters(self):
60ds = InMemoryDocumentStore()
61serialized_ds = ds.to_dict()
62
63component = InMemoryBM25Retriever(
64document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
65)
66data = component.to_dict()
67assert data == {
68"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
69"init_parameters": {
70"document_store": serialized_ds,
71"filters": {"name": "test.txt"},
72"top_k": 5,
73"scale_score": True,
74},
75}
76
77#
78
79def test_from_dict(self):
80data = {
81"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
82"init_parameters": {
83"document_store": {
84"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
85"init_parameters": {},
86},
87"filters": {"name": "test.txt"},
88"top_k": 5,
89},
90}
91component = InMemoryBM25Retriever.from_dict(data)
92assert isinstance(component.document_store, InMemoryDocumentStore)
93assert component.filters == {"name": "test.txt"}
94assert component.top_k == 5
95assert component.scale_score is False
96
97def test_from_dict_without_docstore(self):
98data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
99with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
100InMemoryBM25Retriever.from_dict(data)
101
102def test_from_dict_without_docstore_type(self):
103data = {"type": "InMemoryBM25Retriever", "init_parameters": {"document_store": {"init_parameters": {}}}}
104with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
105InMemoryBM25Retriever.from_dict(data)
106
107def test_from_dict_nonexisting_docstore(self):
108data = {
109"type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
110"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
111}
112with pytest.raises(DeserializationError):
113InMemoryBM25Retriever.from_dict(data)
114
115def test_retriever_valid_run(self, mock_docs):
116ds = InMemoryDocumentStore()
117ds.write_documents(mock_docs)
118
119retriever = InMemoryBM25Retriever(ds, top_k=5)
120result = retriever.run(query="PHP")
121
122assert "documents" in result
123assert len(result["documents"]) == 5
124assert result["documents"][0].content == "PHP is a popular programming language"
125
126def test_invalid_run_wrong_store_type(self):
127SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
128with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
129InMemoryBM25Retriever(SomeOtherDocumentStore())
130
131@pytest.mark.integration
132@pytest.mark.parametrize(
133"query, query_result",
134[
135("Javascript", "Javascript is a popular programming language"),
136("Java", "Java is a popular programming language"),
137],
138)
139def test_run_with_pipeline(self, mock_docs, query: str, query_result: str):
140ds = InMemoryDocumentStore()
141ds.write_documents(mock_docs)
142retriever = InMemoryBM25Retriever(ds)
143
144pipeline = Pipeline()
145pipeline.add_component("retriever", retriever)
146result: Dict[str, Any] = pipeline.run(data={"retriever": {"query": query}})
147
148assert result
149assert "retriever" in result
150results_docs = result["retriever"]["documents"]
151assert results_docs
152assert results_docs[0].content == query_result
153
154@pytest.mark.integration
155@pytest.mark.parametrize(
156"query, query_result, top_k",
157[
158("Javascript", "Javascript is a popular programming language", 1),
159("Java", "Java is a popular programming language", 2),
160("Ruby", "Ruby is a popular programming language", 3),
161],
162)
163def test_run_with_pipeline_and_top_k(self, mock_docs, query: str, query_result: str, top_k: int):
164ds = InMemoryDocumentStore()
165ds.write_documents(mock_docs)
166retriever = InMemoryBM25Retriever(ds)
167
168pipeline = Pipeline()
169pipeline.add_component("retriever", retriever)
170result: Dict[str, Any] = pipeline.run(data={"retriever": {"query": query, "top_k": top_k}})
171
172assert result
173assert "retriever" in result
174results_docs = result["retriever"]["documents"]
175assert results_docs
176assert len(results_docs) == top_k
177assert results_docs[0].content == query_result
178