haystack
286 строк · 11.5 Кб
1from unittest.mock import MagicMock, patch
2
3import numpy as np
4import pytest
5from huggingface_hub.utils import RepositoryNotFoundError
6from haystack.utils.auth import Secret
7
8from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
9from haystack.dataclasses import Document
10
11
12@pytest.fixture
13def mock_check_valid_model():
14with patch(
15"haystack.components.embedders.hugging_face_tei_document_embedder.check_valid_model",
16MagicMock(return_value=None),
17) as mock:
18yield mock
19
20
21def mock_embedding_generation(text, **kwargs):
22response = np.array([np.random.rand(384) for i in range(len(text))])
23return response
24
25
26class TestHuggingFaceTEIDocumentEmbedder:
27def test_init_default(self, monkeypatch, mock_check_valid_model):
28monkeypatch.setenv("HF_API_TOKEN", "fake-api-token")
29embedder = HuggingFaceTEIDocumentEmbedder()
30
31assert embedder.model == "BAAI/bge-small-en-v1.5"
32assert embedder.url is None
33assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
34assert embedder.prefix == ""
35assert embedder.suffix == ""
36assert embedder.batch_size == 32
37assert embedder.progress_bar is True
38assert embedder.meta_fields_to_embed == []
39assert embedder.embedding_separator == "\n"
40
41def test_init_with_parameters(self, mock_check_valid_model):
42embedder = HuggingFaceTEIDocumentEmbedder(
43model="sentence-transformers/all-mpnet-base-v2",
44url="https://some_embedding_model.com",
45token=Secret.from_token("fake-api-token"),
46prefix="prefix",
47suffix="suffix",
48batch_size=64,
49progress_bar=False,
50meta_fields_to_embed=["test_field"],
51embedding_separator=" | ",
52)
53
54assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
55assert embedder.url == "https://some_embedding_model.com"
56assert embedder.token == Secret.from_token("fake-api-token")
57assert embedder.prefix == "prefix"
58assert embedder.suffix == "suffix"
59assert embedder.batch_size == 64
60assert embedder.progress_bar is False
61assert embedder.meta_fields_to_embed == ["test_field"]
62assert embedder.embedding_separator == " | "
63
64def test_initialize_with_invalid_url(self, mock_check_valid_model):
65with pytest.raises(ValueError):
66HuggingFaceTEIDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", url="invalid_url")
67
68def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
69# When custom TEI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
70mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
71with pytest.raises(RepositoryNotFoundError):
72HuggingFaceTEIDocumentEmbedder(model="invalid_model_id", url="https://some_embedding_model.com")
73
74def test_to_dict(self, mock_check_valid_model):
75component = HuggingFaceTEIDocumentEmbedder()
76data = component.to_dict()
77
78assert data == {
79"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
80"init_parameters": {
81"model": "BAAI/bge-small-en-v1.5",
82"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
83"url": None,
84"prefix": "",
85"suffix": "",
86"batch_size": 32,
87"progress_bar": True,
88"meta_fields_to_embed": [],
89"embedding_separator": "\n",
90},
91}
92
93def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
94component = HuggingFaceTEIDocumentEmbedder(
95model="sentence-transformers/all-mpnet-base-v2",
96url="https://some_embedding_model.com",
97token=Secret.from_env_var("ENV_VAR", strict=False),
98prefix="prefix",
99suffix="suffix",
100batch_size=64,
101progress_bar=False,
102meta_fields_to_embed=["test_field"],
103embedding_separator=" | ",
104)
105
106data = component.to_dict()
107
108assert data == {
109"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
110"init_parameters": {
111"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
112"model": "sentence-transformers/all-mpnet-base-v2",
113"url": "https://some_embedding_model.com",
114"prefix": "prefix",
115"suffix": "suffix",
116"batch_size": 64,
117"progress_bar": False,
118"meta_fields_to_embed": ["test_field"],
119"embedding_separator": " | ",
120},
121}
122
123def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
124documents = [
125Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
126]
127
128embedder = HuggingFaceTEIDocumentEmbedder(
129model="sentence-transformers/all-mpnet-base-v2",
130url="https://some_embedding_model.com",
131token=Secret.from_token("fake-api-token"),
132meta_fields_to_embed=["meta_field"],
133embedding_separator=" | ",
134)
135
136prepared_texts = embedder._prepare_texts_to_embed(documents)
137
138assert prepared_texts == [
139"meta_value 0 | document number 0: content",
140"meta_value 1 | document number 1: content",
141"meta_value 2 | document number 2: content",
142"meta_value 3 | document number 3: content",
143"meta_value 4 | document number 4: content",
144]
145
146def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
147documents = [Document(content=f"document number {i}") for i in range(5)]
148
149embedder = HuggingFaceTEIDocumentEmbedder(
150model="sentence-transformers/all-mpnet-base-v2",
151url="https://some_embedding_model.com",
152token=Secret.from_token("fake-api-token"),
153prefix="my_prefix ",
154suffix=" my_suffix",
155)
156
157prepared_texts = embedder._prepare_texts_to_embed(documents)
158
159assert prepared_texts == [
160"my_prefix document number 0 my_suffix",
161"my_prefix document number 1 my_suffix",
162"my_prefix document number 2 my_suffix",
163"my_prefix document number 3 my_suffix",
164"my_prefix document number 4 my_suffix",
165]
166
167def test_embed_batch(self, mock_check_valid_model):
168texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
169
170with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
171mock_embedding_patch.side_effect = mock_embedding_generation
172
173embedder = HuggingFaceTEIDocumentEmbedder(
174model="BAAI/bge-small-en-v1.5",
175url="https://some_embedding_model.com",
176token=Secret.from_token("fake-api-token"),
177)
178embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
179
180assert mock_embedding_patch.call_count == 3
181
182assert isinstance(embeddings, list)
183assert len(embeddings) == len(texts)
184for embedding in embeddings:
185assert isinstance(embedding, list)
186assert len(embedding) == 384
187assert all(isinstance(x, float) for x in embedding)
188
189def test_run(self, mock_check_valid_model):
190docs = [
191Document(content="I love cheese", meta={"topic": "Cuisine"}),
192Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
193]
194
195with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
196mock_embedding_patch.side_effect = mock_embedding_generation
197
198embedder = HuggingFaceTEIDocumentEmbedder(
199model="BAAI/bge-small-en-v1.5",
200token=Secret.from_token("fake-api-token"),
201prefix="prefix ",
202suffix=" suffix",
203meta_fields_to_embed=["topic"],
204embedding_separator=" | ",
205)
206
207result = embedder.run(documents=docs)
208
209mock_embedding_patch.assert_called_once_with(
210text=[
211"prefix Cuisine | I love cheese suffix",
212"prefix ML | A transformer is a deep learning architecture suffix",
213]
214)
215documents_with_embeddings = result["documents"]
216
217assert isinstance(documents_with_embeddings, list)
218assert len(documents_with_embeddings) == len(docs)
219for doc in documents_with_embeddings:
220assert isinstance(doc, Document)
221assert isinstance(doc.embedding, list)
222assert len(doc.embedding) == 384
223assert all(isinstance(x, float) for x in doc.embedding)
224
225def test_run_custom_batch_size(self, mock_check_valid_model):
226docs = [
227Document(content="I love cheese", meta={"topic": "Cuisine"}),
228Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
229]
230
231with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
232mock_embedding_patch.side_effect = mock_embedding_generation
233
234embedder = HuggingFaceTEIDocumentEmbedder(
235model="BAAI/bge-small-en-v1.5",
236token=Secret.from_token("fake-api-token"),
237prefix="prefix ",
238suffix=" suffix",
239meta_fields_to_embed=["topic"],
240embedding_separator=" | ",
241batch_size=1,
242)
243
244result = embedder.run(documents=docs)
245
246assert mock_embedding_patch.call_count == 2
247
248documents_with_embeddings = result["documents"]
249
250assert isinstance(documents_with_embeddings, list)
251assert len(documents_with_embeddings) == len(docs)
252for doc in documents_with_embeddings:
253assert isinstance(doc, Document)
254assert isinstance(doc.embedding, list)
255assert len(doc.embedding) == 384
256assert all(isinstance(x, float) for x in doc.embedding)
257
258def test_run_wrong_input_format(self, mock_check_valid_model):
259embedder = HuggingFaceTEIDocumentEmbedder(
260model="BAAI/bge-small-en-v1.5",
261url="https://some_embedding_model.com",
262token=Secret.from_token("fake-api-token"),
263)
264
265# wrong formats
266string_input = "text"
267list_integers_input = [1, 2, 3]
268
269with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
270embedder.run(documents=string_input)
271
272with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
273embedder.run(documents=list_integers_input)
274
275def test_run_on_empty_list(self, mock_check_valid_model):
276embedder = HuggingFaceTEIDocumentEmbedder(
277model="BAAI/bge-small-en-v1.5",
278url="https://some_embedding_model.com",
279token=Secret.from_token("fake-api-token"),
280)
281
282empty_list_input = []
283result = embedder.run(documents=empty_list_input)
284
285assert result["documents"] is not None
286assert not result["documents"] # empty list
287