haystack
104 строки · 5.3 Кб
1import pytest2from haystack import Document3from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker4
5
6class TestLostInTheMiddleRanker:7def test_lost_in_the_middle_order_odd(self):8# tests that lost_in_the_middle order works with an odd number of documents9docs = [Document(content=str(i)) for i in range(1, 10)]10ranker = LostInTheMiddleRanker()11result = ranker.run(documents=docs)12assert result["documents"]13expected_order = "1 3 5 7 9 8 6 4 2".split()14assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))15
16def test_lost_in_the_middle_order_even(self):17# tests that lost_in_the_middle order works with an even number of documents18docs = [Document(content=str(i)) for i in range(1, 11)]19ranker = LostInTheMiddleRanker()20result = ranker.run(documents=docs)21expected_order = "1 3 5 7 9 10 8 6 4 2".split()22assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))23
24def test_lost_in_the_middle_order_two_docs(self):25# tests that lost_in_the_middle order works with two documents26ranker = LostInTheMiddleRanker()27# two docs28docs = [Document(content="1"), Document(content="2")]29result = ranker.run(documents=docs)30assert result["documents"][0].content == "1"31assert result["documents"][1].content == "2"32
33def test_lost_in_the_middle_init(self):34# tests that LostInTheMiddleRanker initializes with default values35ranker = LostInTheMiddleRanker()36assert ranker.word_count_threshold is None37
38ranker = LostInTheMiddleRanker(word_count_threshold=10)39assert ranker.word_count_threshold == 1040
41def test_lost_in_the_middle_init_invalid_word_count_threshold(self):42# tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 043with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):44LostInTheMiddleRanker(word_count_threshold=0)45
46with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):47LostInTheMiddleRanker(word_count_threshold=-5)48
49def test_lost_in_the_middle_with_word_count_threshold(self):50# tests that lost_in_the_middle with word_count_threshold works as expected51ranker = LostInTheMiddleRanker(word_count_threshold=6)52docs = [Document(content="word" + str(i)) for i in range(1, 10)]53# result, _ = ranker.run(query="", documents=docs)54result = ranker.run(documents=docs)55expected_order = "word1 word3 word5 word6 word4 word2".split()56assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))57
58ranker = LostInTheMiddleRanker(word_count_threshold=9)59# result, _ = ranker.run(query="", documents=docs)60result = ranker.run(documents=docs)61expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()62assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))63
64def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(self):65ranker = LostInTheMiddleRanker(word_count_threshold=100)66docs = [Document(content="word" + str(i)) for i in range(1, 10)]67ordered_docs = ranker.run(documents=docs)68# assert len(ordered_docs) == len(docs)69expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()70assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs["documents"]))71
72def test_empty_documents_returns_empty_list(self):73ranker = LostInTheMiddleRanker()74result = ranker.run(documents=[])75assert result == {"documents": []}76
77def test_list_of_one_document_returns_same_document(self):78ranker = LostInTheMiddleRanker()79doc = Document(content="test")80assert ranker.run(documents=[doc]) == {"documents": [doc]}81
82@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20])83def test_lost_in_the_middle_order_with_top_k(self, top_k: int):84# tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter85docs = [Document(content=str(i)) for i in range(1, 10)]86ranker = LostInTheMiddleRanker()87result = ranker.run(documents=docs, top_k=top_k)88if top_k < len(docs):89# top_k is less than the number of documents, so only the top_k documents should be returned in LITM order90assert len(result["documents"]) == top_k91expected_order = ranker.run(documents=[Document(content=str(i)) for i in range(1, top_k + 1)])92assert result == expected_order93else:94# top_k is greater than the number of documents, so all documents should be returned in LITM order95assert len(result["documents"]) == len(docs)96assert result == ranker.run(documents=docs)97
98def test_to_dict(self):99component = LostInTheMiddleRanker()100data = component.to_dict()101assert data == {102"type": "haystack.components.rankers.lost_in_the_middle.LostInTheMiddleRanker",103"init_parameters": {"word_count_threshold": None, "top_k": None},104}105