haystack
153 строки · 6.4 Кб
1from unittest.mock import patch, MagicMock2import pytest3from haystack.utils import Secret, ComponentDevice4
5import numpy as np6
7from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder8
9
10class TestSentenceTransformersTextEmbedder:11def test_init_default(self):12embedder = SentenceTransformersTextEmbedder(model="model")13assert embedder.model == "model"14assert embedder.device == ComponentDevice.resolve_device(None)15assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)16assert embedder.prefix == ""17assert embedder.suffix == ""18assert embedder.batch_size == 3219assert embedder.progress_bar is True20assert embedder.normalize_embeddings is False21
22def test_init_with_parameters(self):23embedder = SentenceTransformersTextEmbedder(24model="model",25device=ComponentDevice.from_str("cuda:0"),26token=Secret.from_token("fake-api-token"),27prefix="prefix",28suffix="suffix",29batch_size=64,30progress_bar=False,31normalize_embeddings=True,32)33assert embedder.model == "model"34assert embedder.device == ComponentDevice.from_str("cuda:0")35assert embedder.token == Secret.from_token("fake-api-token")36assert embedder.prefix == "prefix"37assert embedder.suffix == "suffix"38assert embedder.batch_size == 6439assert embedder.progress_bar is False40assert embedder.normalize_embeddings is True41
42def test_to_dict(self):43component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))44data = component.to_dict()45assert data == {46"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",47"init_parameters": {48"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},49"model": "model",50"device": ComponentDevice.from_str("cpu").to_dict(),51"prefix": "",52"suffix": "",53"batch_size": 32,54"progress_bar": True,55"normalize_embeddings": False,56},57}58
59def test_to_dict_with_custom_init_parameters(self):60component = SentenceTransformersTextEmbedder(61model="model",62device=ComponentDevice.from_str("cuda:0"),63token=Secret.from_env_var("ENV_VAR", strict=False),64prefix="prefix",65suffix="suffix",66batch_size=64,67progress_bar=False,68normalize_embeddings=True,69)70data = component.to_dict()71assert data == {72"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",73"init_parameters": {74"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},75"model": "model",76"device": ComponentDevice.from_str("cuda:0").to_dict(),77"prefix": "prefix",78"suffix": "suffix",79"batch_size": 64,80"progress_bar": False,81"normalize_embeddings": True,82},83}84
85def test_to_dict_not_serialize_token(self):86component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))87with pytest.raises(ValueError, match="Cannot serialize token-based secret"):88component.to_dict()89
90def test_from_dict(self):91data = {92"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",93"init_parameters": {94"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},95"model": "model",96"device": ComponentDevice.from_str("cpu").to_dict(),97"prefix": "",98"suffix": "",99"batch_size": 32,100"progress_bar": True,101"normalize_embeddings": False,102},103}104component = SentenceTransformersTextEmbedder.from_dict(data)105assert component.model == "model"106assert component.device == ComponentDevice.from_str("cpu")107assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)108assert component.prefix == ""109assert component.suffix == ""110assert component.batch_size == 32111assert component.progress_bar is True112assert component.normalize_embeddings is False113
114@patch(115"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"116)117def test_warmup(self, mocked_factory):118embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))119mocked_factory.get_embedding_backend.assert_not_called()120embedder.warm_up()121mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)122
123@patch(124"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"125)126def test_warmup_doesnt_reload(self, mocked_factory):127embedder = SentenceTransformersTextEmbedder(model="model")128mocked_factory.get_embedding_backend.assert_not_called()129embedder.warm_up()130embedder.warm_up()131mocked_factory.get_embedding_backend.assert_called_once()132
133def test_run(self):134embedder = SentenceTransformersTextEmbedder(model="model")135embedder.embedding_backend = MagicMock()136embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()137
138text = "a nice text to embed"139
140result = embedder.run(text=text)141embedding = result["embedding"]142
143assert isinstance(embedding, list)144assert all(isinstance(el, float) for el in embedding)145
146def test_run_wrong_input_format(self):147embedder = SentenceTransformersTextEmbedder(model="model")148embedder.embedding_backend = MagicMock()149
150list_integers_input = [1, 2, 3]151
152with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):153embedder.run(text=list_integers_input)154