pytorch-lightning
676 строк · 23.8 Кб
1import contextlib2import os3import random4from unittest import mock5from unittest.mock import Mock6
7import lightning.fabric8import numpy as np9import pytest10import torch11from lightning.fabric.utilities.data import (12AttributeDict,13_get_dataloader_init_args_and_kwargs,14_replace_dunder_methods,15_replace_value_in_saved_args,16_set_sampler_epoch,17_update_dataloader,18_WrapAttrTag,19has_iterable_dataset,20has_len,21suggested_max_num_workers,22)
23from lightning.fabric.utilities.exceptions import MisconfigurationException24from lightning_utilities.test.warning import no_warning_call25from torch import Tensor26from torch.utils.data import BatchSampler, DataLoader, RandomSampler27
28from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset29
30
31def test_has_iterable_dataset():32assert has_iterable_dataset(DataLoader(RandomIterableDataset(1, 1)))33
34assert not has_iterable_dataset(DataLoader(RandomDataset(1, 1)))35
36class MockDatasetWithoutIterableDataset(RandomDataset):37def __iter__(self):38yield 139return self40
41assert not has_iterable_dataset(DataLoader(MockDatasetWithoutIterableDataset(1, 1)))42
43
44def test_has_len():45assert has_len(DataLoader(RandomDataset(1, 1)))46
47with pytest.warns(UserWarning, match="`DataLoader` returned 0 length."):48assert has_len(DataLoader(RandomDataset(0, 0)))49
50assert not has_len(DataLoader(RandomIterableDataset(1, 1)))51
52
53def test_replace_dunder_methods_multiple_loaders_without_init():54"""In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__`55method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent
56class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured
57only sometimes because it depends on the order in which we are iterating over a set of classes we are patching.
58
59This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__`
60and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and
61b) the mechanism checking for presence of `__old__init__` works as expected.
62
63"""
64classes = [DataLoader]65for i in range(100):66classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {}))67
68before = {cls: cls.__init__ for cls in classes}69
70with _replace_dunder_methods(DataLoader, "dataset"):71for cls in classes[1:]: # First one is `DataLoader`72assert "__old__init__" not in cls.__dict__73assert hasattr(cls, "__old__init__")74
75assert "__old__init__" in DataLoader.__dict__76assert hasattr(DataLoader, "__old__init__")77
78for cls in classes:79assert before[cls] == cls.__init__80
81
82class MyBaseDataLoader(DataLoader):83pass84
85
86class DataLoaderSubclass1(DataLoader):87def __init__(self, attribute1, *args, **kwargs):88self.at1 = attribute189super().__init__(*args, **kwargs)90
91
92class DataLoaderSubclass2(DataLoaderSubclass1):93def __init__(self, attribute2, *args, **kwargs):94self.at2 = attribute295super().__init__(attribute2 + "-2", *args, **kwargs)96
97
98class MyDataLoader(MyBaseDataLoader):99def __init__(self, data: Tensor, *args, **kwargs):100self.data = data101super().__init__(range(data.size(0)), *args, **kwargs)102
103
104test3_data = torch.randn((10, 20))105
106
107class PoptorchDataLoader(DataLoader):108def __init__(self, options, *args, **kwargs):109super().__init__(*args, **kwargs)110self._options = options111
112@property113def options(self):114return self._options115
116
117class IncompleteDataLoader(DataLoader):118def __init__(self, dataset, batch_size, **kwargs):119batch_size = max(batch_size - 5, 0)120super().__init__(dataset, batch_size=batch_size, **kwargs)121
122
123class WeirdDataLoader1(DataLoader):124def __init__(self, arg1, arg2, **kwargs):125self.arg1 = arg1126super().__init__(arg2, **kwargs)127
128
129class WeirdDataLoader2(DataLoader):130def __init__(self, data_part1, data_part2, **kwargs):131data = list(data_part1) + list(data_part2)132super().__init__(data, **kwargs)133
134
135class NoneDataLoader(DataLoader):136def __init__(self, *args, **kwargs):137super().__init__(*args, **kwargs)138
139
140class ChangingDataLoader(DataLoader):141def __init__(self, dataset, **kwargs):142super().__init__(list(dataset) + list(range(5, 10)), **kwargs)143
144
145@pytest.mark.parametrize(146("cls", "args", "kwargs", "arg_names", "dataset", "checked_values"),147[148pytest.param(149DataLoaderSubclass1,150("attribute1",),151{"dataset": range(4), "batch_size": 2},152("attribute1",),153range(4),154{"batch_size": 2, "at1": "attribute1"},155id="test1",156),157pytest.param(158DataLoaderSubclass2,159("attribute2",),160{"dataset": range(4), "batch_size": 2},161("attribute2",),162range(4),163{"batch_size": 2, "at1": "attribute2-2", "at2": "attribute2"},164id="test2",165),166pytest.param(167MyDataLoader,168(test3_data,),169{"batch_size": 2},170("data",),171range(10),172{"batch_size": 2, "data": test3_data},173id="test3",174),175pytest.param(PoptorchDataLoader, (123, [1]), {}, ("options",), [1], {"options": 123}, id="test4"),176pytest.param(177IncompleteDataLoader,178(range(10),),179{"batch_size": 10},180("dataset",),181range(10),182{"batch_size": 5},183id="test5",184),185pytest.param(186WeirdDataLoader1,187(10, range(10)),188{"batch_size": 10},189("arg1", "arg2"),190range(10),191{"arg1": 10, "batch_size": 10},192id="test6",193),194pytest.param(195WeirdDataLoader2,196(range(10), range(10, 20)),197{"batch_size": 10},198("data_part1", "data_part2"),199list(range(20)),200{"batch_size": 10},201id="test7",202),203pytest.param(NoneDataLoader, (None,), {}, (), None, {}, id="test8"),204pytest.param(ChangingDataLoader, (range(5),), {}, ("dataset",), list(range(10)), {}, id="test9"),205],206)
207def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values):208with _replace_dunder_methods(DataLoader, "dataset"):209dataloader = cls(*args, **kwargs)210
211assert dataloader.__pl_saved_args == args212assert dataloader.__pl_saved_kwargs == kwargs213assert dataloader.__pl_saved_arg_names == arg_names214assert dataloader.__pl_saved_default_kwargs == {}215assert dataloader.__dataset == dataset216
217assert dataloader.dataset == dataset218
219for key, value in checked_values.items():220dataloader_value = getattr(dataloader, key)221if isinstance(dataloader_value, Tensor):222assert dataloader_value is value223else:224assert dataloader_value == value225
226dataloader = _update_dataloader(dataloader, dataloader.sampler)227
228assert isinstance(dataloader, cls)229assert not hasattr(dataloader, "__pl_saved_kwargs")230assert not hasattr(dataloader, "__pl_saved_arg_names")231assert not hasattr(dataloader, "__pl_saved_args")232assert not hasattr(dataloader, "__pl_saved_default_kwargs")233assert not hasattr(dataloader, "__dataset")234
235assert dataloader.dataset == dataset236
237for key, value in checked_values.items():238dataloader_value = getattr(dataloader, key)239if isinstance(dataloader_value, Tensor):240assert dataloader_value is value241else:242assert dataloader_value == value243
244
245def test_replace_dunder_methods_extra_kwargs():246class LoaderSubclass(DataLoader):247def __init__(self, dataset, *args, batch_size=10, **kwargs):248super().__init__(dataset, *args, batch_size=batch_size, **kwargs)249
250with _replace_dunder_methods(DataLoader, "dataset"):251dataloader = LoaderSubclass(range(10))252
253assert dataloader.__pl_saved_args == (range(10),)254assert dataloader.__pl_saved_kwargs == {}255assert dataloader.__pl_saved_arg_names == ("dataset",)256assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10}257assert dataloader.__dataset == range(10)258
259
260def test_replace_dunder_methods_attrs():261"""This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` are262correctly preserved even after reinstantiation.
263
264It also includes a custom `__setattr__`
265
266"""
267
268class Loader(DataLoader):269def __setattr__(self, attr, val):270if attr == "custom_arg":271val = val + 2272super().__setattr__(attr, val)273
274with _replace_dunder_methods(DataLoader, "dataset"):275dataloader = Loader(range(10))276dataloader.custom_arg = 5277dataloader.my_arg = 10278dataloader.another_arg = 100279del dataloader.dataset280with contextlib.suppress(AttributeError):281del dataloader.abc_arg282
283assert dataloader.__pl_saved_args == (range(10),)284assert dataloader.__pl_saved_kwargs == {}285assert dataloader.__pl_saved_arg_names == ("dataset",)286assert dataloader.__dataset == range(10)287assert dataloader.custom_arg == 7288assert dataloader.my_arg == 10289assert dataloader.another_arg == 100290assert not hasattr(dataloader, "dataset")291assert dataloader.__pl_attrs_record == [292(("custom_arg", 5), _WrapAttrTag.SET),293(("my_arg", 10), _WrapAttrTag.SET),294(("another_arg", 100), _WrapAttrTag.SET),295(("dataset",), _WrapAttrTag.DEL),296]297
298dataloader = _update_dataloader(dataloader, dataloader.sampler)299assert dataloader.custom_arg == 7300assert dataloader.my_arg == 10301assert dataloader.another_arg == 100302assert not hasattr(dataloader, "dataset")303
304
305def test_replace_dunder_methods_restore_methods():306"""This tests checks whether are all dunder methods restored to their original versions."""307
308class Init(DataLoader):309def __init__(self, *args, **kwargs):310super().__init__(*args, **kwargs)311
312class SetAttr(DataLoader):313def __setattr__(self, *args):314return super().__setattr__(*args)315
316class DelAttr(DataLoader):317def __delattr__(self, *args):318return super().__delattr__(*args)319
320class InitAndSetAttr(Init, SetAttr):321pass322
323class InitAndDelAttr(Init, DelAttr):324pass325
326class SetAttrAndDelAttr(SetAttr, DelAttr):327pass328
329class AllDunder(Init, SetAttr, DelAttr):330pass331
332before = {}333for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder):334before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__}335
336with _replace_dunder_methods(DataLoader, "dataset"):337pass338
339for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder):340assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__}341
342
343@pytest.mark.parametrize(344(345"args",346"kwargs",347"default_kwargs",348"arg_names",349"replace_key",350"replace_value",351"expected_status",352"expected_args",353"expected_kwargs",354),355[356pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"),357pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"),358pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"),359pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"),360pytest.param(361(1, 2, 3),362{"a": 1},363{"e": 5},364["b", "c", "d"],365"e",3662,367True,368(1, 2, 3),369{"a": 1, "e": 2},370id="default_kwargs",371),372],373)
374def test_replace_value_in_args(375args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs376):377assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == (378expected_status,379expected_args,380expected_kwargs,381)382
383
384def test_update_dataloader_typerror_custom_exception():385class BadStandaloneGoodHookImpl(DataLoader):386def __init__(self, foo, *args, **kwargs):387self.foo = foo388# positional conflict with `dataset`389super().__init__(foo, *args, **kwargs)390
391dataloader = BadStandaloneGoodHookImpl([1, 2, 3])392with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"):393_update_dataloader(dataloader, dataloader.sampler)394
395with _replace_dunder_methods(DataLoader, "dataset"):396dataloader = BadStandaloneGoodHookImpl([1, 2, 3])397new_dataloader = _update_dataloader(dataloader, dataloader.sampler)398assert isinstance(new_dataloader, BadStandaloneGoodHookImpl)399
400class BadImpl(DataLoader):401def __init__(self, randomize, *args, **kwargs):402self.randomize = randomize403# keyword conflict with `shuffle`404super().__init__(*args, shuffle=randomize, **kwargs)405
406dataloader = BadImpl(False, [])407with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"):408_update_dataloader(dataloader, dataloader.sampler)409
410class GoodImpl(DataLoader):411def __init__(self, randomize, *args, **kwargs):412# fixed implementation, kwargs are filtered413self.randomize = randomize or kwargs.pop("shuffle", False)414super().__init__(*args, shuffle=randomize, **kwargs)415
416dataloader = GoodImpl(False, [])417new_dataloader = _update_dataloader(dataloader, dataloader.sampler)418assert isinstance(new_dataloader, GoodImpl)419
420
421def test_custom_torch_batch_sampler():422"""This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly423reinstantiate the class, is invoked properly.
424
425It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore
426not setting `__pl_saved_{args,arg_names,kwargs}` attributes.
427
428"""
429
430class MyBatchSampler(BatchSampler):431# Custom Batch sampler with extra argument and default value432def __init__(self, sampler, extra_arg, drop_last=True):433self.extra_arg = extra_arg434super().__init__(sampler, 10, drop_last)435
436sampler = RandomSampler(range(10))437with _replace_dunder_methods(BatchSampler):438# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks439batch_sampler = MyBatchSampler(sampler, "random_str")440
441dataloader = DataLoader(range(10), batch_sampler=batch_sampler)442
443# assert that passed information got saved444assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str")445assert dataloader.batch_sampler.__pl_saved_kwargs == {}446assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg")447assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True}448
449# updating dataloader, what happens on access of the dataloaders.450# This should not fail, and would fail before support for custom args.451dataloader = _update_dataloader(dataloader, dataloader.sampler)452
453# Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types454batch_sampler = dataloader.batch_sampler455
456assert isinstance(batch_sampler, MyBatchSampler)457
458assert batch_sampler.extra_arg == "random_str"459assert not hasattr(batch_sampler, "__pl_saved_kwargs")460assert not hasattr(batch_sampler, "__pl_saved_arg_names")461assert not hasattr(batch_sampler, "__pl_saved_args")462assert not hasattr(batch_sampler, "__pl_saved_default_kwargs")463
464
465def test_custom_torch_batch_sampler_doppelganger():466"""Test we can reinstantiate a sampler that mimics PyTorch's BatchSampler even if it does not inherit from it.467
468This is only possible if that sampler accepts the `batch_size` and `drop_last` arguments, and stores them
469as attributes.
470
471"""
472
473class BatchSamplerDoppelganger:474"""A batch sampler that mimics `torch.utils.data.BatchSampler` but does not inherit from it."""475
476def __init__(self, sampler, batch_size, drop_last):477self.sampler = sampler478self.batch_size = batch_size479self.drop_last = drop_last480
481def __iter__(self):482while True:483yield [0, 1, 2, 3]484
485def __len__(self) -> int:486return 4487
488batch_sampler = BatchSamplerDoppelganger(sampler=Mock(), batch_size=2, drop_last=True)489dataloader = DataLoader(range(100), batch_sampler=batch_sampler)490new_sampler = Mock()491dataloader = _update_dataloader(dataloader, sampler=new_sampler)492
493batch_sampler = dataloader.batch_sampler494assert isinstance(batch_sampler, BatchSamplerDoppelganger)495assert batch_sampler.sampler == new_sampler496
497
498def test_custom_batch_sampler():499"""Test that a custom (non-PyTorch) batch sampler requires the user to set `use_distributed_sampler=False`."""500
501class CustomBatchSampler: # not inheriting from `BatchSampler`502def __iter__(self):503while True:504yield [0, 1, 2, 3]505
506batch_sampler = CustomBatchSampler()507dataloader = DataLoader(range(100), batch_sampler=batch_sampler)508with pytest.raises(TypeError, match=r"can't inject a \(distributed\) sampler into your batch sampler"):509_ = _update_dataloader(dataloader, sampler=Mock())510
511
512def test_custom_batch_sampler_no_sampler():513"""Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument."""514
515class MyBatchSampler(BatchSampler):516# Custom batch sampler, without sampler argument.517def __init__(self, extra_arg):518self.extra_arg = extra_arg519super().__init__(RandomSampler(range(10)), 10, False)520
521with _replace_dunder_methods(BatchSampler):522# instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks523batch_sampler = MyBatchSampler("random_str")524dataloader = DataLoader(range(10), batch_sampler=batch_sampler)525
526# assert that passed information got saved527assert dataloader.batch_sampler.__pl_saved_args == ("random_str",)528assert dataloader.batch_sampler.__pl_saved_kwargs == {}529assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",)530assert dataloader.batch_sampler.__pl_saved_default_kwargs == {}531
532# Assert that error is raised533with pytest.raises(TypeError, match="sampler into the batch sampler"):534_ = _update_dataloader(dataloader, dataloader.sampler)535
536
537def test_dataloader_kwargs_replacement_with_iterable_dataset():538"""Test that DataLoader kwargs are not replaced when using Iterable Dataset."""539dataset = RandomIterableDataset(7, 100)540dataloader = DataLoader(dataset, batch_size=32)541_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)542assert dl_kwargs["sampler"] is None543assert dl_kwargs["batch_sampler"] is None544assert dl_kwargs["batch_size"] is dataloader.batch_size545assert dl_kwargs["dataset"] is dataloader.dataset546assert dl_kwargs["collate_fn"] is dataloader.collate_fn547
548
549def test_dataloader_kwargs_replacement_with_array_default_comparison():550"""Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous).551
552Regression test for issue #15408.
553
554"""
555dataset = RandomDataset(5, 100)556
557class ArrayAttributeDataloader(DataLoader):558def __init__(self, indices=None, **kwargs):559super().__init__(dataset)560self.indices = np.random.rand(2, 2) # an attribute we can't compare with ==561
562dataloader = ArrayAttributeDataloader(dataset)563_, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)564assert dl_kwargs["indices"] is dataloader.indices565
566
567def test_set_sampler_epoch():568# No samplers569dataloader = Mock()570dataloader.sampler = None571dataloader.batch_sampler = None572_set_sampler_epoch(dataloader, 55)573
574# set_epoch not callable575dataloader = Mock()576dataloader.sampler.set_epoch = None577dataloader.batch_sampler.set_epoch = None578_set_sampler_epoch(dataloader, 55)579
580# set_epoch callable581dataloader = Mock()582_set_sampler_epoch(dataloader, 55)583dataloader.sampler.set_epoch.assert_called_once_with(55)584dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)585
586
587@pytest.mark.parametrize(588("cpu_count", "local_world_size", "expected"),589[590(0, 1, 1),591(1, 1, 1),592(2, 1, 2 - 1),593(1, 2, 1),594(2, 2, 1),595(3, 2, 1),596(4, 2, 2 - 1),597(4, 3, 1),598(4, 1, 4 - 1),599],600)
601@pytest.mark.parametrize(602"affinity",603[604False,605pytest.param(606True,607marks=pytest.mark.skipif(608not hasattr(os, "sched_getaffinity"), reason="OS does not support restricting CPU cores"609),610),611],612)
613@mock.patch("lightning.fabric.utilities.data.os.cpu_count")614def test_suggested_max_num_workers(cpu_count_mock, affinity, cpu_count, local_world_size, expected, monkeypatch):615if affinity:616monkeypatch.setattr(lightning.fabric.utilities.data.os, "sched_getaffinity", lambda _: list(range(cpu_count)))617else:618monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)619cpu_count_mock.return_value = cpu_count620
621assert suggested_max_num_workers(local_world_size) == expected622
623
624@pytest.mark.parametrize("invalid", [-1, 0])625def test_suggested_max_num_workers_input_validation(invalid):626with pytest.raises(ValueError, match="should be >= 1"):627suggested_max_num_workers(invalid)628
629
630@pytest.mark.parametrize("cpu_count", [1, 2, 3])631@pytest.mark.parametrize("local_world_size", [1, 2, 3])632def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size, cpu_count, monkeypatch):633"""Test that our suggestion for num workers doesn't trigger a warning in the DataLoader for too many workers."""634monkeypatch.delattr(lightning.fabric.utilities.data.os, "sched_getaffinity", raising=False)635monkeypatch.delattr(torch.utils.data.dataloader.os, "sched_getaffinity", raising=False)636monkeypatch.setattr(lightning.fabric.utilities.data.os, "cpu_count", lambda: cpu_count)637monkeypatch.setattr(torch.utils.data.dataloader.os, "cpu_count", lambda: cpu_count)638
639# The dataloader runs a check in `DataLoader.check_worker_number_rationality`640with pytest.warns(UserWarning, match="This DataLoader will create"):641DataLoader(range(2), num_workers=(cpu_count + 1))642with no_warning_call():643DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))644
645
646def test_state():647# init via dict648inputs = {"key1": 1, "key2": "abc"}649state = AttributeDict(inputs)650for key, value in inputs.items():651assert getattr(state, key) == value652
653# init via kwargs654inputs = {"key1": 1, "key2": "abc"}655state = AttributeDict(**inputs)656for key, value in inputs.items():657assert getattr(state, key) == value658
659# update via dict660state = AttributeDict()661state.update({"key1": 1})662assert state.key1 == 1663
664# update via setter665state = AttributeDict({"key1": 1})666state.key1 = 123667assert state.key1 == 123668
669with pytest.raises(AttributeError, match="has no attribute 'key3'"):670_ = state.key3671
672# delete attribute673del state.key1674assert "key1" not in state675with pytest.raises(KeyError):676del state.key3677