haystack
229 строк · 9.4 Кб
1from unittest.mock import patch, MagicMock
2import pytest
3import numpy as np
4from haystack.utils import Secret, ComponentDevice
5
6from haystack import Document
7from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
8
9
10class TestSentenceTransformersDocumentEmbedder:
11def test_init_default(self):
12embedder = SentenceTransformersDocumentEmbedder(model="model")
13assert embedder.model == "model"
14assert embedder.device == ComponentDevice.resolve_device(None)
15assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
16assert embedder.prefix == ""
17assert embedder.suffix == ""
18assert embedder.batch_size == 32
19assert embedder.progress_bar is True
20assert embedder.normalize_embeddings is False
21assert embedder.meta_fields_to_embed == []
22assert embedder.embedding_separator == "\n"
23
24def test_init_with_parameters(self):
25embedder = SentenceTransformersDocumentEmbedder(
26model="model",
27device=ComponentDevice.from_str("cuda:0"),
28token=Secret.from_token("fake-api-token"),
29prefix="prefix",
30suffix="suffix",
31batch_size=64,
32progress_bar=False,
33normalize_embeddings=True,
34meta_fields_to_embed=["test_field"],
35embedding_separator=" | ",
36)
37assert embedder.model == "model"
38assert embedder.device == ComponentDevice.from_str("cuda:0")
39assert embedder.token == Secret.from_token("fake-api-token")
40assert embedder.prefix == "prefix"
41assert embedder.suffix == "suffix"
42assert embedder.batch_size == 64
43assert embedder.progress_bar is False
44assert embedder.normalize_embeddings is True
45assert embedder.meta_fields_to_embed == ["test_field"]
46assert embedder.embedding_separator == " | "
47
48def test_to_dict(self):
49component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
50data = component.to_dict()
51assert data == {
52"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
53"init_parameters": {
54"model": "model",
55"device": ComponentDevice.from_str("cpu").to_dict(),
56"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
57"prefix": "",
58"suffix": "",
59"batch_size": 32,
60"progress_bar": True,
61"normalize_embeddings": False,
62"embedding_separator": "\n",
63"meta_fields_to_embed": [],
64},
65}
66
67def test_to_dict_with_custom_init_parameters(self):
68component = SentenceTransformersDocumentEmbedder(
69model="model",
70device=ComponentDevice.from_str("cuda:0"),
71token=Secret.from_env_var("ENV_VAR", strict=False),
72prefix="prefix",
73suffix="suffix",
74batch_size=64,
75progress_bar=False,
76normalize_embeddings=True,
77meta_fields_to_embed=["meta_field"],
78embedding_separator=" - ",
79)
80data = component.to_dict()
81
82assert data == {
83"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
84"init_parameters": {
85"model": "model",
86"device": ComponentDevice.from_str("cuda:0").to_dict(),
87"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
88"prefix": "prefix",
89"suffix": "suffix",
90"batch_size": 64,
91"progress_bar": False,
92"normalize_embeddings": True,
93"embedding_separator": " - ",
94"meta_fields_to_embed": ["meta_field"],
95},
96}
97
98def test_from_dict(self):
99init_parameters = {
100"model": "model",
101"device": ComponentDevice.from_str("cuda:0").to_dict(),
102"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
103"prefix": "prefix",
104"suffix": "suffix",
105"batch_size": 64,
106"progress_bar": False,
107"normalize_embeddings": True,
108"embedding_separator": " - ",
109"meta_fields_to_embed": ["meta_field"],
110}
111component = SentenceTransformersDocumentEmbedder.from_dict(
112{
113"type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",
114"init_parameters": init_parameters,
115}
116)
117assert component.model == "model"
118assert component.device == ComponentDevice.from_str("cuda:0")
119assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
120assert component.prefix == "prefix"
121assert component.suffix == "suffix"
122assert component.batch_size == 64
123assert component.progress_bar is False
124assert component.normalize_embeddings is True
125assert component.embedding_separator == " - "
126assert component.meta_fields_to_embed == ["meta_field"]
127
128@patch(
129"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
130)
131def test_warmup(self, mocked_factory):
132embedder = SentenceTransformersDocumentEmbedder(
133model="model", token=None, device=ComponentDevice.from_str("cpu")
134)
135mocked_factory.get_embedding_backend.assert_not_called()
136embedder.warm_up()
137mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
138
139@patch(
140"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
141)
142def test_warmup_doesnt_reload(self, mocked_factory):
143embedder = SentenceTransformersDocumentEmbedder(model="model")
144mocked_factory.get_embedding_backend.assert_not_called()
145embedder.warm_up()
146embedder.warm_up()
147mocked_factory.get_embedding_backend.assert_called_once()
148
149def test_run(self):
150embedder = SentenceTransformersDocumentEmbedder(model="model")
151embedder.embedding_backend = MagicMock()
152embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
153
154documents = [Document(content=f"document number {i}") for i in range(5)]
155
156result = embedder.run(documents=documents)
157
158assert isinstance(result["documents"], list)
159assert len(result["documents"]) == len(documents)
160for doc in result["documents"]:
161assert isinstance(doc, Document)
162assert isinstance(doc.embedding, list)
163assert isinstance(doc.embedding[0], float)
164
165def test_run_wrong_input_format(self):
166embedder = SentenceTransformersDocumentEmbedder(model="model")
167
168string_input = "text"
169list_integers_input = [1, 2, 3]
170
171with pytest.raises(
172TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
173):
174embedder.run(documents=string_input)
175
176with pytest.raises(
177TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
178):
179embedder.run(documents=list_integers_input)
180
181def test_embed_metadata(self):
182embedder = SentenceTransformersDocumentEmbedder(
183model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
184)
185embedder.embedding_backend = MagicMock()
186
187documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
188
189embedder.run(documents=documents)
190
191embedder.embedding_backend.embed.assert_called_once_with(
192[
193"meta_value 0\ndocument number 0",
194"meta_value 1\ndocument number 1",
195"meta_value 2\ndocument number 2",
196"meta_value 3\ndocument number 3",
197"meta_value 4\ndocument number 4",
198],
199batch_size=32,
200show_progress_bar=True,
201normalize_embeddings=False,
202)
203
204def test_prefix_suffix(self):
205embedder = SentenceTransformersDocumentEmbedder(
206model="model",
207prefix="my_prefix ",
208suffix=" my_suffix",
209meta_fields_to_embed=["meta_field"],
210embedding_separator="\n",
211)
212embedder.embedding_backend = MagicMock()
213
214documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
215
216embedder.run(documents=documents)
217
218embedder.embedding_backend.embed.assert_called_once_with(
219[
220"my_prefix meta_value 0\ndocument number 0 my_suffix",
221"my_prefix meta_value 1\ndocument number 1 my_suffix",
222"my_prefix meta_value 2\ndocument number 2 my_suffix",
223"my_prefix meta_value 3\ndocument number 3 my_suffix",
224"my_prefix meta_value 4\ndocument number 4 my_suffix",
225],
226batch_size=32,
227show_progress_bar=True,
228normalize_embeddings=False,
229)
230