2
from unittest import mock
3
from unittest.mock import MagicMock
6
from openai import AsyncOpenAI
8
from outlines.models.openai import (
11
build_optimistic_mask,
12
find_longest_intersection,
13
find_response_choices_intersection,
17
def module_patch(path):
18
"""Patch functions that have the same name as the module in which they're implemented."""
20
components = target.split(".")
21
for i in range(len(components), 0, -1):
24
imported = importlib.import_module(".".join(components[:i]))
27
patch = mock.patch(path)
28
patch.getter = lambda: imported
29
patch.attribute = ".".join(components[i:])
35
return mock.patch(path)
38
def test_openai_call():
39
with module_patch("outlines.models.openai.generate_chat") as mocked_generate_chat:
40
mocked_generate_chat.return_value = ["foo"], 1, 2
41
async_client = MagicMock(spec=AsyncOpenAI, api_key="key")
45
OpenAIConfig(max_tokens=10, temperature=0.5, n=2, stop=["."]),
48
assert model("bar")[0] == "foo"
49
assert model.prompt_tokens == 1
50
assert model.completion_tokens == 2
51
mocked_generate_chat_args = mocked_generate_chat.call_args
52
mocked_generate_chat_arg_config = mocked_generate_chat_args[0][3]
53
assert isinstance(mocked_generate_chat_arg_config, OpenAIConfig)
54
assert mocked_generate_chat_arg_config.max_tokens == 10
55
assert mocked_generate_chat_arg_config.temperature == 0.5
56
assert mocked_generate_chat_arg_config.n == 2
57
assert mocked_generate_chat_arg_config.stop == ["."]
59
model("bar", samples=3)
60
mocked_generate_chat_args = mocked_generate_chat.call_args
61
mocked_generate_chat_arg_config = mocked_generate_chat_args[0][3]
62
assert mocked_generate_chat_arg_config.n == 3
65
@pytest.mark.parametrize(
66
"response,choice,expected_intersection,expected_choices_left",
68
([1, 2, 3, 4], [[5, 6]], [], [[5, 6]]),
69
([1, 2, 3, 4], [[5, 6], [7, 8]], [], [[5, 6], [7, 8]]),
70
([1, 2, 3, 4], [[1, 2], [7, 8]], [1, 2], [[]]),
71
([1, 2], [[1, 2, 3, 4], [1, 2]], [1, 2], [[3, 4], []]),
72
([1, 2, 3], [[1, 2, 3, 4], [1, 2]], [1, 2, 3], [[4]]),
75
def test_find_response_choices_intersection(
76
response, choice, expected_intersection, expected_choices_left
78
intersection, choices_left = find_response_choices_intersection(response, choice)
79
assert intersection == expected_intersection
80
assert choices_left == expected_choices_left
83
@pytest.mark.parametrize(
84
"response,choice,expected_prefix",
86
([1, 2, 3], [1, 2, 3, 4], [1, 2, 3]),
87
([1, 2, 3], [1, 2, 3], [1, 2, 3]),
88
([4, 5], [1, 2, 3], []),
91
def test_find_longest_common_prefix(response, choice, expected_prefix):
92
prefix = find_longest_intersection(response, choice)
93
assert prefix == expected_prefix
96
@pytest.mark.parametrize(
97
"transposed,mask_size,expected_mask",
99
([{1, 2}, {3, 4}], 3, {1: 100, 2: 100, 3: 100}),
100
([{1, 2}, {3, 4}], 4, {1: 100, 2: 100, 3: 100, 4: 100}),
103
def test_build_optimistic_mask(transposed, mask_size, expected_mask):
104
mask = build_optimistic_mask(transposed, mask_size)
105
assert mask == expected_mask