haystack

Форк
0
/
test_sentence_transformers_text_embedder.py 
153 строки · 6.4 Кб
1
from unittest.mock import patch, MagicMock
2
import pytest
3
from haystack.utils import Secret, ComponentDevice
4

5
import numpy as np
6

7
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
8

9

10
class TestSentenceTransformersTextEmbedder:
11
    def test_init_default(self):
12
        embedder = SentenceTransformersTextEmbedder(model="model")
13
        assert embedder.model == "model"
14
        assert embedder.device == ComponentDevice.resolve_device(None)
15
        assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
16
        assert embedder.prefix == ""
17
        assert embedder.suffix == ""
18
        assert embedder.batch_size == 32
19
        assert embedder.progress_bar is True
20
        assert embedder.normalize_embeddings is False
21

22
    def test_init_with_parameters(self):
23
        embedder = SentenceTransformersTextEmbedder(
24
            model="model",
25
            device=ComponentDevice.from_str("cuda:0"),
26
            token=Secret.from_token("fake-api-token"),
27
            prefix="prefix",
28
            suffix="suffix",
29
            batch_size=64,
30
            progress_bar=False,
31
            normalize_embeddings=True,
32
        )
33
        assert embedder.model == "model"
34
        assert embedder.device == ComponentDevice.from_str("cuda:0")
35
        assert embedder.token == Secret.from_token("fake-api-token")
36
        assert embedder.prefix == "prefix"
37
        assert embedder.suffix == "suffix"
38
        assert embedder.batch_size == 64
39
        assert embedder.progress_bar is False
40
        assert embedder.normalize_embeddings is True
41

42
    def test_to_dict(self):
43
        component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
44
        data = component.to_dict()
45
        assert data == {
46
            "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
47
            "init_parameters": {
48
                "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
49
                "model": "model",
50
                "device": ComponentDevice.from_str("cpu").to_dict(),
51
                "prefix": "",
52
                "suffix": "",
53
                "batch_size": 32,
54
                "progress_bar": True,
55
                "normalize_embeddings": False,
56
            },
57
        }
58

59
    def test_to_dict_with_custom_init_parameters(self):
60
        component = SentenceTransformersTextEmbedder(
61
            model="model",
62
            device=ComponentDevice.from_str("cuda:0"),
63
            token=Secret.from_env_var("ENV_VAR", strict=False),
64
            prefix="prefix",
65
            suffix="suffix",
66
            batch_size=64,
67
            progress_bar=False,
68
            normalize_embeddings=True,
69
        )
70
        data = component.to_dict()
71
        assert data == {
72
            "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
73
            "init_parameters": {
74
                "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
75
                "model": "model",
76
                "device": ComponentDevice.from_str("cuda:0").to_dict(),
77
                "prefix": "prefix",
78
                "suffix": "suffix",
79
                "batch_size": 64,
80
                "progress_bar": False,
81
                "normalize_embeddings": True,
82
            },
83
        }
84

85
    def test_to_dict_not_serialize_token(self):
86
        component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))
87
        with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
88
            component.to_dict()
89

90
    def test_from_dict(self):
91
        data = {
92
            "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
93
            "init_parameters": {
94
                "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
95
                "model": "model",
96
                "device": ComponentDevice.from_str("cpu").to_dict(),
97
                "prefix": "",
98
                "suffix": "",
99
                "batch_size": 32,
100
                "progress_bar": True,
101
                "normalize_embeddings": False,
102
            },
103
        }
104
        component = SentenceTransformersTextEmbedder.from_dict(data)
105
        assert component.model == "model"
106
        assert component.device == ComponentDevice.from_str("cpu")
107
        assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
108
        assert component.prefix == ""
109
        assert component.suffix == ""
110
        assert component.batch_size == 32
111
        assert component.progress_bar is True
112
        assert component.normalize_embeddings is False
113

114
    @patch(
115
        "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
116
    )
117
    def test_warmup(self, mocked_factory):
118
        embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
119
        mocked_factory.get_embedding_backend.assert_not_called()
120
        embedder.warm_up()
121
        mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
122

123
    @patch(
124
        "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
125
    )
126
    def test_warmup_doesnt_reload(self, mocked_factory):
127
        embedder = SentenceTransformersTextEmbedder(model="model")
128
        mocked_factory.get_embedding_backend.assert_not_called()
129
        embedder.warm_up()
130
        embedder.warm_up()
131
        mocked_factory.get_embedding_backend.assert_called_once()
132

133
    def test_run(self):
134
        embedder = SentenceTransformersTextEmbedder(model="model")
135
        embedder.embedding_backend = MagicMock()
136
        embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
137

138
        text = "a nice text to embed"
139

140
        result = embedder.run(text=text)
141
        embedding = result["embedding"]
142

143
        assert isinstance(embedding, list)
144
        assert all(isinstance(el, float) for el in embedding)
145

146
    def test_run_wrong_input_format(self):
147
        embedder = SentenceTransformersTextEmbedder(model="model")
148
        embedder.embedding_backend = MagicMock()
149

150
        list_integers_input = [1, 2, 3]
151

152
        with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
153
            embedder.run(text=list_integers_input)
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.