haystack
381 строка · 16.0 Кб
1from unittest.mock import MagicMock, patch2from haystack.utils.auth import Secret3
4import pytest5import logging6import torch7from transformers.modeling_outputs import SequenceClassifierOutput8
9from haystack import ComponentError, Document10from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker11from haystack.utils.device import ComponentDevice, DeviceMap12
13
14class TestSimilarityRanker:15def test_to_dict(self):16component = TransformersSimilarityRanker()17data = component.to_dict()18assert data == {19"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",20"init_parameters": {21"device": None,22"top_k": 10,23"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},24"query_prefix": "",25"document_prefix": "",26"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",27"meta_fields_to_embed": [],28"embedding_separator": "\n",29"scale_score": True,30"calibration_factor": 1.0,31"score_threshold": None,32"model_kwargs": {"device_map": ComponentDevice.resolve_device(None).to_hf()},33},34}35
36def test_to_dict_with_custom_init_parameters(self):37component = TransformersSimilarityRanker(38model="my_model",39device=ComponentDevice.from_str("cuda:0"),40token=Secret.from_env_var("ENV_VAR", strict=False),41top_k=5,42query_prefix="query_instruction: ",43document_prefix="document_instruction: ",44scale_score=False,45calibration_factor=None,46score_threshold=0.01,47model_kwargs={"torch_dtype": torch.float16},48)49data = component.to_dict()50assert data == {51"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",52"init_parameters": {53"device": None,54"model": "my_model",55"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},56"top_k": 5,57"query_prefix": "query_instruction: ",58"document_prefix": "document_instruction: ",59"meta_fields_to_embed": [],60"embedding_separator": "\n",61"scale_score": False,62"calibration_factor": None,63"score_threshold": 0.01,64"model_kwargs": {65"torch_dtype": "torch.float16",66"device_map": ComponentDevice.from_str("cuda:0").to_hf(),67}, # torch_dtype is correctly serialized68},69}70
71def test_to_dict_with_quantization_options(self):72component = TransformersSimilarityRanker(73model_kwargs={74"load_in_4bit": True,75"bnb_4bit_use_double_quant": True,76"bnb_4bit_quant_type": "nf4",77"bnb_4bit_compute_dtype": torch.bfloat16,78}79)80data = component.to_dict()81assert data == {82"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",83"init_parameters": {84"device": None,85"top_k": 10,86"query_prefix": "",87"document_prefix": "",88"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},89"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",90"meta_fields_to_embed": [],91"embedding_separator": "\n",92"scale_score": True,93"calibration_factor": 1.0,94"score_threshold": None,95"model_kwargs": {96"load_in_4bit": True,97"bnb_4bit_use_double_quant": True,98"bnb_4bit_quant_type": "nf4",99"bnb_4bit_compute_dtype": "torch.bfloat16",100"device_map": ComponentDevice.resolve_device(None).to_hf(),101},102},103}104
105@pytest.mark.parametrize(106"device_map,expected",107[108("auto", "auto"),109("cpu:0", ComponentDevice.from_str("cpu:0").to_hf()),110({"": "cpu:0"}, ComponentDevice.from_multiple(DeviceMap.from_hf({"": "cpu:0"})).to_hf()),111],112)113def test_to_dict_device_map(self, device_map, expected):114component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map}, token=None)115data = component.to_dict()116
117assert data == {118"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",119"init_parameters": {120"device": None,121"top_k": 10,122"token": None,123"query_prefix": "",124"document_prefix": "",125"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",126"meta_fields_to_embed": [],127"embedding_separator": "\n",128"scale_score": True,129"calibration_factor": 1.0,130"score_threshold": None,131"model_kwargs": {"device_map": expected},132},133}134
135def test_from_dict(self):136data = {137"type": "haystack.components.rankers.transformers_similarity.TransformersSimilarityRanker",138"init_parameters": {139"device": None,140"model": "my_model",141"token": None,142"top_k": 5,143"query_prefix": "",144"document_prefix": "",145"meta_fields_to_embed": [],146"embedding_separator": "\n",147"scale_score": False,148"calibration_factor": None,149"score_threshold": 0.01,150"model_kwargs": {"torch_dtype": "torch.float16"},151},152}153
154component = TransformersSimilarityRanker.from_dict(data)155assert component.device is None156assert component.model_name_or_path == "my_model"157assert component.token is None158assert component.top_k == 5159assert component.query_prefix == ""160assert component.document_prefix == ""161assert component.meta_fields_to_embed == []162assert component.embedding_separator == "\n"163assert not component.scale_score164assert component.calibration_factor is None165assert component.score_threshold == 0.01166# torch_dtype is correctly deserialized167assert component.model_kwargs == {168"torch_dtype": torch.float16,169"device_map": ComponentDevice.resolve_device(None).to_hf(),170}171
172@patch("torch.sigmoid")173@patch("torch.sort")174def test_embed_meta(self, mocked_sort, mocked_sigmoid):175mocked_sort.return_value = (None, torch.tensor([0]))176mocked_sigmoid.return_value = torch.tensor([0])177embedder = TransformersSimilarityRanker(178model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"179)180embedder.model = MagicMock()181embedder.tokenizer = MagicMock()182embedder.device = MagicMock()183embedder.warm_up()184
185documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]186
187embedder.run(query="test", documents=documents)188
189embedder.tokenizer.assert_called_once_with(190[191["test", "meta_value 0\ndocument number 0"],192["test", "meta_value 1\ndocument number 1"],193["test", "meta_value 2\ndocument number 2"],194["test", "meta_value 3\ndocument number 3"],195["test", "meta_value 4\ndocument number 4"],196],197padding=True,198truncation=True,199return_tensors="pt",200)201
202@patch("torch.sigmoid")203@patch("torch.sort")204def test_prefix(self, mocked_sort, mocked_sigmoid):205mocked_sort.return_value = (None, torch.tensor([0]))206mocked_sigmoid.return_value = torch.tensor([0])207embedder = TransformersSimilarityRanker(208model="model", query_prefix="query_instruction: ", document_prefix="document_instruction: "209)210embedder.model = MagicMock()211embedder.tokenizer = MagicMock()212embedder.device = MagicMock()213embedder.warm_up()214
215documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]216
217embedder.run(query="test", documents=documents)218
219embedder.tokenizer.assert_called_once_with(220[221["query_instruction: test", "document_instruction: document number 0"],222["query_instruction: test", "document_instruction: document number 1"],223["query_instruction: test", "document_instruction: document number 2"],224["query_instruction: test", "document_instruction: document number 3"],225["query_instruction: test", "document_instruction: document number 4"],226],227padding=True,228truncation=True,229return_tensors="pt",230)231
232@patch("torch.sort")233def test_scale_score_false(self, mocked_sort):234mocked_sort.return_value = (None, torch.tensor([0, 1]))235embedder = TransformersSimilarityRanker(model="model", scale_score=False)236embedder.model = MagicMock()237embedder.model.return_value = SequenceClassifierOutput(238loss=None, logits=torch.FloatTensor([[-10.6859], [-8.9874]]), hidden_states=None, attentions=None239)240embedder.tokenizer = MagicMock()241embedder.device = MagicMock()242
243documents = [Document(content="document number 0"), Document(content="document number 1")]244out = embedder.run(query="test", documents=documents)245assert out["documents"][0].score == pytest.approx(-10.6859, abs=1e-4)246assert out["documents"][1].score == pytest.approx(-8.9874, abs=1e-4)247
248@patch("torch.sort")249def test_score_threshold(self, mocked_sort):250mocked_sort.return_value = (None, torch.tensor([0, 1]))251embedder = TransformersSimilarityRanker(model="model", scale_score=False, score_threshold=0.1)252embedder.model = MagicMock()253embedder.model.return_value = SequenceClassifierOutput(254loss=None, logits=torch.FloatTensor([[0.955], [0.001]]), hidden_states=None, attentions=None255)256embedder.tokenizer = MagicMock()257embedder.device = MagicMock()258
259documents = [Document(content="document number 0"), Document(content="document number 1")]260out = embedder.run(query="test", documents=documents)261assert len(out["documents"]) == 1262
263def test_device_map_and_device_raises(self, caplog):264with caplog.at_level(logging.WARNING):265_ = TransformersSimilarityRanker(266"model", model_kwargs={"device_map": "cpu"}, device=ComponentDevice.from_str("cuda")267)268assert (269"The parameters `device` and `device_map` from `model_kwargs` are both provided. Ignoring `device` and using `device_map`."270in caplog.text271)272
273@patch("haystack.components.rankers.transformers_similarity.AutoTokenizer.from_pretrained")274@patch("haystack.components.rankers.transformers_similarity.AutoModelForSequenceClassification.from_pretrained")275def test_device_map_dict(self, mocked_automodel, mocked_autotokenizer):276ranker = TransformersSimilarityRanker("model", model_kwargs={"device_map": {"layer_1": 1, "classifier": "cpu"}})277
278class MockedModel:279def __init__(self):280self.hf_device_map = {"layer_1": 1, "classifier": "cpu"}281
282mocked_automodel.return_value = MockedModel()283ranker.warm_up()284
285mocked_automodel.assert_called_once_with("model", token=None, device_map={"layer_1": 1, "classifier": "cpu"})286assert ranker.device == ComponentDevice.from_multiple(DeviceMap.from_hf({"layer_1": 1, "classifier": "cpu"}))287
288@pytest.mark.integration289@pytest.mark.parametrize(290"query,docs_before_texts,expected_first_text,scores",291[292(293"City in Bosnia and Herzegovina",294["Berlin", "Belgrade", "Sarajevo"],295"Sarajevo",296[2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331],297),298(299"Machine learning",300["Python", "Bakery in Paris", "Tesla Giga Berlin"],301"Python",302[1.9063229046878405e-05, 1.434577916370472e-05, 1.3049247172602918e-05],303),304(305"Cubist movement",306["Nirvana", "Pablo Picasso", "Coffee"],307"Pablo Picasso",308[1.3313065210240893e-05, 9.90335684036836e-05, 1.3518535524781328e-05],309),310],311)312def test_run(self, query, docs_before_texts, expected_first_text, scores):313"""314Test if the component ranks documents correctly.
315"""
316ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2")317ranker.warm_up()318docs_before = [Document(content=text) for text in docs_before_texts]319output = ranker.run(query=query, documents=docs_before)320docs_after = output["documents"]321
322assert len(docs_after) == 3323assert docs_after[0].content == expected_first_text324
325sorted_scores = sorted(scores, reverse=True)326assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6)327assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6)328assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6)329
330# Returns an empty list if no documents are provided331@pytest.mark.integration332def test_returns_empty_list_if_no_documents_are_provided(self):333sampler = TransformersSimilarityRanker()334sampler.warm_up()335output = sampler.run(query="City in Germany", documents=[])336assert not output["documents"]337
338# Raises ComponentError if model is not warmed up339@pytest.mark.integration340def test_raises_component_error_if_model_not_warmed_up(self):341sampler = TransformersSimilarityRanker()342with pytest.raises(ComponentError):343sampler.run(query="query", documents=[Document(content="document")])344
345@pytest.mark.integration346@pytest.mark.parametrize(347"query,docs_before_texts,expected_first_text",348[349("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"),350("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"),351("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"),352],353)354def test_run_top_k(self, query, docs_before_texts, expected_first_text):355"""356Test if the component ranks documents correctly with a custom top_k.
357"""
358ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)359ranker.warm_up()360docs_before = [Document(content=text) for text in docs_before_texts]361output = ranker.run(query=query, documents=docs_before)362docs_after = output["documents"]363
364assert len(docs_after) == 2365assert docs_after[0].content == expected_first_text366
367sorted_scores = sorted([doc.score for doc in docs_after], reverse=True)368assert [doc.score for doc in docs_after] == sorted_scores369
370@pytest.mark.integration371def test_run_single_document(self):372"""373Test if the component runs with a single document.
374"""
375ranker = TransformersSimilarityRanker(model="cross-encoder/ms-marco-MiniLM-L-6-v2", device=None)376ranker.warm_up()377docs_before = [Document(content="Berlin")]378output = ranker.run(query="City in Germany", documents=docs_before)379docs_after = output["documents"]380
381assert len(docs_after) == 1382