10
from dataclasses import dataclass
11
from datetime import timedelta, timezone
12
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union
13
from unittest import mock
19
from disnake import utils
21
from . import helpers, utils_helper_module
24
from typing_extensions import TypeAliasType
25
elif sys.version_info >= (3, 12):
27
from typing import TypeAliasType
30
def test_missing() -> None:
31
assert utils.MISSING != utils.MISSING
32
assert not bool(utils.MISSING)
35
def test_cached_property() -> None:
37
@utils.cached_property
38
def prop(self) -> object:
43
assert inst.prop is inst.prop
44
assert Test.prop.__doc__ == "Does things"
45
assert isinstance(Test.prop, utils.cached_property)
48
def test_cached_slot_property() -> None:
50
__slots__ = ("_cs_prop",)
52
@utils.cached_slot_property("_cs_prop")
53
def prop(self) -> object:
58
assert inst.prop is inst.prop
59
assert Test.prop.__doc__ == "Does things"
60
assert isinstance(Test.prop, utils.CachedSlotProperty)
63
def test_parse_time() -> None:
64
assert utils.parse_time(None) is None
65
assert utils.parse_time("2021-08-29T13:50:00+00:00") == datetime.datetime(
66
2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc
70
def test_copy_doc() -> None:
71
def func(num: int, *, arg: str) -> float:
72
"""Returns the best number"""
76
def func2(*args: Any, **kwargs: Any) -> Any:
79
assert func2.__doc__ == func.__doc__
80
assert inspect.signature(func) == inspect.signature(func2)
83
@mock.patch.object(warnings, "warn")
84
@pytest.mark.parametrize(
86
[(None, "stuff is deprecated."), ("other", "stuff is deprecated, use other instead.")],
88
def test_deprecated(mock_warn: mock.Mock, instead, msg) -> None:
89
@utils.deprecated(instead)
90
def stuff(num: int) -> int:
93
assert stuff(42) == 42
94
mock_warn.assert_called_once_with(msg, stacklevel=3, category=DeprecationWarning)
97
@pytest.mark.parametrize(
98
("expected", "perms", "guild", "redirect", "scopes", "disable_select"),
110
"scope": "bot applications.commands",
113
"response_type": "code",
114
"redirect_uri": "http://endless.horse",
115
"disable_guild_select": "true",
117
disnake.Permissions(42),
118
disnake.Object(9999),
119
"http://endless.horse",
120
["bot", "applications.commands"],
125
def test_oauth_url(expected, perms, guild, redirect, scopes, disable_select) -> None:
126
url = utils.oauth_url(
130
redirect_uri=redirect,
132
disable_guild_select=disable_select,
134
assert dict(yarl.URL(url).query) == {"client_id": "1234", **expected}
137
@pytest.mark.parametrize(
140
(0, datetime.datetime(2015, 1, 1, tzinfo=timezone.utc)),
141
(881536165478499999, datetime.datetime(2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc)),
142
(10000000000000000000, datetime.datetime(2090, 7, 20, 17, 49, 51, tzinfo=timezone.utc)),
145
def test_snowflake_time(num: int, expected) -> None:
146
assert utils.snowflake_time(num).replace(microsecond=0) == expected
149
@pytest.mark.parametrize(
152
(datetime.datetime(2015, 1, 1, tzinfo=timezone.utc), 0),
153
(datetime.datetime(2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc), 881536165478400000),
156
def test_time_snowflake(dt, expected) -> None:
157
low = utils.time_snowflake(dt)
158
assert low == expected
160
high = utils.time_snowflake(dt, high=True)
162
assert high + 1 == utils.time_snowflake(dt + timedelta(milliseconds=1))
165
def test_find() -> None:
166
pred = lambda i: i == 42
167
assert utils.find(pred, []) is None
168
assert utils.find(pred, [42]) == 42
169
assert utils.find(pred, [1, 2, 42, 3, 4]) == 42
171
pred = lambda i: i.id == 42
172
lst = list(map(disnake.Object, [1, 42, 42, 2]))
173
assert utils.find(pred, lst) is lst[1]
176
def test_get() -> None:
186
lst = [B(123, A(42))]
187
with pytest.raises(AttributeError):
188
utils.get(lst, something=None)
191
assert utils.get(lst, value=123) == lst[0]
192
assert utils.get(lst, a__value=42) == lst[0]
193
assert utils.get(lst, value=111111) is None
196
lst += [B(456, A(42)), B(789, A(99999))]
197
assert utils.get(lst, value=789) == lst[2]
198
assert utils.get(lst, a__value=42) == lst[0]
200
assert utils.get(lst, value=456, a__value=42) is lst[1]
201
assert utils.get(lst, value=789, a__value=42) is None
204
@pytest.mark.parametrize(
208
([3, 2, 1, 2], [3, 2, 1]),
213
def test_unique(it, expected) -> None:
214
assert utils._unique(it) == expected
217
@pytest.mark.parametrize(
218
("data", "expected"),
222
({"key": None}, None),
227
def test_get_as_snowflake(data, expected) -> None:
228
assert utils._get_as_snowflake(data, "key") == expected
231
def test_maybe_cast() -> None:
232
convert = lambda v: v + 1
235
assert utils._maybe_cast(utils.MISSING, convert) is None
236
assert utils._maybe_cast(utils.MISSING, convert, default) is default
238
assert utils._maybe_cast(42, convert) == 43
239
assert utils._maybe_cast(42, convert, default) == 43
242
@pytest.mark.parametrize(
243
("data", "expected_mime", "expected_ext"),
245
(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A", "image/png", ".png"),
246
(b"\xFF\xD8\xFFxxxJFIF", "image/jpeg", ".jpg"),
247
(b"\xFF\xD8\xFFxxxExif", "image/jpeg", ".jpg"),
248
(b"\xFF\xD8\xFFxxxxxxxxxxxx", "image/jpeg", ".jpg"),
249
(b"xxxxxxJFIF", "image/jpeg", ".jpg"),
250
(b"\x47\x49\x46\x38\x37\x61", "image/gif", ".gif"),
251
(b"\x47\x49\x46\x38\x39\x61", "image/gif", ".gif"),
252
(b"RIFFxxxxWEBP", "image/webp", ".webp"),
255
def test_mime_type_valid(data, expected_mime, expected_ext) -> None:
256
for d in (data, data + b"\xFF"):
257
assert utils._get_mime_type_for_image(d) == expected_mime
258
assert utils._get_extension_for_image(d) == expected_ext
260
prefixed = b"\xFF" + data
261
with pytest.raises(ValueError, match=r"Unsupported image type given"):
262
utils._get_mime_type_for_image(prefixed)
263
assert utils._get_extension_for_image(prefixed) is None
266
@pytest.mark.parametrize(
269
b"\x89\x50\x4E\x47\x0D\x0A\x1A\xFF",
270
b"\x47\x49\x46\x38\x38\x61",
276
def test_mime_type_invalid(data) -> None:
277
with pytest.raises(ValueError, match=r"Unsupported image type given"):
278
utils._get_mime_type_for_image(data)
279
assert utils._get_extension_for_image(data) is None
283
async def test_assetbytes_base64() -> None:
284
assert await utils._assetbytes_to_base64_data(None) is None
287
data = b"RIFFabcdWEBPwxyz"
288
expected = "data:image/webp;base64,UklGRmFiY2RXRUJQd3h5eg=="
289
for conv in (bytes, bytearray, memoryview):
290
assert await utils._assetbytes_to_base64_data(conv(data)) == expected
293
mock_asset = mock.Mock(disnake.Asset)
294
mock_asset.read = mock.AsyncMock(return_value=data)
296
assert await utils._assetbytes_to_base64_data(mock_asset) == expected
299
@pytest.mark.parametrize(
300
("after", "use_clock", "expected"),
306
(utils.MISSING, False, 7),
307
(utils.MISSING, True, 7),
310
@helpers.freeze_time()
311
def test_parse_ratelimit_header(after, use_clock, expected) -> None:
312
reset_time = utils.utcnow() + timedelta(seconds=7)
314
request = mock.Mock()
315
request.headers = {"X-Ratelimit-Reset": reset_time.timestamp()}
316
if after is not utils.MISSING:
317
request.headers["X-Ratelimit-Reset-After"] = after
319
assert utils._parse_ratelimit_header(request, use_clock=use_clock) == expected
322
@pytest.mark.parametrize(
330
async def test_maybe_coroutine(func: mock.Mock) -> None:
331
assert await utils.maybe_coroutine(func, 42, arg="uwu") is func.return_value
332
func.assert_called_once_with(42, arg="uwu")
335
@pytest.mark.parametrize("mock_type", [mock.Mock, mock.AsyncMock])
336
@pytest.mark.parametrize(
342
([False, True], False),
343
([True, False, True], False),
347
async def test_async_all(mock_type, gen, expected) -> None:
348
assert await utils.async_all(mock_type(return_value=x)() for x in gen) is expected
353
async def test_sane_wait_for(looptime) -> None:
357
return [asyncio.sleep(n) for n in times]
360
await utils.sane_wait_for(create(), timeout=100)
361
assert looptime == 50
364
tasks = [asyncio.create_task(c) for c in create()]
365
with pytest.raises(asyncio.TimeoutError):
366
await utils.sane_wait_for(tasks, timeout=40)
367
assert looptime == 90
369
assert [t.done() for t in tasks] == [True, False, True]
372
await asyncio.sleep(1000)
373
assert all(t.done() for t in tasks)
376
def test_get_slots() -> None:
378
__slots__ = ("a", "a2")
384
__slots__ = {"c": "uwu"}
389
assert list(utils.get_slots(D)) == ["a", "a2", "c", "xyz"]
392
@pytest.mark.parametrize(
400
timezone(timedelta(hours=-9)),
403
@pytest.mark.parametrize(("delta", "expected"), [(7, 7), (-100, 0)])
404
@helpers.freeze_time()
405
def test_compute_timedelta(tz, delta, expected) -> None:
406
dt = datetime.datetime.now()
407
if tz is not utils.MISSING:
408
dt = dt.astimezone(tz)
409
assert utils.compute_timedelta(dt + timedelta(seconds=delta)) == expected
412
@pytest.mark.parametrize(("delta", "expected"), [(0, 0), (42, 42), (-100, 0)])
415
@helpers.freeze_time()
416
async def test_sleep_until(looptime, delta, expected) -> None:
418
assert await utils.sleep_until(utils.utcnow() + timedelta(seconds=delta), o) is o
419
assert looptime == expected
422
def test_utcnow() -> None:
423
assert utils.utcnow().tzinfo == timezone.utc
426
def test_valid_icon_size() -> None:
427
for s in (2**x for x in range(4, 13)):
428
assert utils.valid_icon_size(s)
430
for s in [0, 1, 2, 8, 24, 100, 2**20]:
431
assert not utils.valid_icon_size(s)
434
@pytest.mark.parametrize(("s", "expected"), [("a一b", 4), ("abc", 3)])
435
def test_string_width(s, expected) -> None:
436
assert utils._string_width(s) == expected
439
@pytest.mark.parametrize(
440
("url", "params", "expected"),
442
(mock.Mock(disnake.Invite, code="uwu"), {}, "uwu"),
444
("https://discord.com/disnake", {}, "https://discord.com/disnake"),
445
("https://discord.com/invite/disnake", {}, "disnake"),
446
("http://discord.gg/disnake?param=123%20456", {"param": "123 456"}, "disnake"),
447
("https://discordapp.com/invite/disnake?a=1&a=2", {"a": "1"}, "disnake"),
450
@pytest.mark.parametrize("with_params", [False, True])
451
def test_resolve_invite(url, params, expected, with_params) -> None:
452
res = utils.resolve_invite(url, with_params=with_params)
454
assert res == (expected, params)
456
assert res == expected
459
@pytest.mark.parametrize(
462
(mock.Mock(disnake.Template, code="uwu"), "uwu"),
464
("http://discord.com/disnake", "http://discord.com/disnake"),
465
("http://discord.new/disnake", "disnake"),
466
("https://discord.com/template/disnake", "disnake"),
467
("https://discordapp.com/template/disnake", "disnake"),
470
def test_resolve_template(url, expected) -> None:
471
assert utils.resolve_template(url) == expected
474
@pytest.mark.parametrize(
475
("text", "exp_remove", "exp_escape"),
480
"*hi* ~~a~ |aaa~*\\``\n`py x``` __uwu__ y",
481
"hi a aaa\npy x uwu y",
482
r"\*hi\* \~\~a\~ \|aaa\~\*\\\`\`" "\n" r"\`py x\`\`\` \_\_uwu\_\_ y",
485
"aaaaa\n> h\n>> abc \n>>> te*st_",
486
"aaaaa\nh\n>> abc \ntest",
487
"aaaaa\n\\> h\n>> abc \n\\>>> te\\*st\\_",
490
"*h*\n> [li|nk](~~url~~) xyz **https://google.com/stuff?uwu=owo",
491
"h\n xyz https://google.com/stuff?uwu=owo",
493
r"\*h\*" "\n" r"\> \[li|nk](~~url~~) xyz \*\*https://google.com/stuff?uwu=owo",
497
def test_markdown(text: str, exp_remove, exp_escape) -> None:
498
assert utils.remove_markdown(text, ignore_links=False) == exp_remove
499
assert utils.remove_markdown(text, ignore_links=True) == exp_remove
501
assert utils.escape_markdown(text, ignore_links=False) == exp_escape
502
assert utils.escape_markdown(text, ignore_links=True) == exp_escape
505
@pytest.mark.parametrize(
506
("text", "expected", "expected_ignore"),
509
"http://google.com/~test/hi_test ~~a~~",
510
"http://google.com/test/hitest a",
511
"http://google.com/~test/hi_test a",
514
"abc [link](http://test~test.com)\n>>> <http://endless.horse/_*>",
515
"abc \n<http://endless.horse/>",
516
"abc \n<http://endless.horse/_*>",
520
def test_markdown_links(text: str, expected, expected_ignore) -> None:
521
assert utils.remove_markdown(text, ignore_links=False) == expected
522
assert utils.remove_markdown(text, ignore_links=True) == expected_ignore
525
@pytest.mark.parametrize(
526
("text", "expected"),
528
("@everyone hey look at this cat", "@\u200beveryone hey look at this cat"),
529
("test @here", "test @\u200bhere"),
530
("<@12341234123412341> hi", "<@\u200b12341234123412341> hi"),
531
("<@!12341234123412341>", "<@\u200b!12341234123412341>"),
532
("<@&12341234123412341>", "<@\u200b&12341234123412341>"),
535
def test_escape_mentions(text: str, expected) -> None:
536
assert utils.escape_mentions(text) == expected
539
@pytest.mark.parametrize(
540
("docstring", "expected"),
544
("test abc", "test abc"),
558
"test\nhi\n\n\naaaaaaa\nxyz: abc",
567
"stuff\n-----+\nabc",
589
def test_parse_docstring_desc(docstring: Optional[str], expected) -> None:
593
f.__doc__ = docstring
594
assert utils.parse_docstring(f) == {
595
"description": expected,
597
"localization_key_name": None,
598
"localization_key_desc": None,
602
@pytest.mark.parametrize(
603
("docstring", "expected"),
612
other_something: :class:`int`
615
thingy: a very cool thingy
622
"something": {"name": "something", "description": "a value"},
624
"name": "other_something",
625
"description": "another value,\nwow",
627
"thingy": {"name": "thingy", "description": "a very cool thingy"},
649
def test_parse_docstring_param(docstring: str, expected) -> None:
653
f.__doc__ = docstring
655
k: {**v, "type": None, "localization_key_name": None, "localization_key_desc": None}
656
for k, v in expected.items()
658
assert utils.parse_docstring(f)["params"] == expected
661
def test_parse_docstring_localizations() -> None:
663
"""Does stuff. {{cool_key}}
667
p1: {{ PARAM_1 }} Probably a number.
669
Definitely a string {{ PARAM_X }}
672
assert utils.parse_docstring(f) == {
673
"description": "Does stuff.",
674
"localization_key_name": "cool_key_NAME",
675
"localization_key_desc": "cool_key_DESCRIPTION",
679
"description": "Probably a number.",
680
"localization_key_name": "PARAM_1_NAME",
681
"localization_key_desc": "PARAM_1_DESCRIPTION",
686
"description": "Definitely a string",
687
"localization_key_name": "PARAM_X_NAME",
688
"localization_key_desc": "PARAM_X_DESCRIPTION",
695
@pytest.mark.parametrize(
696
("it", "max_size", "expected"),
700
([0, 1, 2], 2, [[0, 1], [2]]),
701
([0, 1, 2, 3, 4, 5], 3, [[0, 1, 2], [3, 4, 5]]),
704
@pytest.mark.parametrize("sync", [False, True])
706
async def test_as_chunks(sync, it, max_size: int, expected) -> None:
708
assert list(utils.as_chunks(it, max_size)) == expected
715
assert [x async for x in utils.as_chunks(_it(), max_size)] == expected
718
@pytest.mark.parametrize("max_size", [-1, 0])
719
def test_as_chunks_size(max_size: int) -> None:
720
with pytest.raises(ValueError, match=r"Chunk sizes must be greater than 0."):
721
utils.as_chunks(iter([]), max_size)
724
@pytest.mark.parametrize(
725
("params", "expected"),
728
([disnake.CommandInter, int, Optional[str]], (disnake.CommandInter, int, Optional[str])),
730
([float, Literal[1, 2, Literal[3, 4]], Literal["a", "bc"]], (float, 1, 2, 3, 4, "a", "bc")),
731
([Literal[1, 1, 2, 3, 3]], (1, 2, 3)),
734
def test_flatten_literal_params(params, expected) -> None:
735
assert utils.flatten_literal_params(params) == expected
741
@pytest.mark.parametrize(
742
("params", "expected"),
743
[([NoneType], (NoneType,)), ([NoneType, int, NoneType, float], (int, float, NoneType))],
745
def test_normalise_optional_params(params, expected) -> None:
746
assert utils.normalise_optional_params(params) == expected
749
@pytest.mark.parametrize(
750
("tp", "expected", "expected_cache"),
753
(None, NoneType, False),
756
(List[int], List[int], False),
757
(Dict[float, "List[yarl.URL]"], Dict[float, List[yarl.URL]], True),
758
(Literal[1, Literal[False], "hi"], Literal[1, False, "hi"], False),
760
(Union[timezone, float], Union[timezone, float], False),
761
(Optional[int], Optional[int], False),
762
(Union["tuple", None, int], Union[tuple, int, None], True),
764
("bool", bool, True),
765
("Tuple[dict, List[Literal[42, 99]]]", Tuple[dict, List[Literal[42, 99]]], True),
771
marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="syntax requires py3.10"),
775
def test_resolve_annotation(tp, expected, expected_cache) -> None:
777
result = utils.resolve_annotation(tp, globals(), locals(), cache)
778
assert result == expected
781
assert bool(cache) == expected_cache
783
if isinstance(tp, str):
784
assert utils.resolve_annotation(tp, globals(), locals(), cache) is result
787
def test_resolve_annotation_literal() -> None:
789
TypeError, match=r"Literal arguments must be of type str, int, bool, or NoneType."
791
utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {})
794
@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12")
795
class TestResolveAnnotationTypeAliasType:
796
def test_simple(self) -> None:
798
CoolList = TypeAliasType("CoolList", List[int])
799
assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int]
801
def test_generic(self) -> None:
804
CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))
806
annotation = CoolList[int]
807
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int]
810
def test_forwardref_local(self) -> None:
812
IntOrStr = Union[int, str]
813
CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))
815
annotation = CoolList["IntOrStr"]
816
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr]
819
def test_forwardref_module(self) -> None:
820
resolved = utils.resolve_annotation(
821
utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {}
823
assert resolved == List[Union[int, str]]
826
def test_forwardref_mixed(self) -> None:
827
LocalIntOrStr = Union[int, str]
829
annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"]
830
assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr]
833
def test_forwardref_duplicate(self) -> None:
839
utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache)
846
utils.resolve_annotation(
847
utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache
853
@pytest.mark.parametrize(
854
("dt", "style", "expected"),
857
(1630245000.1234, "T", "<t:1630245000:T>"),
858
(datetime.datetime(2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc), "f", "<t:1630245000:f>"),
861
def test_format_dt(dt, style, expected) -> None:
862
assert utils.format_dt(dt, style) == expected
865
@pytest.fixture(scope="session")
866
def tmp_module_root(tmp_path_factory):
868
tmpdir = tmp_path_factory.mktemp("module_root")
869
for d in ["empty", "not_a_module", "mod/sub1/sub2"]:
870
(tmpdir / d).mkdir(parents=True)
873
"not_a_module/abc.py",
876
"mod/sub1/sub2/__init__.py",
877
"mod/sub1/sub2/abc.py",
883
@pytest.mark.parametrize(
884
("path", "expected"),
886
(".", ["test", "mod.ext"]),
887
("./", ["test", "mod.ext"]),
891
def test_search_directory(tmp_module_root, path, expected) -> None:
892
orig_cwd = os.getcwd()
894
os.chdir(tmp_module_root)
897
for p in [path, os.path.abspath(path)]:
898
assert sorted(utils.search_directory(p)) == sorted(expected)
903
@pytest.mark.parametrize(
906
("../../", r"Modules outside the cwd require a package to be specified"),
907
("nonexistent", r"Provided path '.*?nonexistent' does not exist"),
908
("test.py", r"Provided path '.*?test.py' is not a directory"),
911
def test_search_directory_exc(tmp_module_root, path, exc) -> None:
912
orig_cwd = os.getcwd()
914
os.chdir(tmp_module_root)
916
with pytest.raises(ValueError, match=exc):
917
list(utils.search_directory(tmp_module_root / path))
922
@pytest.mark.parametrize(
923
("locale", "expected"),
933
def test_as_valid_locale(locale, expected) -> None:
934
assert utils.as_valid_locale(locale) == expected
937
@pytest.mark.parametrize(
938
("values", "expected"),
942
(["one", "two"], "one plus two"),
943
(["one", "two", "three"], "one, two, plus three"),
944
(["one", "two", "three", "four"], "one, two, three, plus four"),
947
def test_humanize_list(values, expected) -> None:
948
assert utils.humanize_list(values, "plus") == expected
961
def wrap(self, *args, **kwargs):
962
return f(self, *args, **kwargs)
976
def cmethod(cls) -> None:
980
def smethod() -> None:
993
def decorated(self) -> None:
996
_lambda = lambda: None
999
@pytest.mark.parametrize(
1000
("function", "expected"),
1005
(_Clazz.func, True),
1006
(_Clazz().func, False),
1008
(_Clazz.rebind, False),
1009
(_Clazz().rebind, False),
1011
(_Clazz.cmethod, False),
1012
(_Clazz.smethod, True),
1014
(_Clazz.Nested.func, True),
1015
(_Clazz.Nested().func, False),
1017
(_toplevel(), False),
1018
(_Clazz().func(), False),
1019
(_Clazz.Nested().func(), False),
1021
(_Clazz.decorated, True),
1022
(_Clazz().decorated, False),
1024
(_Clazz._lambda, False),
1025
(_Clazz()._lambda, False),
1028
def test_signature_has_self_param(function, expected) -> None:
1029
assert utils.signature_has_self_param(function) == expected