disnake

Форк
0
/
test_utils.py 
1029 строк · 29.2 Кб
1
# SPDX-License-Identifier: MIT
2

3
import asyncio
4
import datetime
5
import functools
6
import inspect
7
import os
8
import sys
9
import warnings
10
from dataclasses import dataclass
11
from datetime import timedelta, timezone
12
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union
13
from unittest import mock
14

15
import pytest
16
import yarl
17

18
import disnake
19
from disnake import utils
20

21
from . import helpers, utils_helper_module
22

23
if TYPE_CHECKING:
24
    from typing_extensions import TypeAliasType
25
elif sys.version_info >= (3, 12):
26
    # non-3.12 tests shouldn't be using this
27
    from typing import TypeAliasType
28

29

30
def test_missing() -> None:
31
    assert utils.MISSING != utils.MISSING
32
    assert not bool(utils.MISSING)
33

34

35
def test_cached_property() -> None:
36
    class Test:
37
        @utils.cached_property
38
        def prop(self) -> object:
39
            """Does things"""
40
            return object()
41

42
    inst = Test()
43
    assert inst.prop is inst.prop
44
    assert Test.prop.__doc__ == "Does things"
45
    assert isinstance(Test.prop, utils.cached_property)
46

47

48
def test_cached_slot_property() -> None:
49
    class Test:
50
        __slots__ = ("_cs_prop",)
51

52
        @utils.cached_slot_property("_cs_prop")
53
        def prop(self) -> object:
54
            """Does things"""
55
            return object()
56

57
    inst = Test()
58
    assert inst.prop is inst.prop
59
    assert Test.prop.__doc__ == "Does things"
60
    assert isinstance(Test.prop, utils.CachedSlotProperty)
61

62

63
def test_parse_time() -> None:
64
    assert utils.parse_time(None) is None
65
    assert utils.parse_time("2021-08-29T13:50:00+00:00") == datetime.datetime(
66
        2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc
67
    )
68

69

70
def test_copy_doc() -> None:
71
    def func(num: int, *, arg: str) -> float:
72
        """Returns the best number"""
73
        ...
74

75
    @utils.copy_doc(func)
76
    def func2(*args: Any, **kwargs: Any) -> Any:
77
        ...
78

79
    assert func2.__doc__ == func.__doc__
80
    assert inspect.signature(func) == inspect.signature(func2)
81

82

83
@mock.patch.object(warnings, "warn")
84
@pytest.mark.parametrize(
85
    ("instead", "msg"),
86
    [(None, "stuff is deprecated."), ("other", "stuff is deprecated, use other instead.")],
87
)
88
def test_deprecated(mock_warn: mock.Mock, instead, msg) -> None:
89
    @utils.deprecated(instead)
90
    def stuff(num: int) -> int:
91
        return num
92

93
    assert stuff(42) == 42
94
    mock_warn.assert_called_once_with(msg, stacklevel=3, category=DeprecationWarning)
95

96

97
@pytest.mark.parametrize(
98
    ("expected", "perms", "guild", "redirect", "scopes", "disable_select"),
99
    [
100
        (
101
            {"scope": "bot"},
102
            utils.MISSING,
103
            utils.MISSING,
104
            utils.MISSING,
105
            utils.MISSING,
106
            False,
107
        ),
108
        (
109
            {
110
                "scope": "bot applications.commands",
111
                "permissions": "42",
112
                "guild_id": "9999",
113
                "response_type": "code",
114
                "redirect_uri": "http://endless.horse",
115
                "disable_guild_select": "true",
116
            },
117
            disnake.Permissions(42),
118
            disnake.Object(9999),
119
            "http://endless.horse",
120
            ["bot", "applications.commands"],
121
            True,
122
        ),
123
    ],
124
)
125
def test_oauth_url(expected, perms, guild, redirect, scopes, disable_select) -> None:
126
    url = utils.oauth_url(
127
        1234,
128
        permissions=perms,
129
        guild=guild,
130
        redirect_uri=redirect,
131
        scopes=scopes,
132
        disable_guild_select=disable_select,
133
    )
134
    assert dict(yarl.URL(url).query) == {"client_id": "1234", **expected}
135

136

137
@pytest.mark.parametrize(
138
    ("num", "expected"),
139
    [
140
        (0, datetime.datetime(2015, 1, 1, tzinfo=timezone.utc)),
141
        (881536165478499999, datetime.datetime(2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc)),
142
        (10000000000000000000, datetime.datetime(2090, 7, 20, 17, 49, 51, tzinfo=timezone.utc)),
143
    ],
144
)
145
def test_snowflake_time(num: int, expected) -> None:
146
    assert utils.snowflake_time(num).replace(microsecond=0) == expected
147

148

149
@pytest.mark.parametrize(
150
    ("dt", "expected"),
151
    [
152
        (datetime.datetime(2015, 1, 1, tzinfo=timezone.utc), 0),
153
        (datetime.datetime(2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc), 881536165478400000),
154
    ],
155
)
156
def test_time_snowflake(dt, expected) -> None:
157
    low = utils.time_snowflake(dt)
158
    assert low == expected
159

160
    high = utils.time_snowflake(dt, high=True)
161
    assert low < high
162
    assert high + 1 == utils.time_snowflake(dt + timedelta(milliseconds=1))
163

164

165
def test_find() -> None:
166
    pred = lambda i: i == 42  # type: ignore
167
    assert utils.find(pred, []) is None
168
    assert utils.find(pred, [42]) == 42
169
    assert utils.find(pred, [1, 2, 42, 3, 4]) == 42
170

171
    pred = lambda i: i.id == 42  # type: ignore
172
    lst = list(map(disnake.Object, [1, 42, 42, 2]))
173
    assert utils.find(pred, lst) is lst[1]
174

175

176
def test_get() -> None:
177
    @dataclass
178
    class A:
179
        value: int
180

181
    @dataclass
182
    class B:
183
        value: int
184
        a: A
185

186
    lst = [B(123, A(42))]
187
    with pytest.raises(AttributeError):
188
        utils.get(lst, something=None)
189

190
    # test special case for len(lst) == 1
191
    assert utils.get(lst, value=123) == lst[0]
192
    assert utils.get(lst, a__value=42) == lst[0]
193
    assert utils.get(lst, value=111111) is None
194

195
    # general case
196
    lst += [B(456, A(42)), B(789, A(99999))]
197
    assert utils.get(lst, value=789) == lst[2]
198
    assert utils.get(lst, a__value=42) == lst[0]
199

200
    assert utils.get(lst, value=456, a__value=42) is lst[1]
201
    assert utils.get(lst, value=789, a__value=42) is None
202

203

204
@pytest.mark.parametrize(
205
    ("it", "expected"),
206
    [
207
        ([1, 1, 1], [1]),
208
        ([3, 2, 1, 2], [3, 2, 1]),
209
        ([1, 2], [1, 2]),
210
        ([2, 1], [2, 1]),
211
    ],
212
)
213
def test_unique(it, expected) -> None:
214
    assert utils._unique(it) == expected
215

216

217
@pytest.mark.parametrize(
218
    ("data", "expected"),
219
    [
220
        ({}, None),
221
        ({"a": 42}, None),
222
        ({"key": None}, None),
223
        ({"key": 42}, 42),
224
        ({"key": "42"}, 42),
225
    ],
226
)
227
def test_get_as_snowflake(data, expected) -> None:
228
    assert utils._get_as_snowflake(data, "key") == expected
229

230

231
def test_maybe_cast() -> None:
232
    convert = lambda v: v + 1  # type: ignore
233
    default = object()
234

235
    assert utils._maybe_cast(utils.MISSING, convert) is None
236
    assert utils._maybe_cast(utils.MISSING, convert, default) is default
237

238
    assert utils._maybe_cast(42, convert) == 43
239
    assert utils._maybe_cast(42, convert, default) == 43
240

241

242
@pytest.mark.parametrize(
243
    ("data", "expected_mime", "expected_ext"),
244
    [
245
        (b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A", "image/png", ".png"),
246
        (b"\xFF\xD8\xFFxxxJFIF", "image/jpeg", ".jpg"),
247
        (b"\xFF\xD8\xFFxxxExif", "image/jpeg", ".jpg"),
248
        (b"\xFF\xD8\xFFxxxxxxxxxxxx", "image/jpeg", ".jpg"),
249
        (b"xxxxxxJFIF", "image/jpeg", ".jpg"),
250
        (b"\x47\x49\x46\x38\x37\x61", "image/gif", ".gif"),
251
        (b"\x47\x49\x46\x38\x39\x61", "image/gif", ".gif"),
252
        (b"RIFFxxxxWEBP", "image/webp", ".webp"),
253
    ],
254
)
255
def test_mime_type_valid(data, expected_mime, expected_ext) -> None:
256
    for d in (data, data + b"\xFF"):
257
        assert utils._get_mime_type_for_image(d) == expected_mime
258
        assert utils._get_extension_for_image(d) == expected_ext
259

260
    prefixed = b"\xFF" + data
261
    with pytest.raises(ValueError, match=r"Unsupported image type given"):
262
        utils._get_mime_type_for_image(prefixed)
263
    assert utils._get_extension_for_image(prefixed) is None
264

265

266
@pytest.mark.parametrize(
267
    "data",
268
    [
269
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\xFF",  # invalid png end
270
        b"\x47\x49\x46\x38\x38\x61",  # invalid gif version
271
        b"RIFFxxxxAAAA",
272
        b"AAAAxxxxWEBP",
273
        b"",
274
    ],
275
)
276
def test_mime_type_invalid(data) -> None:
277
    with pytest.raises(ValueError, match=r"Unsupported image type given"):
278
        utils._get_mime_type_for_image(data)
279
    assert utils._get_extension_for_image(data) is None
280

281

282
@pytest.mark.asyncio
283
async def test_assetbytes_base64() -> None:
284
    assert await utils._assetbytes_to_base64_data(None) is None
285

286
    # test bytes
287
    data = b"RIFFabcdWEBPwxyz"
288
    expected = ""
289
    for conv in (bytes, bytearray, memoryview):
290
        assert await utils._assetbytes_to_base64_data(conv(data)) == expected
291

292
    # test assets
293
    mock_asset = mock.Mock(disnake.Asset)
294
    mock_asset.read = mock.AsyncMock(return_value=data)
295

296
    assert await utils._assetbytes_to_base64_data(mock_asset) == expected
297

298

299
@pytest.mark.parametrize(
300
    ("after", "use_clock", "expected"),
301
    [
302
        # use reset_after
303
        (42, False, 42),
304
        # use timestamp
305
        (42, True, 7),
306
        (utils.MISSING, False, 7),
307
        (utils.MISSING, True, 7),
308
    ],
309
)
310
@helpers.freeze_time()
311
def test_parse_ratelimit_header(after, use_clock, expected) -> None:
312
    reset_time = utils.utcnow() + timedelta(seconds=7)
313

314
    request = mock.Mock()
315
    request.headers = {"X-Ratelimit-Reset": reset_time.timestamp()}
316
    if after is not utils.MISSING:
317
        request.headers["X-Ratelimit-Reset-After"] = after
318

319
    assert utils._parse_ratelimit_header(request, use_clock=use_clock) == expected
320

321

322
@pytest.mark.parametrize(
323
    "func",
324
    [
325
        mock.Mock(),
326
        mock.AsyncMock(),
327
    ],
328
)
329
@pytest.mark.asyncio
330
async def test_maybe_coroutine(func: mock.Mock) -> None:
331
    assert await utils.maybe_coroutine(func, 42, arg="uwu") is func.return_value
332
    func.assert_called_once_with(42, arg="uwu")
333

334

335
@pytest.mark.parametrize("mock_type", [mock.Mock, mock.AsyncMock])
336
@pytest.mark.parametrize(
337
    ("gen", "expected"),
338
    [
339
        ([], True),
340
        ([True], True),
341
        ([False], False),
342
        ([False, True], False),
343
        ([True, False, True], False),
344
    ],
345
)
346
@pytest.mark.asyncio
347
async def test_async_all(mock_type, gen, expected) -> None:
348
    assert await utils.async_all(mock_type(return_value=x)() for x in gen) is expected
349

350

351
@pytest.mark.looptime
352
@pytest.mark.asyncio
353
async def test_sane_wait_for(looptime) -> None:
354
    times = [10, 50, 25]
355

356
    def create():
357
        return [asyncio.sleep(n) for n in times]
358

359
    # no timeout
360
    await utils.sane_wait_for(create(), timeout=100)
361
    assert looptime == 50
362

363
    # timeout
364
    tasks = [asyncio.create_task(c) for c in create()]
365
    with pytest.raises(asyncio.TimeoutError):
366
        await utils.sane_wait_for(tasks, timeout=40)
367
    assert looptime == 90
368

369
    assert [t.done() for t in tasks] == [True, False, True]
370

371
    # tasks should continue running even if timeout occurred
372
    await asyncio.sleep(1000)
373
    assert all(t.done() for t in tasks)
374

375

376
def test_get_slots() -> None:
377
    class A:
378
        __slots__ = ("a", "a2")
379

380
    class B:
381
        __slots__ = ()
382

383
    class C(A):
384
        __slots__ = {"c": "uwu"}
385

386
    class D(B, C):
387
        __slots__ = "xyz"  # noqa: PLC0205  # this is intentional
388

389
    assert list(utils.get_slots(D)) == ["a", "a2", "c", "xyz"]
390

391

392
@pytest.mark.parametrize(
393
    "tz",
394
    [
395
        # naive datetime
396
        utils.MISSING,
397
        # aware datetime
398
        None,
399
        timezone.utc,
400
        timezone(timedelta(hours=-9)),
401
    ],
402
)
403
@pytest.mark.parametrize(("delta", "expected"), [(7, 7), (-100, 0)])
404
@helpers.freeze_time()
405
def test_compute_timedelta(tz, delta, expected) -> None:
406
    dt = datetime.datetime.now()  # noqa: DTZ005
407
    if tz is not utils.MISSING:
408
        dt = dt.astimezone(tz)
409
    assert utils.compute_timedelta(dt + timedelta(seconds=delta)) == expected
410

411

412
@pytest.mark.parametrize(("delta", "expected"), [(0, 0), (42, 42), (-100, 0)])
413
@pytest.mark.looptime
414
@pytest.mark.asyncio
415
@helpers.freeze_time()
416
async def test_sleep_until(looptime, delta, expected) -> None:
417
    o = object()
418
    assert await utils.sleep_until(utils.utcnow() + timedelta(seconds=delta), o) is o
419
    assert looptime == expected
420

421

422
def test_utcnow() -> None:
423
    assert utils.utcnow().tzinfo == timezone.utc
424

425

426
def test_valid_icon_size() -> None:
427
    for s in (2**x for x in range(4, 13)):
428
        assert utils.valid_icon_size(s)
429

430
    for s in [0, 1, 2, 8, 24, 100, 2**20]:
431
        assert not utils.valid_icon_size(s)
432

433

434
@pytest.mark.parametrize(("s", "expected"), [("a一b", 4), ("abc", 3)])
435
def test_string_width(s, expected) -> None:
436
    assert utils._string_width(s) == expected
437

438

439
@pytest.mark.parametrize(
440
    ("url", "params", "expected"),
441
    [
442
        (mock.Mock(disnake.Invite, code="uwu"), {}, "uwu"),
443
        ("uwu", {}, "uwu"),
444
        ("https://discord.com/disnake", {}, "https://discord.com/disnake"),
445
        ("https://discord.com/invite/disnake", {}, "disnake"),
446
        ("http://discord.gg/disnake?param=123%20456", {"param": "123 456"}, "disnake"),
447
        ("https://discordapp.com/invite/disnake?a=1&a=2", {"a": "1"}, "disnake"),
448
    ],
449
)
450
@pytest.mark.parametrize("with_params", [False, True])
451
def test_resolve_invite(url, params, expected, with_params) -> None:
452
    res = utils.resolve_invite(url, with_params=with_params)
453
    if with_params:
454
        assert res == (expected, params)
455
    else:
456
        assert res == expected
457

458

459
@pytest.mark.parametrize(
460
    ("url", "expected"),
461
    [
462
        (mock.Mock(disnake.Template, code="uwu"), "uwu"),
463
        ("uwu", "uwu"),
464
        ("http://discord.com/disnake", "http://discord.com/disnake"),
465
        ("http://discord.new/disnake", "disnake"),
466
        ("https://discord.com/template/disnake", "disnake"),
467
        ("https://discordapp.com/template/disnake", "disnake"),
468
    ],
469
)
470
def test_resolve_template(url, expected) -> None:
471
    assert utils.resolve_template(url) == expected
472

473

474
@pytest.mark.parametrize(
475
    ("text", "exp_remove", "exp_escape"),
476
    [
477
        (
478
            # this is obviously not valid markdown for the most part,
479
            # it's just meant to test several combinations
480
            "*hi* ~~a~ |aaa~*\\``\n`py x``` __uwu__ y",
481
            "hi a aaa\npy x uwu y",
482
            r"\*hi\* \~\~a\~ \|aaa\~\*\\\`\`" "\n" r"\`py x\`\`\` \_\_uwu\_\_ y",
483
        ),
484
        (
485
            "aaaaa\n> h\n>> abc \n>>> te*st_",
486
            "aaaaa\nh\n>> abc \ntest",
487
            "aaaaa\n\\> h\n>> abc \n\\>>> te\\*st\\_",
488
        ),
489
        (
490
            "*h*\n> [li|nk](~~url~~) xyz **https://google.com/stuff?uwu=owo",
491
            "h\n xyz https://google.com/stuff?uwu=owo",
492
            # NOTE: currently doesn't escape inside `[x](y)`, should that be changed?
493
            r"\*h\*" "\n" r"\> \[li|nk](~~url~~) xyz \*\*https://google.com/stuff?uwu=owo",
494
        ),
495
    ],
496
)
497
def test_markdown(text: str, exp_remove, exp_escape) -> None:
498
    assert utils.remove_markdown(text, ignore_links=False) == exp_remove
499
    assert utils.remove_markdown(text, ignore_links=True) == exp_remove
500

501
    assert utils.escape_markdown(text, ignore_links=False) == exp_escape
502
    assert utils.escape_markdown(text, ignore_links=True) == exp_escape
503

504

505
@pytest.mark.parametrize(
506
    ("text", "expected", "expected_ignore"),
507
    [
508
        (
509
            "http://google.com/~test/hi_test ~~a~~",
510
            "http://google.com/test/hitest a",
511
            "http://google.com/~test/hi_test a",
512
        ),
513
        (
514
            "abc [link](http://test~test.com)\n>>> <http://endless.horse/_*>",
515
            "abc \n<http://endless.horse/>",
516
            "abc \n<http://endless.horse/_*>",
517
        ),
518
    ],
519
)
520
def test_markdown_links(text: str, expected, expected_ignore) -> None:
521
    assert utils.remove_markdown(text, ignore_links=False) == expected
522
    assert utils.remove_markdown(text, ignore_links=True) == expected_ignore
523

524

525
@pytest.mark.parametrize(
526
    ("text", "expected"),
527
    [
528
        ("@everyone hey look at this cat", "@\u200beveryone hey look at this cat"),
529
        ("test @here", "test @\u200bhere"),
530
        ("<@12341234123412341> hi", "<@\u200b12341234123412341> hi"),
531
        ("<@!12341234123412341>", "<@\u200b!12341234123412341>"),
532
        ("<@&12341234123412341>", "<@\u200b&12341234123412341>"),
533
    ],
534
)
535
def test_escape_mentions(text: str, expected) -> None:
536
    assert utils.escape_mentions(text) == expected
537

538

539
@pytest.mark.parametrize(
540
    ("docstring", "expected"),
541
    [
542
        (None, ""),
543
        ("", ""),
544
        ("test abc", "test abc"),
545
        (
546
            """
547
            test
548
            hi
549

550

551
            aaaaaaa
552
            xyz: abc
553

554
            stuff
555
            -----
556
            something
557
            """,
558
            "test\nhi\n\n\naaaaaaa\nxyz: abc",
559
        ),
560
        # other chars
561
        (
562
            """
563
            stuff
564
            -----+
565
            abc
566
            """,
567
            "stuff\n-----+\nabc",
568
        ),
569
        # additional spaces in front of line
570
        (
571
            """
572
             stuff
573
            -----
574
            abc
575
            """,
576
            "stuff\n-----\nabc",
577
        ),
578
        # invalid underline length
579
        (
580
            """
581
            stuff
582
            ----
583
            abc
584
            """,
585
            "stuff\n----\nabc",
586
        ),
587
    ],
588
)
589
def test_parse_docstring_desc(docstring: Optional[str], expected) -> None:
590
    def f() -> None:
591
        ...
592

593
    f.__doc__ = docstring
594
    assert utils.parse_docstring(f) == {
595
        "description": expected,
596
        "params": {},
597
        "localization_key_name": None,
598
        "localization_key_desc": None,
599
    }
600

601

602
@pytest.mark.parametrize(
603
    ("docstring", "expected"),
604
    [
605
        (
606
            """
607
            This does stuff.
608

609
            Parameters
610
            ----------
611
            something: a value
612
            other_something: :class:`int`
613
                another value,
614
                wow
615
            thingy: a very cool thingy
616

617
            Returns
618
            -------
619
            Nothing.
620
            """,
621
            {
622
                "something": {"name": "something", "description": "a value"},
623
                "other_something": {
624
                    "name": "other_something",
625
                    "description": "another value,\nwow",
626
                },
627
                "thingy": {"name": "thingy", "description": "a very cool thingy"},
628
            },
629
        ),
630
        # invalid underline length
631
        (
632
            """
633
            Parameters
634
            ---------
635
            something: abc
636
            """,
637
            {},
638
        ),
639
        # missing next line
640
        (
641
            """
642
            Parameters
643
            ----------
644
            """,
645
            {},
646
        ),
647
    ],
648
)
649
def test_parse_docstring_param(docstring: str, expected) -> None:
650
    def f() -> None:
651
        ...
652

653
    f.__doc__ = docstring
654
    expected = {
655
        k: {**v, "type": None, "localization_key_name": None, "localization_key_desc": None}
656
        for k, v in expected.items()
657
    }
658
    assert utils.parse_docstring(f)["params"] == expected  # ignore description
659

660

661
def test_parse_docstring_localizations() -> None:
662
    def f() -> None:
663
        """Does stuff. {{cool_key}}
664

665
        Parameters
666
        ----------
667
        p1: {{ PARAM_1 }} Probably a number.
668
        p2: str
669
            Definitely a string {{   PARAM_X }}
670
        """
671

672
    assert utils.parse_docstring(f) == {
673
        "description": "Does stuff.",
674
        "localization_key_name": "cool_key_NAME",
675
        "localization_key_desc": "cool_key_DESCRIPTION",
676
        "params": {
677
            "p1": {
678
                "name": "p1",
679
                "description": "Probably a number.",
680
                "localization_key_name": "PARAM_1_NAME",
681
                "localization_key_desc": "PARAM_1_DESCRIPTION",
682
                "type": None,
683
            },
684
            "p2": {
685
                "name": "p2",
686
                "description": "Definitely a string",
687
                "localization_key_name": "PARAM_X_NAME",
688
                "localization_key_desc": "PARAM_X_DESCRIPTION",
689
                "type": None,
690
            },
691
        },
692
    }
693

694

695
@pytest.mark.parametrize(
696
    ("it", "max_size", "expected"),
697
    [
698
        ([], 3, []),
699
        ([0], 3, [[0]]),
700
        ([0, 1, 2], 2, [[0, 1], [2]]),
701
        ([0, 1, 2, 3, 4, 5], 3, [[0, 1, 2], [3, 4, 5]]),
702
    ],
703
)
704
@pytest.mark.parametrize("sync", [False, True])
705
@pytest.mark.asyncio
706
async def test_as_chunks(sync, it, max_size: int, expected) -> None:
707
    if sync:
708
        assert list(utils.as_chunks(it, max_size)) == expected
709
    else:
710

711
        async def _it():
712
            for x in it:
713
                yield x
714

715
        assert [x async for x in utils.as_chunks(_it(), max_size)] == expected
716

717

718
@pytest.mark.parametrize("max_size", [-1, 0])
719
def test_as_chunks_size(max_size: int) -> None:
720
    with pytest.raises(ValueError, match=r"Chunk sizes must be greater than 0."):
721
        utils.as_chunks(iter([]), max_size)
722

723

724
@pytest.mark.parametrize(
725
    ("params", "expected"),
726
    [
727
        ([], ()),
728
        ([disnake.CommandInter, int, Optional[str]], (disnake.CommandInter, int, Optional[str])),
729
        # check flattening + deduplication (both of these are done automatically in 3.9.1+)
730
        ([float, Literal[1, 2, Literal[3, 4]], Literal["a", "bc"]], (float, 1, 2, 3, 4, "a", "bc")),
731
        ([Literal[1, 1, 2, 3, 3]], (1, 2, 3)),
732
    ],
733
)
734
def test_flatten_literal_params(params, expected) -> None:
735
    assert utils.flatten_literal_params(params) == expected
736

737

738
NoneType = type(None)
739

740

741
@pytest.mark.parametrize(
742
    ("params", "expected"),
743
    [([NoneType], (NoneType,)), ([NoneType, int, NoneType, float], (int, float, NoneType))],
744
)
745
def test_normalise_optional_params(params, expected) -> None:
746
    assert utils.normalise_optional_params(params) == expected
747

748

749
@pytest.mark.parametrize(
750
    ("tp", "expected", "expected_cache"),
751
    [
752
        # simple types
753
        (None, NoneType, False),
754
        (int, int, False),
755
        # complex types
756
        (List[int], List[int], False),
757
        (Dict[float, "List[yarl.URL]"], Dict[float, List[yarl.URL]], True),
758
        (Literal[1, Literal[False], "hi"], Literal[1, False, "hi"], False),
759
        # unions
760
        (Union[timezone, float], Union[timezone, float], False),
761
        (Optional[int], Optional[int], False),
762
        (Union["tuple", None, int], Union[tuple, int, None], True),
763
        # forward refs
764
        ("bool", bool, True),
765
        ("Tuple[dict, List[Literal[42, 99]]]", Tuple[dict, List[Literal[42, 99]]], True),
766
        # 3.10 union syntax
767
        pytest.param(
768
            "int | float",
769
            Union[int, float],
770
            True,
771
            marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="syntax requires py3.10"),
772
        ),
773
    ],
774
)
775
def test_resolve_annotation(tp, expected, expected_cache) -> None:
776
    cache = {}
777
    result = utils.resolve_annotation(tp, globals(), locals(), cache)
778
    assert result == expected
779

780
    # check if state is what we expect
781
    assert bool(cache) == expected_cache
782
    # if it's a forward ref, resolve again and ensure cache is used
783
    if isinstance(tp, str):
784
        assert utils.resolve_annotation(tp, globals(), locals(), cache) is result
785

786

787
def test_resolve_annotation_literal() -> None:
788
    with pytest.raises(
789
        TypeError, match=r"Literal arguments must be of type str, int, bool, or NoneType."
790
    ):
791
        utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {})  # type: ignore
792

793

794
@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12")
795
class TestResolveAnnotationTypeAliasType:
796
    def test_simple(self) -> None:
797
        # this is equivalent to `type CoolList = List[int]`
798
        CoolList = TypeAliasType("CoolList", List[int])
799
        assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int]
800

801
    def test_generic(self) -> None:
802
        # this is equivalent to `type CoolList[T] = List[T]; CoolList[int]`
803
        T = TypeVar("T")
804
        CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))
805

806
        annotation = CoolList[int]
807
        assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int]
808

809
    # alias and arg in local scope
810
    def test_forwardref_local(self) -> None:
811
        T = TypeVar("T")
812
        IntOrStr = Union[int, str]
813
        CoolList = TypeAliasType("CoolList", List[T], type_params=(T,))
814

815
        annotation = CoolList["IntOrStr"]
816
        assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr]
817

818
    # alias and arg in other module scope
819
    def test_forwardref_module(self) -> None:
820
        resolved = utils.resolve_annotation(
821
            utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {}
822
        )
823
        assert resolved == List[Union[int, str]]
824

825
    # combination of the previous two, alias in other module scope and arg in local scope
826
    def test_forwardref_mixed(self) -> None:
827
        LocalIntOrStr = Union[int, str]
828

829
        annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"]
830
        assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr]
831

832
    # two different forwardrefs with same name
833
    def test_forwardref_duplicate(self) -> None:
834
        DuplicateAlias = int
835

836
        # first, resolve an annotation where `DuplicateAlias` resolves to the local int
837
        cache = {}
838
        assert (
839
            utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache)
840
            == List[int]
841
        )
842

843
        # then, resolve an annotation where the globalns changes and `DuplicateAlias` resolves to something else
844
        # (i.e. this should not resolve to `List[int]` despite {"DuplicateAlias": int} in the cache)
845
        assert (
846
            utils.resolve_annotation(
847
                utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache
848
            )
849
            == List[str]
850
        )
851

852

853
@pytest.mark.parametrize(
854
    ("dt", "style", "expected"),
855
    [
856
        (0, "F", "<t:0:F>"),
857
        (1630245000.1234, "T", "<t:1630245000:T>"),
858
        (datetime.datetime(2021, 8, 29, 13, 50, 0, tzinfo=timezone.utc), "f", "<t:1630245000:f>"),
859
    ],
860
)
861
def test_format_dt(dt, style, expected) -> None:
862
    assert utils.format_dt(dt, style) == expected
863

864

865
@pytest.fixture(scope="session")
866
def tmp_module_root(tmp_path_factory):
867
    # this obviously isn't great code, but it'll do just fine for tests
868
    tmpdir = tmp_path_factory.mktemp("module_root")
869
    for d in ["empty", "not_a_module", "mod/sub1/sub2"]:
870
        (tmpdir / d).mkdir(parents=True)
871
    for f in [
872
        "test.py",
873
        "not_a_module/abc.py",
874
        "mod/__init__.py",
875
        "mod/ext.py",
876
        "mod/sub1/sub2/__init__.py",
877
        "mod/sub1/sub2/abc.py",
878
    ]:
879
        (tmpdir / f).touch()
880
    return tmpdir
881

882

883
@pytest.mark.parametrize(
884
    ("path", "expected"),
885
    [
886
        (".", ["test", "mod.ext"]),
887
        ("./", ["test", "mod.ext"]),
888
        ("empty/", []),
889
    ],
890
)
891
def test_search_directory(tmp_module_root, path, expected) -> None:
892
    orig_cwd = os.getcwd()
893
    try:
894
        os.chdir(tmp_module_root)
895

896
        # test relative and absolute paths
897
        for p in [path, os.path.abspath(path)]:
898
            assert sorted(utils.search_directory(p)) == sorted(expected)
899
    finally:
900
        os.chdir(orig_cwd)
901

902

903
@pytest.mark.parametrize(
904
    ("path", "exc"),
905
    [
906
        ("../../", r"Modules outside the cwd require a package to be specified"),
907
        ("nonexistent", r"Provided path '.*?nonexistent' does not exist"),
908
        ("test.py", r"Provided path '.*?test.py' is not a directory"),
909
    ],
910
)
911
def test_search_directory_exc(tmp_module_root, path, exc) -> None:
912
    orig_cwd = os.getcwd()
913
    try:
914
        os.chdir(tmp_module_root)
915

916
        with pytest.raises(ValueError, match=exc):
917
            list(utils.search_directory(tmp_module_root / path))
918
    finally:
919
        os.chdir(orig_cwd)
920

921

922
@pytest.mark.parametrize(
923
    ("locale", "expected"),
924
    [
925
        ("abc", None),
926
        ("en-US", "en-US"),
927
        ("en_US", "en-US"),
928
        ("de", "de"),
929
        ("de-DE", "de"),
930
        ("de_DE", "de"),
931
    ],
932
)
933
def test_as_valid_locale(locale, expected) -> None:
934
    assert utils.as_valid_locale(locale) == expected
935

936

937
@pytest.mark.parametrize(
938
    ("values", "expected"),
939
    [
940
        ([], "<none>"),
941
        (["one"], "one"),
942
        (["one", "two"], "one plus two"),
943
        (["one", "two", "three"], "one, two, plus three"),
944
        (["one", "two", "three", "four"], "one, two, three, plus four"),
945
    ],
946
)
947
def test_humanize_list(values, expected) -> None:
948
    assert utils.humanize_list(values, "plus") == expected
949

950

951
# used for `test_signature_has_self_param`
952
def _toplevel():
953
    def inner() -> None:
954
        ...
955

956
    return inner
957

958

959
def decorator(f):
960
    @functools.wraps(f)
961
    def wrap(self, *args, **kwargs):
962
        return f(self, *args, **kwargs)
963

964
    return wrap
965

966

967
# used for `test_signature_has_self_param`
968
class _Clazz:
969
    def func(self):
970
        def inner() -> None:
971
            ...
972

973
        return inner
974

975
    @classmethod
976
    def cmethod(cls) -> None:
977
        ...
978

979
    @staticmethod
980
    def smethod() -> None:
981
        ...
982

983
    class Nested:
984
        def func(self):
985
            def inner() -> None:
986
                ...
987

988
            return inner
989

990
    rebind = _toplevel
991

992
    @decorator
993
    def decorated(self) -> None:
994
        ...
995

996
    _lambda = lambda: None
997

998

999
@pytest.mark.parametrize(
1000
    ("function", "expected"),
1001
    [
1002
        # top-level function
1003
        (_toplevel, False),
1004
        # methods in class
1005
        (_Clazz.func, True),
1006
        (_Clazz().func, False),
1007
        # unfortunately doesn't work
1008
        (_Clazz.rebind, False),
1009
        (_Clazz().rebind, False),
1010
        # classmethod/staticmethod isn't supported, but checked to ensure consistency
1011
        (_Clazz.cmethod, False),
1012
        (_Clazz.smethod, True),
1013
        # nested class methods
1014
        (_Clazz.Nested.func, True),
1015
        (_Clazz.Nested().func, False),
1016
        # inner methods
1017
        (_toplevel(), False),
1018
        (_Clazz().func(), False),
1019
        (_Clazz.Nested().func(), False),
1020
        # decorated method
1021
        (_Clazz.decorated, True),
1022
        (_Clazz().decorated, False),
1023
        # lambda (class-level)
1024
        (_Clazz._lambda, False),
1025
        (_Clazz()._lambda, False),
1026
    ],
1027
)
1028
def test_signature_has_self_param(function, expected) -> None:
1029
    assert utils.signature_has_self_param(function) == expected
1030

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.