haystack

Форк
0
/
test_in_memory_embedding_retriever.py 
160 строк · 6.8 Кб
1
from typing import Dict, Any
2

3
import pytest
4
import numpy as np
5

6
from haystack import Pipeline, DeserializationError
7
from haystack.testing.factory import document_store_class
8
from haystack.components.retrievers.in_memory.embedding_retriever import InMemoryEmbeddingRetriever
9
from haystack.dataclasses import Document
10
from haystack.document_stores.in_memory import InMemoryDocumentStore
11

12

13
class TestMemoryEmbeddingRetriever:
14
    def test_init_default(self):
15
        retriever = InMemoryEmbeddingRetriever(InMemoryDocumentStore())
16
        assert retriever.filters is None
17
        assert retriever.top_k == 10
18
        assert retriever.scale_score is False
19

20
    def test_init_with_parameters(self):
21
        retriever = InMemoryEmbeddingRetriever(
22
            InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
23
        )
24
        assert retriever.filters == {"name": "test.txt"}
25
        assert retriever.top_k == 5
26
        assert retriever.scale_score
27

28
    def test_init_with_invalid_top_k_parameter(self):
29
        with pytest.raises(ValueError):
30
            InMemoryEmbeddingRetriever(InMemoryDocumentStore(), top_k=-2)
31

32
    def test_to_dict(self):
33
        MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
34
        document_store = MyFakeStore()
35
        document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}}
36
        component = InMemoryEmbeddingRetriever(document_store=document_store)
37

38
        data = component.to_dict()
39
        assert data == {
40
            "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
41
            "init_parameters": {
42
                "document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}},
43
                "filters": None,
44
                "top_k": 10,
45
                "scale_score": False,
46
                "return_embedding": False,
47
            },
48
        }
49

50
    def test_to_dict_with_custom_init_parameters(self):
51
        MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
52
        document_store = MyFakeStore()
53
        document_store.to_dict = lambda: {"type": "test_module.MyFakeStore", "init_parameters": {}}
54
        component = InMemoryEmbeddingRetriever(
55
            document_store=document_store,
56
            filters={"name": "test.txt"},
57
            top_k=5,
58
            scale_score=True,
59
            return_embedding=True,
60
        )
61
        data = component.to_dict()
62
        assert data == {
63
            "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
64
            "init_parameters": {
65
                "document_store": {"type": "test_module.MyFakeStore", "init_parameters": {}},
66
                "filters": {"name": "test.txt"},
67
                "top_k": 5,
68
                "scale_score": True,
69
                "return_embedding": True,
70
            },
71
        }
72

73
    def test_from_dict(self):
74
        data = {
75
            "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
76
            "init_parameters": {
77
                "document_store": {
78
                    "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
79
                    "init_parameters": {},
80
                },
81
                "filters": {"name": "test.txt"},
82
                "top_k": 5,
83
            },
84
        }
85
        component = InMemoryEmbeddingRetriever.from_dict(data)
86
        assert isinstance(component.document_store, InMemoryDocumentStore)
87
        assert component.filters == {"name": "test.txt"}
88
        assert component.top_k == 5
89
        assert component.scale_score is False
90

91
    def test_from_dict_without_docstore(self):
92
        data = {
93
            "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
94
            "init_parameters": {},
95
        }
96
        with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
97
            InMemoryEmbeddingRetriever.from_dict(data)
98

99
    def test_from_dict_without_docstore_type(self):
100
        data = {
101
            "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
102
            "init_parameters": {"document_store": {"init_parameters": {}}},
103
        }
104
        with pytest.raises(DeserializationError):
105
            InMemoryEmbeddingRetriever.from_dict(data)
106

107
    def test_from_dict_nonexisting_docstore(self):
108
        data = {
109
            "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
110
            "init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
111
        }
112
        with pytest.raises(DeserializationError):
113
            InMemoryEmbeddingRetriever.from_dict(data)
114

115
    def test_valid_run(self):
116
        top_k = 3
117
        ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
118
        docs = [
119
            Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
120
            Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
121
            Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
122
        ]
123
        ds.write_documents(docs)
124

125
        retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
126
        result = retriever.run(query_embedding=[0.1, 0.1, 0.1, 0.1], return_embedding=True)
127

128
        assert "documents" in result
129
        assert len(result["documents"]) == top_k
130
        assert np.array_equal(result["documents"][0].embedding, [1.0, 1.0, 1.0, 1.0])
131

132
    def test_invalid_run_wrong_store_type(self):
133
        SomeOtherDocumentStore = document_store_class("SomeOtherDocumentStore")
134
        with pytest.raises(ValueError, match="document_store must be an instance of InMemoryDocumentStore"):
135
            InMemoryEmbeddingRetriever(SomeOtherDocumentStore())
136

137
    @pytest.mark.integration
138
    def test_run_with_pipeline(self):
139
        ds = InMemoryDocumentStore(embedding_similarity_function="cosine")
140
        top_k = 2
141
        docs = [
142
            Document(content="my document", embedding=[0.1, 0.2, 0.3, 0.4]),
143
            Document(content="another document", embedding=[1.0, 1.0, 1.0, 1.0]),
144
            Document(content="third document", embedding=[0.5, 0.7, 0.5, 0.7]),
145
        ]
146
        ds.write_documents(docs)
147
        retriever = InMemoryEmbeddingRetriever(ds, top_k=top_k)
148

149
        pipeline = Pipeline()
150
        pipeline.add_component("retriever", retriever)
151
        result: Dict[str, Any] = pipeline.run(
152
            data={"retriever": {"query_embedding": [0.1, 0.1, 0.1, 0.1], "return_embedding": True}}
153
        )
154

155
        assert result
156
        assert "retriever" in result
157
        results_docs = result["retriever"]["documents"]
158
        assert results_docs
159
        assert len(results_docs) == top_k
160
        assert np.array_equal(results_docs[0].embedding, [1.0, 1.0, 1.0, 1.0])
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.