haystack

Форк
0
/
test_sentence_transformers_document_embedder.py 
229 строк · 9.4 Кб
1
from unittest.mock import patch, MagicMock
2
import pytest
3
import numpy as np
4
from haystack.utils import Secret, ComponentDevice
5

6
from haystack import Document
7
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
8

9

10
class TestSentenceTransformersDocumentEmbedder:
11
    def test_init_default(self):
12
        embedder = SentenceTransformersDocumentEmbedder(model="model")
13
        assert embedder.model == "model"
14
        assert embedder.device == ComponentDevice.resolve_device(None)
15
        assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
16
        assert embedder.prefix == ""
17
        assert embedder.suffix == ""
18
        assert embedder.batch_size == 32
19
        assert embedder.progress_bar is True
20
        assert embedder.normalize_embeddings is False
21
        assert embedder.meta_fields_to_embed == []
22
        assert embedder.embedding_separator == "\n"
23

24
    def test_init_with_parameters(self):
25
        embedder = SentenceTransformersDocumentEmbedder(
26
            model="model",
27
            device=ComponentDevice.from_str("cuda:0"),
28
            token=Secret.from_token("fake-api-token"),
29
            prefix="prefix",
30
            suffix="suffix",
31
            batch_size=64,
32
            progress_bar=False,
33
            normalize_embeddings=True,
34
            meta_fields_to_embed=["test_field"],
35
            embedding_separator=" | ",
36
        )
37
        assert embedder.model == "model"
38
        assert embedder.device == ComponentDevice.from_str("cuda:0")
39
        assert embedder.token == Secret.from_token("fake-api-token")
40
        assert embedder.prefix == "prefix"
41
        assert embedder.suffix == "suffix"
42
        assert embedder.batch_size == 64
43
        assert embedder.progress_bar is False
44
        assert embedder.normalize_embeddings is True
45
        assert embedder.meta_fields_to_embed == ["test_field"]
46
        assert embedder.embedding_separator == " | "
47

48
    def test_to_dict(self):
49
        component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
50
        data = component.to_dict()
51
        assert data == {
52
            "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
53
            "init_parameters": {
54
                "model": "model",
55
                "device": ComponentDevice.from_str("cpu").to_dict(),
56
                "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
57
                "prefix": "",
58
                "suffix": "",
59
                "batch_size": 32,
60
                "progress_bar": True,
61
                "normalize_embeddings": False,
62
                "embedding_separator": "\n",
63
                "meta_fields_to_embed": [],
64
            },
65
        }
66

67
    def test_to_dict_with_custom_init_parameters(self):
68
        component = SentenceTransformersDocumentEmbedder(
69
            model="model",
70
            device=ComponentDevice.from_str("cuda:0"),
71
            token=Secret.from_env_var("ENV_VAR", strict=False),
72
            prefix="prefix",
73
            suffix="suffix",
74
            batch_size=64,
75
            progress_bar=False,
76
            normalize_embeddings=True,
77
            meta_fields_to_embed=["meta_field"],
78
            embedding_separator=" - ",
79
        )
80
        data = component.to_dict()
81

82
        assert data == {
83
            "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
84
            "init_parameters": {
85
                "model": "model",
86
                "device": ComponentDevice.from_str("cuda:0").to_dict(),
87
                "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
88
                "prefix": "prefix",
89
                "suffix": "suffix",
90
                "batch_size": 64,
91
                "progress_bar": False,
92
                "normalize_embeddings": True,
93
                "embedding_separator": " - ",
94
                "meta_fields_to_embed": ["meta_field"],
95
            },
96
        }
97

98
    def test_from_dict(self):
99
        init_parameters = {
100
            "model": "model",
101
            "device": ComponentDevice.from_str("cuda:0").to_dict(),
102
            "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
103
            "prefix": "prefix",
104
            "suffix": "suffix",
105
            "batch_size": 64,
106
            "progress_bar": False,
107
            "normalize_embeddings": True,
108
            "embedding_separator": " - ",
109
            "meta_fields_to_embed": ["meta_field"],
110
        }
111
        component = SentenceTransformersDocumentEmbedder.from_dict(
112
            {
113
                "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
114
                "init_parameters": init_parameters,
115
            }
116
        )
117
        assert component.model == "model"
118
        assert component.device == ComponentDevice.from_str("cuda:0")
119
        assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
120
        assert component.prefix == "prefix"
121
        assert component.suffix == "suffix"
122
        assert component.batch_size == 64
123
        assert component.progress_bar is False
124
        assert component.normalize_embeddings is True
125
        assert component.embedding_separator == " - "
126
        assert component.meta_fields_to_embed == ["meta_field"]
127

128
    @patch(
129
        "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
130
    )
131
    def test_warmup(self, mocked_factory):
132
        embedder = SentenceTransformersDocumentEmbedder(
133
            model="model", token=None, device=ComponentDevice.from_str("cpu")
134
        )
135
        mocked_factory.get_embedding_backend.assert_not_called()
136
        embedder.warm_up()
137
        mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
138

139
    @patch(
140
        "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
141
    )
142
    def test_warmup_doesnt_reload(self, mocked_factory):
143
        embedder = SentenceTransformersDocumentEmbedder(model="model")
144
        mocked_factory.get_embedding_backend.assert_not_called()
145
        embedder.warm_up()
146
        embedder.warm_up()
147
        mocked_factory.get_embedding_backend.assert_called_once()
148

149
    def test_run(self):
150
        embedder = SentenceTransformersDocumentEmbedder(model="model")
151
        embedder.embedding_backend = MagicMock()
152
        embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
153

154
        documents = [Document(content=f"document number {i}") for i in range(5)]
155

156
        result = embedder.run(documents=documents)
157

158
        assert isinstance(result["documents"], list)
159
        assert len(result["documents"]) == len(documents)
160
        for doc in result["documents"]:
161
            assert isinstance(doc, Document)
162
            assert isinstance(doc.embedding, list)
163
            assert isinstance(doc.embedding[0], float)
164

165
    def test_run_wrong_input_format(self):
166
        embedder = SentenceTransformersDocumentEmbedder(model="model")
167

168
        string_input = "text"
169
        list_integers_input = [1, 2, 3]
170

171
        with pytest.raises(
172
            TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
173
        ):
174
            embedder.run(documents=string_input)
175

176
        with pytest.raises(
177
            TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
178
        ):
179
            embedder.run(documents=list_integers_input)
180

181
    def test_embed_metadata(self):
182
        embedder = SentenceTransformersDocumentEmbedder(
183
            model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
184
        )
185
        embedder.embedding_backend = MagicMock()
186

187
        documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
188

189
        embedder.run(documents=documents)
190

191
        embedder.embedding_backend.embed.assert_called_once_with(
192
            [
193
                "meta_value 0\ndocument number 0",
194
                "meta_value 1\ndocument number 1",
195
                "meta_value 2\ndocument number 2",
196
                "meta_value 3\ndocument number 3",
197
                "meta_value 4\ndocument number 4",
198
            ],
199
            batch_size=32,
200
            show_progress_bar=True,
201
            normalize_embeddings=False,
202
        )
203

204
    def test_prefix_suffix(self):
205
        embedder = SentenceTransformersDocumentEmbedder(
206
            model="model",
207
            prefix="my_prefix ",
208
            suffix=" my_suffix",
209
            meta_fields_to_embed=["meta_field"],
210
            embedding_separator="\n",
211
        )
212
        embedder.embedding_backend = MagicMock()
213

214
        documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
215

216
        embedder.run(documents=documents)
217

218
        embedder.embedding_backend.embed.assert_called_once_with(
219
            [
220
                "my_prefix meta_value 0\ndocument number 0 my_suffix",
221
                "my_prefix meta_value 1\ndocument number 1 my_suffix",
222
                "my_prefix meta_value 2\ndocument number 2 my_suffix",
223
                "my_prefix meta_value 3\ndocument number 3 my_suffix",
224
                "my_prefix meta_value 4\ndocument number 4 my_suffix",
225
            ],
226
            batch_size=32,
227
            show_progress_bar=True,
228
            normalize_embeddings=False,
229
        )
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.