rag-chatbot-2
62 строки · 1.7 Кб
1from unittest.mock import patch
2
3import pytest
4from bot.client.ctransformers_client import CtransformersClient
5from bot.model.model_settings import ModelType, get_model_setting
6
7
8@pytest.fixture
9def cpu_config():
10config = {
11"top_k": 40,
12"top_p": 0.95,
13"temperature": 0.7,
14"repetition_penalty": 1.1,
15"last_n_tokens": 64,
16"seed": -1,
17"batch_size": 8,
18"threads": -1,
19"max_new_tokens": 1024,
20"stop": None,
21"stream": False,
22"reset": True,
23"context_length": 2048,
24"gpu_layers": 0,
25"mmap": True,
26"mlock": False,
27}
28return config
29
30
31@pytest.fixture
32def valid_model_settings():
33model_setting = get_model_setting(ModelType.ZEPHYR.value)
34return model_setting
35
36
37@pytest.fixture
38def invalid_model_settings():
39return get_model_setting(ModelType.OPENCHAT.value)
40
41
42@pytest.fixture
43def ctransformers_client(mock_model_folder, valid_model_settings, cpu_config):
44with patch.object(valid_model_settings, "config", cpu_config):
45return CtransformersClient(mock_model_folder, valid_model_settings)
46
47
48def test_init_raises_value_error_for_invalid_client_type(mock_model_folder, invalid_model_settings):
49with pytest.raises(ValueError):
50CtransformersClient(mock_model_folder, invalid_model_settings)
51
52
53def test_encode_prompt(ctransformers_client):
54prompt = "Test prompt"
55encoded_prompt = ctransformers_client._encode_prompt(prompt)
56assert encoded_prompt is not None
57
58
59def test_generate_answer(ctransformers_client):
60prompt = "Tell me a joke"
61generated_answer = ctransformers_client.generate_answer(prompt, max_new_tokens=10)
62assert generated_answer is not None
63