outlines

Форк
0
/
test_openai.py 
105 строк · 3.5 Кб
1
import importlib
2
from unittest import mock
3
from unittest.mock import MagicMock
4

5
import pytest
6
from openai import AsyncOpenAI
7

8
from outlines.models.openai import (
9
    OpenAI,
10
    OpenAIConfig,
11
    build_optimistic_mask,
12
    find_longest_intersection,
13
    find_response_choices_intersection,
14
)
15

16

17
def module_patch(path):
18
    """Patch functions that have the same name as the module in which they're implemented."""
19
    target = path
20
    components = target.split(".")
21
    for i in range(len(components), 0, -1):
22
        try:
23
            # attempt to import the module
24
            imported = importlib.import_module(".".join(components[:i]))
25

26
            # module was imported, let's use it in the patch
27
            patch = mock.patch(path)
28
            patch.getter = lambda: imported
29
            patch.attribute = ".".join(components[i:])
30
            return patch
31
        except Exception:
32
            continue
33

34
    # did not find a module, just return the default mock
35
    return mock.patch(path)
36

37

38
def test_openai_call():
39
    with module_patch("outlines.models.openai.generate_chat") as mocked_generate_chat:
40
        mocked_generate_chat.return_value = ["foo"], 1, 2
41
        async_client = MagicMock(spec=AsyncOpenAI, api_key="key")
42

43
        model = OpenAI(
44
            async_client,
45
            OpenAIConfig(max_tokens=10, temperature=0.5, n=2, stop=["."]),
46
        )
47

48
        assert model("bar")[0] == "foo"
49
        assert model.prompt_tokens == 1
50
        assert model.completion_tokens == 2
51
        mocked_generate_chat_args = mocked_generate_chat.call_args
52
        mocked_generate_chat_arg_config = mocked_generate_chat_args[0][3]
53
        assert isinstance(mocked_generate_chat_arg_config, OpenAIConfig)
54
        assert mocked_generate_chat_arg_config.max_tokens == 10
55
        assert mocked_generate_chat_arg_config.temperature == 0.5
56
        assert mocked_generate_chat_arg_config.n == 2
57
        assert mocked_generate_chat_arg_config.stop == ["."]
58

59
        model("bar", samples=3)
60
        mocked_generate_chat_args = mocked_generate_chat.call_args
61
        mocked_generate_chat_arg_config = mocked_generate_chat_args[0][3]
62
        assert mocked_generate_chat_arg_config.n == 3
63

64

65
@pytest.mark.parametrize(
66
    "response,choice,expected_intersection,expected_choices_left",
67
    (
68
        ([1, 2, 3, 4], [[5, 6]], [], [[5, 6]]),
69
        ([1, 2, 3, 4], [[5, 6], [7, 8]], [], [[5, 6], [7, 8]]),
70
        ([1, 2, 3, 4], [[1, 2], [7, 8]], [1, 2], [[]]),
71
        ([1, 2], [[1, 2, 3, 4], [1, 2]], [1, 2], [[3, 4], []]),
72
        ([1, 2, 3], [[1, 2, 3, 4], [1, 2]], [1, 2, 3], [[4]]),
73
    ),
74
)
75
def test_find_response_choices_intersection(
76
    response, choice, expected_intersection, expected_choices_left
77
):
78
    intersection, choices_left = find_response_choices_intersection(response, choice)
79
    assert intersection == expected_intersection
80
    assert choices_left == expected_choices_left
81

82

83
@pytest.mark.parametrize(
84
    "response,choice,expected_prefix",
85
    (
86
        ([1, 2, 3], [1, 2, 3, 4], [1, 2, 3]),
87
        ([1, 2, 3], [1, 2, 3], [1, 2, 3]),
88
        ([4, 5], [1, 2, 3], []),
89
    ),
90
)
91
def test_find_longest_common_prefix(response, choice, expected_prefix):
92
    prefix = find_longest_intersection(response, choice)
93
    assert prefix == expected_prefix
94

95

96
@pytest.mark.parametrize(
97
    "transposed,mask_size,expected_mask",
98
    (
99
        ([{1, 2}, {3, 4}], 3, {1: 100, 2: 100, 3: 100}),
100
        ([{1, 2}, {3, 4}], 4, {1: 100, 2: 100, 3: 100, 4: 100}),
101
    ),
102
)
103
def test_build_optimistic_mask(transposed, mask_size, expected_mask):
104
    mask = build_optimistic_mask(transposed, mask_size)
105
    assert mask == expected_mask
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.