pytorch-lightning

Форк
0
676 строк · 23.8 Кб
1
import contextlib
2
import os
3
import random
4
from unittest import mock
5
from unittest.mock import Mock
6

7
import lightning.fabric
8
import numpy as np
9
import pytest
10
import torch
11
from lightning.fabric.utilities.data import (
12
    AttributeDict,
13
    _get_dataloader_init_args_and_kwargs,
14
    _replace_dunder_methods,
15
    _replace_value_in_saved_args,
16
    _set_sampler_epoch,
17
    _update_dataloader,
18
    _WrapAttrTag,
19
    has_iterable_dataset,
20
    has_len,
21
    suggested_max_num_workers,
22
)
23
from lightning.fabric.utilities.exceptions import MisconfigurationException
24
from lightning_utilities.test.warning import no_warning_call
25
from torch import Tensor
26
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
27

28
from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset
29

30

31
def test_has_iterable_dataset():
32
    assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))
33

34
    assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))
35

36
    class MockDatasetWithoutIterableDataset(RandomDataset):
37
        def __iter__(self):
38
            yield 1
39
            return self
40

41
    assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1)))
42

43

44
def test_has_len():
45
    assert has_len(DataLoader(RandomDataset(1, 1)))
46

47
    with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):
48
        assert has_len(DataLoader(RandomDataset(0, 0)))
49

50
    assert not has_len(DataLoader(RandomIterableDataset(1, 1)))
51

52

53
def test_replace_dunder_methods_multiple_loaders_without_init():
54
    """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`
55
    method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent
56
    class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured
57
    only sometimes because it depends on the order in which we are iterating over a set of classes we are patching.
58

59
    This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
60
    and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and
61
    b) the mechanism checking for presence of `__old__init__` works as expected.
62

63
    """
64
    classes = [DataLoader]
65
    for i in range(100):
66
        classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {}))
67

68
    before = {cls: cls.__init__ for cls in classes}
69

70
    with _replace_dunder_methods(DataLoader, "dataset"):
71
        for cls in classes[1:]:  # First one is `DataLoader`
72
            assert "__old__init__" not in cls.__dict__
73
            assert hasattr(cls, "__old__init__")
74

75
        assert "__old__init__" in DataLoader.__dict__
76
        assert hasattr(DataLoader, "__old__init__")
77

78
    for cls in classes:
79
        assert before[cls] == cls.__init__
80

81

82
class MyBaseDataLoader(DataLoader):
83
    pass
84

85

86
class DataLoaderSubclass1(DataLoader):
87
    def __init__(self, attribute1, *args, **kwargs):
88
        self.at1 = attribute1
89
        super().__init__(*args, **kwargs)
90

91

92
class DataLoaderSubclass2(DataLoaderSubclass1):
93
    def __init__(self, attribute2, *args, **kwargs):
94
        self.at2 = attribute2
95
        super().__init__(attribute2 + "-2", *args, **kwargs)
96

97

98
class MyDataLoader(MyBaseDataLoader):
99
    def __init__(self, data: Tensor, *args, **kwargs):
100
        self.data = data
101
        super().__init__(range(data.size(0)), *args, **kwargs)
102

103

104
test3_data = torch.randn((10, 20))
105

106

107
class PoptorchDataLoader(DataLoader):
108
    def __init__(self, options, *args, **kwargs):
109
        super().__init__(*args, **kwargs)
110
        self._options = options
111

112
    @property
113
    def options(self):
114
        return self._options
115

116

117
class IncompleteDataLoader(DataLoader):
118
    def __init__(self, dataset, batch_size, **kwargs):
119
        batch_size = max(batch_size - 5, 0)
120
        super().__init__(dataset, batch_size=batch_size, **kwargs)
121

122

123
class WeirdDataLoader1(DataLoader):
124
    def __init__(self, arg1, arg2, **kwargs):
125
        self.arg1 = arg1
126
        super().__init__(arg2, **kwargs)
127

128

129
class WeirdDataLoader2(DataLoader):
130
    def __init__(self, data_part1, data_part2, **kwargs):
131
        data = list(data_part1) + list(data_part2)
132
        super().__init__(data, **kwargs)
133

134

135
class NoneDataLoader(DataLoader):
136
    def __init__(self, *args, **kwargs):
137
        super().__init__(*args, **kwargs)
138

139

140
class ChangingDataLoader(DataLoader):
141
    def __init__(self, dataset, **kwargs):
142
        super().__init__(list(dataset) + list(range(5, 10)), **kwargs)
143

144

145
@pytest.mark.parametrize(
146
    ("cls", "args", "kwargs", "arg_names", "dataset", "checked_values"),
147
    [
148
        pytest.param(
149
            DataLoaderSubclass1,
150
            ("attribute1",),
151
            {"dataset": range(4), "batch_size": 2},
152
            ("attribute1",),
153
            range(4),
154
            {"batch_size": 2, "at1": "attribute1"},
155
            id="test1",
156
        ),
157
        pytest.param(
158
            DataLoaderSubclass2,
159
            ("attribute2",),
160
            {"dataset": range(4), "batch_size": 2},
161
            ("attribute2",),
162
            range(4),
163
            {"batch_size": 2, "at1": "attribute2-2", "at2": "attribute2"},
164
            id="test2",
165
        ),
166
        pytest.param(
167
            MyDataLoader,
168
            (test3_data,),
169
            {"batch_size": 2},
170
            ("data",),
171
            range(10),
172
            {"batch_size": 2, "data": test3_data},
173
            id="test3",
174
        ),
175
        pytest.param(PoptorchDataLoader, (123, [1]), {}, ("options",), [1], {"options": 123}, id="test4"),
176
        pytest.param(
177
            IncompleteDataLoader,
178
            (range(10),),
179
            {"batch_size": 10},
180
            ("dataset",),
181
            range(10),
182
            {"batch_size": 5},
183
            id="test5",
184
        ),
185
        pytest.param(
186
            WeirdDataLoader1,
187
            (10, range(10)),
188
            {"batch_size": 10},
189
            ("arg1", "arg2"),
190
            range(10),
191
            {"arg1": 10, "batch_size": 10},
192
            id="test6",
193
        ),
194
        pytest.param(
195
            WeirdDataLoader2,
196
            (range(10), range(10, 20)),
197
            {"batch_size": 10},
198
            ("data_part1", "data_part2"),
199
            list(range(20)),
200
            {"batch_size": 10},
201
            id="test7",
202
        ),
203
        pytest.param(NoneDataLoader, (None,), {}, (), None, {}, id="test8"),
204
        pytest.param(ChangingDataLoader, (range(5),), {}, ("dataset",), list(range(10)), {}, id="test9"),
205
    ],
206
)
207
def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values):
208
    with _replace_dunder_methods(DataLoader, "dataset"):
209
        dataloader = cls(*args, **kwargs)
210

211
    assert dataloader.__pl_saved_args == args
212
    assert dataloader.__pl_saved_kwargs == kwargs
213
    assert dataloader.__pl_saved_arg_names == arg_names
214
    assert dataloader.__pl_saved_default_kwargs == {}
215
    assert dataloader.__dataset == dataset
216

217
    assert dataloader.dataset == dataset
218

219
    for key, value in checked_values.items():
220
        dataloader_value = getattr(dataloader, key)
221
        if isinstance(dataloader_value, Tensor):
222
            assert dataloader_value is value
223
        else:
224
            assert dataloader_value == value
225

226
    dataloader = _update_dataloader(dataloader, dataloader.sampler)
227

228
    assert isinstance(dataloader, cls)
229
    assert not hasattr(dataloader, "__pl_saved_kwargs")
230
    assert not hasattr(dataloader, "__pl_saved_arg_names")
231
    assert not hasattr(dataloader, "__pl_saved_args")
232
    assert not hasattr(dataloader, "__pl_saved_default_kwargs")
233
    assert not hasattr(dataloader, "__dataset")
234

235
    assert dataloader.dataset == dataset
236

237
    for key, value in checked_values.items():
238
        dataloader_value = getattr(dataloader, key)
239
        if isinstance(dataloader_value, Tensor):
240
            assert dataloader_value is value
241
        else:
242
            assert dataloader_value == value
243

244

245
def test_replace_dunder_methods_extra_kwargs():
246
    class LoaderSubclass(DataLoader):
247
        def __init__(self, dataset, *args, batch_size=10, **kwargs):
248
            super().__init__(dataset, *args, batch_size=batch_size, **kwargs)
249

250
    with _replace_dunder_methods(DataLoader, "dataset"):
251
        dataloader = LoaderSubclass(range(10))
252

253
    assert dataloader.__pl_saved_args == (range(10),)
254
    assert dataloader.__pl_saved_kwargs == {}
255
    assert dataloader.__pl_saved_arg_names == ("dataset",)
256
    assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10}
257
    assert dataloader.__dataset == range(10)
258

259

260
def test_replace_dunder_methods_attrs():
261
    """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` are
262
    correctly preserved even after reinstantiation.
263

264
    It also includes a custom `__setattr__`
265

266
    """
267

268
    class Loader(DataLoader):
269
        def __setattr__(self, attr, val):
270
            if attr == "custom_arg":
271
                val = val + 2
272
            super().__setattr__(attr, val)
273

274
    with _replace_dunder_methods(DataLoader, "dataset"):
275
        dataloader = Loader(range(10))
276
        dataloader.custom_arg = 5
277
        dataloader.my_arg = 10
278
        dataloader.another_arg = 100
279
        del dataloader.dataset
280
        with contextlib.suppress(AttributeError):
281
            del dataloader.abc_arg
282

283
    assert dataloader.__pl_saved_args == (range(10),)
284
    assert dataloader.__pl_saved_kwargs == {}
285
    assert dataloader.__pl_saved_arg_names == ("dataset",)
286
    assert dataloader.__dataset == range(10)
287
    assert dataloader.custom_arg == 7
288
    assert dataloader.my_arg == 10
289
    assert dataloader.another_arg == 100
290
    assert not hasattr(dataloader, "dataset")
291
    assert dataloader.__pl_attrs_record == [
292
        (("custom_arg", 5), _WrapAttrTag.SET),
293
        (("my_arg", 10), _WrapAttrTag.SET),
294
        (("another_arg", 100), _WrapAttrTag.SET),
295
        (("dataset",), _WrapAttrTag.DEL),
296
    ]
297

298
    dataloader = _update_dataloader(dataloader, dataloader.sampler)
299
    assert dataloader.custom_arg == 7
300
    assert dataloader.my_arg == 10
301
    assert dataloader.another_arg == 100
302
    assert not hasattr(dataloader, "dataset")
303

304

305
def test_replace_dunder_methods_restore_methods():
306
    """This tests checks whether are all dunder methods restored to their original versions."""
307

308
    class Init(DataLoader):
309
        def __init__(self, *args, **kwargs):
310
            super().__init__(*args, **kwargs)
311

312
    class SetAttr(DataLoader):
313
        def __setattr__(self, *args):
314
            return super().__setattr__(*args)
315

316
    class DelAttr(DataLoader):
317
        def __delattr__(self, *args):
318
            return super().__delattr__(*args)
319

320
    class InitAndSetAttr(Init, SetAttr):
321
        pass
322

323
    class InitAndDelAttr(Init, DelAttr):
324
        pass
325

326
    class SetAttrAndDelAttr(SetAttr, DelAttr):
327
        pass
328

329
    class AllDunder(Init, SetAttr, DelAttr):
330
        pass
331

332
    before = {}
333
    for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder):
334
        before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__}
335

336
    with _replace_dunder_methods(DataLoader, "dataset"):
337
        pass
338

339
    for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder):
340
        assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__}
341

342

343
@pytest.mark.parametrize(
344
    (
345
        "args",
346
        "kwargs",
347
        "default_kwargs",
348
        "arg_names",
349
        "replace_key",
350
        "replace_value",
351
        "expected_status",
352
        "expected_args",
353
        "expected_kwargs",
354
    ),
355
    [
356
        pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"),
357
        pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"),
358
        pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"),
359
        pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"),
360
        pytest.param(
361
            (1, 2, 3),
362
            {"a": 1},
363
            {"e": 5},
364
            ["b", "c", "d"],
365
            "e",
366
            2,
367
            True,
368
            (1, 2, 3),
369
            {"a": 1, "e": 2},
370
            id="default_kwargs",
371
        ),
372
    ],
373
)
374
def test_replace_value_in_args(
375
    args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs
376
):
377
    assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == (
378
        expected_status,
379
        expected_args,
380
        expected_kwargs,
381
    )
382

383

384
def test_update_dataloader_typerror_custom_exception():
385
    class BadStandaloneGoodHookImpl(DataLoader):
386
        def __init__(self, foo, *args, **kwargs):
387
            self.foo = foo
388
            # positional conflict with `dataset`
389
            super().__init__(foo, *args, **kwargs)
390

391
    dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
392
    with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"):
393
        _update_dataloader(dataloader, dataloader.sampler)
394

395
    with _replace_dunder_methods(DataLoader, "dataset"):
396
        dataloader = BadStandaloneGoodHookImpl([1, 2, 3])
397
    new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
398
    assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)
399

400
    class BadImpl(DataLoader):
401
        def __init__(self, randomize, *args, **kwargs):
402
            self.randomize = randomize
403
            # keyword conflict with `shuffle`
404
            super().__init__(*args, shuffle=randomize, **kwargs)
405

406
    dataloader = BadImpl(False, [])
407
    with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"):
408
        _update_dataloader(dataloader, dataloader.sampler)
409

410
    class GoodImpl(DataLoader):
411
        def __init__(self, randomize, *args, **kwargs):
412
            # fixed implementation, kwargs are filtered
413
            self.randomize = randomize or kwargs.pop("shuffle", False)
414
            super().__init__(*args, shuffle=randomize, **kwargs)
415

416
    dataloader = GoodImpl(False, [])
417
    new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
418
    assert isinstance(new_dataloader, GoodImpl)
419

420

421
def test_custom_torch_batch_sampler():
422
    """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly
423
    reinstantiate the class, is invoked properly.
424

425
    It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
426
    not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
427

428
    """
429

430
    class MyBatchSampler(BatchSampler):
431
        # Custom Batch sampler with extra argument and default value
432
        def __init__(self, sampler, extra_arg, drop_last=True):
433
            self.extra_arg = extra_arg
434
            super().__init__(sampler, 10, drop_last)
435

436
    sampler = RandomSampler(range(10))
437
    with _replace_dunder_methods(BatchSampler):
438
        # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
439
        batch_sampler = MyBatchSampler(sampler, "random_str")
440

441
    dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
442

443
    # assert that passed information got saved
444
    assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")
445
    assert dataloader.batch_sampler.__pl_saved_kwargs == {}
446
    assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")
447
    assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}
448

449
    # updating dataloader, what happens on access of the dataloaders.
450
    # This should not fail, and would fail before support for custom args.
451
    dataloader = _update_dataloader(dataloader, dataloader.sampler)
452

453
    # Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types
454
    batch_sampler = dataloader.batch_sampler
455

456
    assert isinstance(batch_sampler, MyBatchSampler)
457

458
    assert batch_sampler.extra_arg == "random_str"
459
    assert not hasattr(batch_sampler, "__pl_saved_kwargs")
460
    assert not hasattr(batch_sampler, "__pl_saved_arg_names")
461
    assert not hasattr(batch_sampler, "__pl_saved_args")
462
    assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")
463

464

465
def test_custom_torch_batch_sampler_doppelganger():
466
    """Test we can reinstantiate a sampler that mimics PyTorch's BatchSampler even if it does not inherit from it.
467

468
    This is only possible if that sampler accepts the `batch_size` and `drop_last` arguments, and stores them
469
    as attributes.
470

471
    """
472

473
    class BatchSamplerDoppelganger:
474
        """A batch sampler that mimics `torch.utils.data.BatchSampler` but does not inherit from it."""
475

476
        def __init__(self, sampler, batch_size, drop_last):
477
            self.sampler = sampler
478
            self.batch_size = batch_size
479
            self.drop_last = drop_last
480

481
        def __iter__(self):
482
            while True:
483
                yield [0, 1, 2, 3]
484

485
        def __len__(self) -> int:
486
            return 4
487

488
    batch_sampler = BatchSamplerDoppelganger(sampler=Mock(), batch_size=2, drop_last=True)
489
    dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
490
    new_sampler = Mock()
491
    dataloader = _update_dataloader(dataloader, sampler=new_sampler)
492

493
    batch_sampler = dataloader.batch_sampler
494
    assert isinstance(batch_sampler, BatchSamplerDoppelganger)
495
    assert batch_sampler.sampler == new_sampler
496

497

498
def test_custom_batch_sampler():
499
    """Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""
500

501
    class CustomBatchSampler:  # not inheriting from `BatchSampler`
502
        def __iter__(self):
503
            while True:
504
                yield [0, 1, 2, 3]
505

506
    batch_sampler = CustomBatchSampler()
507
    dataloader = DataLoader(range(100), batch_sampler=batch_sampler)
508
    with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):
509
        _ = _update_dataloader(dataloader, sampler=Mock())
510

511

512
def test_custom_batch_sampler_no_sampler():
513
    """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument."""
514

515
    class MyBatchSampler(BatchSampler):
516
        # Custom batch sampler, without sampler argument.
517
        def __init__(self, extra_arg):
518
            self.extra_arg = extra_arg
519
            super().__init__(RandomSampler(range(10)), 10, False)
520

521
    with _replace_dunder_methods(BatchSampler):
522
        # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks
523
        batch_sampler = MyBatchSampler("random_str")
524
    dataloader = DataLoader(range(10), batch_sampler=batch_sampler)
525

526
    # assert that passed information got saved
527
    assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)
528
    assert dataloader.batch_sampler.__pl_saved_kwargs == {}
529
    assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)
530
    assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}
531

532
    # Assert that error is raised
533
    with pytest.raises(TypeError, match="sampler into the batch sampler"):
534
        _ = _update_dataloader(dataloader, dataloader.sampler)
535

536

537
def test_dataloader_kwargs_replacement_with_iterable_dataset():
538
    """Test that DataLoader kwargs are not replaced when using Iterable Dataset."""
539
    dataset = RandomIterableDataset(7, 100)
540
    dataloader = DataLoader(dataset, batch_size=32)
541
    _, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
542
    assert dl_kwargs["sampler"] is None
543
    assert dl_kwargs["batch_sampler"] is None
544
    assert dl_kwargs["batch_size"] is dataloader.batch_size
545
    assert dl_kwargs["dataset"] is dataloader.dataset
546
    assert dl_kwargs["collate_fn"] is dataloader.collate_fn
547

548

549
def test_dataloader_kwargs_replacement_with_array_default_comparison():
550
    """Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous).
551

552
    Regression test for issue #15408.
553

554
    """
555
    dataset = RandomDataset(5, 100)
556

557
    class ArrayAttributeDataloader(DataLoader):
558
        def __init__(self, indices=None, **kwargs):
559
            super().__init__(dataset)
560
            self.indices = np.random.rand(2, 2)  # an attribute we can't compare with ==
561

562
    dataloader = ArrayAttributeDataloader(dataset)
563
    _, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
564
    assert dl_kwargs["indices"] is dataloader.indices
565

566

567
def test_set_sampler_epoch():
568
    # No samplers
569
    dataloader = Mock()
570
    dataloader.sampler = None
571
    dataloader.batch_sampler = None
572
    _set_sampler_epoch(dataloader, 55)
573

574
    # set_epoch not callable
575
    dataloader = Mock()
576
    dataloader.sampler.set_epoch = None
577
    dataloader.batch_sampler.set_epoch = None
578
    _set_sampler_epoch(dataloader, 55)
579

580
    # set_epoch callable
581
    dataloader = Mock()
582
    _set_sampler_epoch(dataloader, 55)
583
    dataloader.sampler.set_epoch.assert_called_once_with(55)
584
    dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)
585

586

587
@pytest.mark.parametrize(
588
    ("cpu_count", "local_world_size", "expected"),
589
    [
590
        (0, 1, 1),
591
        (1, 1, 1),
592
        (2, 1, 2 - 1),
593
        (1, 2, 1),
594
        (2, 2, 1),
595
        (3, 2, 1),
596
        (4, 2, 2 - 1),
597
        (4, 3, 1),
598
        (4, 1, 4 - 1),
599
    ],
600
)
601
@pytest.mark.parametrize(
602
    "affinity",
603
    [
604
        False,
605
        pytest.param(
606
            True,
607
            marks=pytest.mark.skipif(
608
                not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores"
609
            ),
610
        ),
611
    ],
612
)
613
@mock.patch("lightning.fabric.utilities.data.os.cpu_count")
614
def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch):
615
    if affinity:
616
        monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count)))
617
    else:
618
        monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
619
        cpu_count_mock.return_value = cpu_count
620

621
    assert suggested_max_num_workers(local_world_size) == expected
622

623

624
@pytest.mark.parametrize("invalid", [-1, 0])
625
def test_suggested_max_num_workers_input_validation(invalid):
626
    with pytest.raises(ValueError, match="should be >= 1"):
627
        suggested_max_num_workers(invalid)
628

629

630
@pytest.mark.parametrize("cpu_count", [1, 2, 3])
631
@pytest.mark.parametrize("local_world_size", [1, 2, 3])
632
def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch):
633
    """Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers."""
634
    monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)
635
    monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False)
636
    monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count)
637
    monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count)
638

639
    # The dataloader runs a check in `DataLoader.check_worker_number_rationality`
640
    with pytest.warns(UserWarning, match="This DataLoader will create"):
641
        DataLoader(range(2), num_workers=(cpu_count + 1))
642
    with no_warning_call():
643
        DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))
644

645

646
def test_state():
647
    # init via dict
648
    inputs = {"key1": 1, "key2": "abc"}
649
    state = AttributeDict(inputs)
650
    for key, value in inputs.items():
651
        assert getattr(state, key) == value
652

653
    # init via kwargs
654
    inputs = {"key1": 1, "key2": "abc"}
655
    state = AttributeDict(**inputs)
656
    for key, value in inputs.items():
657
        assert getattr(state, key) == value
658

659
    # update via dict
660
    state = AttributeDict()
661
    state.update({"key1": 1})
662
    assert state.key1 == 1
663

664
    # update via setter
665
    state = AttributeDict({"key1": 1})
666
    state.key1 = 123
667
    assert state.key1 == 123
668

669
    with pytest.raises(AttributeError, match="has no attribute 'key3'"):
670
        _ = state.key3
671

672
    # delete attribute
673
    del state.key1
674
    assert "key1" not in state
675
    with pytest.raises(KeyError):
676
        del state.key3
677

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.