haystack
160 строк · 6.8 Кб
1from typing import Dict, Any
2
3import pytest
4import numpy as np
5
6from haystack import Pipeline, DeserializationError
7from haystack.testing.factory import document_store_class
8from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
9from haystack.dataclasses import Document
10from haystack.document_stores.in_memory import InMemoryDocumentStore
11
12
13class TestMemoryEmbeddingRetriever:
14def test_init_default(self):
15retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
16assert retriever.filters is None
17assert retriever.top_k == 10
18assert retriever.scale_score is False
19
20def test_init_with_parameters(self):
21retriever = InMemoryEmbeddingRetriever(
22InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
23)
24assert retriever.filters == {"name": "test.txt"}
25assert retriever.top_k == 5
26assert retriever.scale_score
27
28def test_init_with_invalid_top_k_parameter(self):
29with pytest.raises(ValueError):
30InMemoryEmbeddingRetriever(InMemoryDocumentStore(), top_k=-2)
31
32def test_to_dict(self):
33MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
34document_store = MyFakeStore()
35document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}}
36component = InMemoryEmbeddingRetriever(document_store=document_store)
37
38data = component.to_dict()
39assert data == {
40"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
41"init_parameters": {
42"document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}},
43"filters": None,
44"top_k": 10,
45"scale_score": False,
46"return_embedding": False,
47},
48}
49
50def test_to_dict_with_custom_init_parameters(self):
51MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
52document_store = MyFakeStore()
53document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}}
54component = InMemoryEmbeddingRetriever(
55document_store=document_store,
56filters={"name": "test.txt"},
57top_k=5,
58scale_score=True,
59return_embedding=True,
60)
61data = component.to_dict()
62assert data == {
63"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
64"init_parameters": {
65"document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}},
66"filters": {"name": "test.txt"},
67"top_k": 5,
68"scale_score": True,
69"return_embedding": True,
70},
71}
72
73def test_from_dict(self):
74data = {
75"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
76"init_parameters": {
77"document_store": {
78"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
79"init_parameters": {},
80},
81"filters": {"name": "test.txt"},
82"top_k": 5,
83},
84}
85component = InMemoryEmbeddingRetriever.from_dict(data)
86assert isinstance(component.document_store, InMemoryDocumentStore)
87assert component.filters == {"name": "test.txt"}
88assert component.top_k == 5
89assert component.scale_score is False
90
91def test_from_dict_without_docstore(self):
92data = {
93"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
94"init_parameters": {},
95}
96with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
97InMemoryEmbeddingRetriever.from_dict(data)
98
99def test_from_dict_without_docstore_type(self):
100data = {
101"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
102"init_parameters": {"document_store": {"init_parameters": {}}},
103}
104with pytest.raises(DeserializationError):
105InMemoryEmbeddingRetriever.from_dict(data)
106
107def test_from_dict_nonexisting_docstore(self):
108data = {
109"type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
110"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
111}
112with pytest.raises(DeserializationError):
113InMemoryEmbeddingRetriever.from_dict(data)
114
115def test_valid_run(self):
116top_k = 3
117ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
118docs = [
119Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
120Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
121Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
122]
123ds.write_documents(docs)
124
125retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
126result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
127
128assert "documents" in result
129assert len(result["documents"]) == top_k
130assert np.array_equal(result["documents"][0].embedding, [1.0, 1.0, 1.0, 1.0])
131
132def test_invalid_run_wrong_store_type(self):
133SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
134with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
135InMemoryEmbeddingRetriever(SomeOtherDocumentStore())
136
137@pytest.mark.integration
138def test_run_with_pipeline(self):
139ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
140top_k = 2
141docs = [
142Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
143Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
144Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
145]
146ds.write_documents(docs)
147retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
148
149pipeline = Pipeline()
150pipeline.add_component("retriever", retriever)
151result: Dict[str, Any] = pipeline.run(
152data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
153)
154
155assert result
156assert "retriever" in result
157results_docs = result["retriever"]["documents"]
158assert results_docs
159assert len(results_docs) == top_k
160assert np.array_equal(results_docs[0].embedding, [1.0, 1.0, 1.0, 1.0])
161