1
# SPDX-License-Identifier: MIT
4
from typing import Any, Iterator, Type, TypeVar
7
from typing_extensions import assert_type
10
from disnake.ui.button import V_co
12
T = TypeVar("T", bound=ui.Item)
15
@contextlib.contextmanager
16
def create_callback(item_type: Type[T]) -> Iterator["ui.item.ItemCallbackType[T]"]:
17
async def callback(self, item, inter) -> None:
18
pytest.fail("callback should not be invoked")
22
# ensure instantiation works
23
assert callback.__discord_ui_model_type__(**callback.__discord_ui_model_kwargs__)
26
class _CustomButton(ui.Button[V_co]):
27
def __init__(self, *, param: float = 42.0) -> None:
32
def test_default(self) -> None:
33
with create_callback(ui.Button[ui.View]) as func:
34
res = ui.button(custom_id="123")(func)
35
assert_type(res, ui.item.DecoratedItem[ui.Button[ui.View]])
37
assert func.__discord_ui_model_type__ is ui.Button
38
assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"}
40
with create_callback(ui.StringSelect[ui.View]) as func:
41
res = ui.string_select(custom_id="123")(func)
42
assert_type(res, ui.item.DecoratedItem[ui.StringSelect[ui.View]])
44
assert func.__discord_ui_model_type__ is ui.StringSelect
45
assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"}
47
# from here on out we're mostly only testing the button decorator,
48
# as @ui.string_select etc. works identically
50
@pytest.mark.parametrize("cls", [_CustomButton, _CustomButton[Any]])
51
def test_cls(self, cls: Type[_CustomButton]) -> None:
52
with create_callback(cls) as func:
53
res = ui.button(cls=cls, param=1337)(func)
54
assert_type(res, ui.item.DecoratedItem[cls])
56
# should strip to origin type
57
assert func.__discord_ui_model_type__ is _CustomButton
58
assert func.__discord_ui_model_kwargs__ == {"param": 1337}
61
def _test_typing_cls(self) -> None:
64
this_should_not_work="h", # type: ignore
67
@pytest.mark.parametrize(
68
("decorator", "invalid_cls"),
70
(ui.button, ui.StringSelect),
71
(ui.string_select, ui.Button),
72
(ui.user_select, ui.Button),
73
(ui.role_select, ui.Button),
74
(ui.mentionable_select, ui.Button),
75
(ui.channel_select, ui.Button),
78
def test_cls_invalid(self, decorator, invalid_cls) -> None:
79
for cls in [123, int, invalid_cls]:
80
with pytest.raises(TypeError, match=r"cls argument must be"):