haystack

Форк
0
/
test_hugging_face_tei_document_embedder.py 
286 строк · 11.5 Кб
1
from unittest.mock import MagicMock, patch
2

3
import numpy as np
4
import pytest
5
from huggingface_hub.utils import RepositoryNotFoundError
6
from haystack.utils.auth import Secret
7

8
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
9
from haystack.dataclasses import Document
10

11

12
@pytest.fixture
13
def mock_check_valid_model():
14
    with patch(
15
        "haystack.components.embedders.hugging_face_tei_document_embedder.check_valid_model",
16
        MagicMock(return_value=None),
17
    ) as mock:
18
        yield mock
19

20

21
def mock_embedding_generation(text, **kwargs):
22
    response = np.array([np.random.rand(384) for i in range(len(text))])
23
    return response
24

25

26
class TestHuggingFaceTEIDocumentEmbedder:
27
    def test_init_default(self, monkeypatch, mock_check_valid_model):
28
        monkeypatch.setenv("HF_API_TOKEN", "fake-api-token")
29
        embedder = HuggingFaceTEIDocumentEmbedder()
30

31
        assert embedder.model == "BAAI/bge-small-en-v1.5"
32
        assert embedder.url is None
33
        assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
34
        assert embedder.prefix == ""
35
        assert embedder.suffix == ""
36
        assert embedder.batch_size == 32
37
        assert embedder.progress_bar is True
38
        assert embedder.meta_fields_to_embed == []
39
        assert embedder.embedding_separator == "\n"
40

41
    def test_init_with_parameters(self, mock_check_valid_model):
42
        embedder = HuggingFaceTEIDocumentEmbedder(
43
            model="sentence-transformers/all-mpnet-base-v2",
44
            url="https://some_embedding_model.com",
45
            token=Secret.from_token("fake-api-token"),
46
            prefix="prefix",
47
            suffix="suffix",
48
            batch_size=64,
49
            progress_bar=False,
50
            meta_fields_to_embed=["test_field"],
51
            embedding_separator=" | ",
52
        )
53

54
        assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
55
        assert embedder.url == "https://some_embedding_model.com"
56
        assert embedder.token == Secret.from_token("fake-api-token")
57
        assert embedder.prefix == "prefix"
58
        assert embedder.suffix == "suffix"
59
        assert embedder.batch_size == 64
60
        assert embedder.progress_bar is False
61
        assert embedder.meta_fields_to_embed == ["test_field"]
62
        assert embedder.embedding_separator == " | "
63

64
    def test_initialize_with_invalid_url(self, mock_check_valid_model):
65
        with pytest.raises(ValueError):
66
            HuggingFaceTEIDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", url="invalid_url")
67

68
    def test_initialize_with_url_but_invalid_model(self, mock_check_valid_model):
69
        # When custom TEI endpoint is used via URL, model must be provided and valid HuggingFace Hub model id
70
        mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id")
71
        with pytest.raises(RepositoryNotFoundError):
72
            HuggingFaceTEIDocumentEmbedder(model="invalid_model_id", url="https://some_embedding_model.com")
73

74
    def test_to_dict(self, mock_check_valid_model):
75
        component = HuggingFaceTEIDocumentEmbedder()
76
        data = component.to_dict()
77

78
        assert data == {
79
            "type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
80
            "init_parameters": {
81
                "model": "BAAI/bge-small-en-v1.5",
82
                "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
83
                "url": None,
84
                "prefix": "",
85
                "suffix": "",
86
                "batch_size": 32,
87
                "progress_bar": True,
88
                "meta_fields_to_embed": [],
89
                "embedding_separator": "\n",
90
            },
91
        }
92

93
    def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
94
        component = HuggingFaceTEIDocumentEmbedder(
95
            model="sentence-transformers/all-mpnet-base-v2",
96
            url="https://some_embedding_model.com",
97
            token=Secret.from_env_var("ENV_VAR", strict=False),
98
            prefix="prefix",
99
            suffix="suffix",
100
            batch_size=64,
101
            progress_bar=False,
102
            meta_fields_to_embed=["test_field"],
103
            embedding_separator=" | ",
104
        )
105

106
        data = component.to_dict()
107

108
        assert data == {
109
            "type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
110
            "init_parameters": {
111
                "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
112
                "model": "sentence-transformers/all-mpnet-base-v2",
113
                "url": "https://some_embedding_model.com",
114
                "prefix": "prefix",
115
                "suffix": "suffix",
116
                "batch_size": 64,
117
                "progress_bar": False,
118
                "meta_fields_to_embed": ["test_field"],
119
                "embedding_separator": " | ",
120
            },
121
        }
122

123
    def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model):
124
        documents = [
125
            Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
126
        ]
127

128
        embedder = HuggingFaceTEIDocumentEmbedder(
129
            model="sentence-transformers/all-mpnet-base-v2",
130
            url="https://some_embedding_model.com",
131
            token=Secret.from_token("fake-api-token"),
132
            meta_fields_to_embed=["meta_field"],
133
            embedding_separator=" | ",
134
        )
135

136
        prepared_texts = embedder._prepare_texts_to_embed(documents)
137

138
        assert prepared_texts == [
139
            "meta_value 0 | document number 0: content",
140
            "meta_value 1 | document number 1: content",
141
            "meta_value 2 | document number 2: content",
142
            "meta_value 3 | document number 3: content",
143
            "meta_value 4 | document number 4: content",
144
        ]
145

146
    def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
147
        documents = [Document(content=f"document number {i}") for i in range(5)]
148

149
        embedder = HuggingFaceTEIDocumentEmbedder(
150
            model="sentence-transformers/all-mpnet-base-v2",
151
            url="https://some_embedding_model.com",
152
            token=Secret.from_token("fake-api-token"),
153
            prefix="my_prefix ",
154
            suffix=" my_suffix",
155
        )
156

157
        prepared_texts = embedder._prepare_texts_to_embed(documents)
158

159
        assert prepared_texts == [
160
            "my_prefix document number 0 my_suffix",
161
            "my_prefix document number 1 my_suffix",
162
            "my_prefix document number 2 my_suffix",
163
            "my_prefix document number 3 my_suffix",
164
            "my_prefix document number 4 my_suffix",
165
        ]
166

167
    def test_embed_batch(self, mock_check_valid_model):
168
        texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
169

170
        with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
171
            mock_embedding_patch.side_effect = mock_embedding_generation
172

173
            embedder = HuggingFaceTEIDocumentEmbedder(
174
                model="BAAI/bge-small-en-v1.5",
175
                url="https://some_embedding_model.com",
176
                token=Secret.from_token("fake-api-token"),
177
            )
178
            embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
179

180
            assert mock_embedding_patch.call_count == 3
181

182
        assert isinstance(embeddings, list)
183
        assert len(embeddings) == len(texts)
184
        for embedding in embeddings:
185
            assert isinstance(embedding, list)
186
            assert len(embedding) == 384
187
            assert all(isinstance(x, float) for x in embedding)
188

189
    def test_run(self, mock_check_valid_model):
190
        docs = [
191
            Document(content="I love cheese", meta={"topic": "Cuisine"}),
192
            Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
193
        ]
194

195
        with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
196
            mock_embedding_patch.side_effect = mock_embedding_generation
197

198
            embedder = HuggingFaceTEIDocumentEmbedder(
199
                model="BAAI/bge-small-en-v1.5",
200
                token=Secret.from_token("fake-api-token"),
201
                prefix="prefix ",
202
                suffix=" suffix",
203
                meta_fields_to_embed=["topic"],
204
                embedding_separator=" | ",
205
            )
206

207
            result = embedder.run(documents=docs)
208

209
            mock_embedding_patch.assert_called_once_with(
210
                text=[
211
                    "prefix Cuisine | I love cheese suffix",
212
                    "prefix ML | A transformer is a deep learning architecture suffix",
213
                ]
214
            )
215
        documents_with_embeddings = result["documents"]
216

217
        assert isinstance(documents_with_embeddings, list)
218
        assert len(documents_with_embeddings) == len(docs)
219
        for doc in documents_with_embeddings:
220
            assert isinstance(doc, Document)
221
            assert isinstance(doc.embedding, list)
222
            assert len(doc.embedding) == 384
223
            assert all(isinstance(x, float) for x in doc.embedding)
224

225
    def test_run_custom_batch_size(self, mock_check_valid_model):
226
        docs = [
227
            Document(content="I love cheese", meta={"topic": "Cuisine"}),
228
            Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
229
        ]
230

231
        with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
232
            mock_embedding_patch.side_effect = mock_embedding_generation
233

234
            embedder = HuggingFaceTEIDocumentEmbedder(
235
                model="BAAI/bge-small-en-v1.5",
236
                token=Secret.from_token("fake-api-token"),
237
                prefix="prefix ",
238
                suffix=" suffix",
239
                meta_fields_to_embed=["topic"],
240
                embedding_separator=" | ",
241
                batch_size=1,
242
            )
243

244
            result = embedder.run(documents=docs)
245

246
            assert mock_embedding_patch.call_count == 2
247

248
        documents_with_embeddings = result["documents"]
249

250
        assert isinstance(documents_with_embeddings, list)
251
        assert len(documents_with_embeddings) == len(docs)
252
        for doc in documents_with_embeddings:
253
            assert isinstance(doc, Document)
254
            assert isinstance(doc.embedding, list)
255
            assert len(doc.embedding) == 384
256
            assert all(isinstance(x, float) for x in doc.embedding)
257

258
    def test_run_wrong_input_format(self, mock_check_valid_model):
259
        embedder = HuggingFaceTEIDocumentEmbedder(
260
            model="BAAI/bge-small-en-v1.5",
261
            url="https://some_embedding_model.com",
262
            token=Secret.from_token("fake-api-token"),
263
        )
264

265
        # wrong formats
266
        string_input = "text"
267
        list_integers_input = [1, 2, 3]
268

269
        with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
270
            embedder.run(documents=string_input)
271

272
        with pytest.raises(TypeError, match="HuggingFaceTEIDocumentEmbedder expects a list of Documents as input"):
273
            embedder.run(documents=list_integers_input)
274

275
    def test_run_on_empty_list(self, mock_check_valid_model):
276
        embedder = HuggingFaceTEIDocumentEmbedder(
277
            model="BAAI/bge-small-en-v1.5",
278
            url="https://some_embedding_model.com",
279
            token=Secret.from_token("fake-api-token"),
280
        )
281

282
        empty_list_input = []
283
        result = embedder.run(documents=empty_list_input)
284

285
        assert result["documents"] is not None
286
        assert not result["documents"]  # empty list
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.