pytorch-lightning
617 строк · 23.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from unittest import mock
15from unittest.mock import Mock, call
16
17import pytest
18import torch
19from lightning.fabric.fabric import Fabric
20from lightning.fabric.plugins import Precision
21from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
22from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
23from lightning.fabric.wrappers import (
24_FabricDataLoader,
25_FabricModule,
26_FabricOptimizer,
27_unwrap_compiled,
28_unwrap_objects,
29is_wrapped,
30)
31from torch.utils.data import BatchSampler, DistributedSampler
32from torch.utils.data.dataloader import DataLoader
33
34from tests_fabric.helpers.runif import RunIf
35
36
37def test_fabric_module_wraps():
38"""Test that the wrapped module is accessible via the property."""
39module = Mock()
40assert _FabricModule(module, Mock()).module is module
41
42wrapped_module = Mock()
43original_module = Mock()
44assert _FabricModule(wrapped_module, Mock(), original_module=original_module).module is original_module
45
46
47def test_fabric_module_attribute_lookup():
48"""Test that attribute lookup passes through to the original module when possible."""
49
50class OriginalModule(torch.nn.Module):
51def __init__(self):
52super().__init__()
53self.layer = torch.nn.Linear(2, 3)
54self.attribute = 1
55
56def method(self):
57return 2
58
59original_module = OriginalModule()
60
61class ModuleWrapper(torch.nn.Module):
62def __init__(self):
63super().__init__()
64self.wrapped = original_module
65
66wrapped_module = ModuleWrapper()
67
68fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
69assert fabric_module.attribute == 1
70assert fabric_module.layer is original_module.layer
71assert fabric_module.forward.__self__.__class__ == _FabricModule
72
73with pytest.raises(AttributeError):
74_ = fabric_module.not_exists
75
76
77def test_fabric_module_method_lookup():
78"""Test that access to methods warns about improper use when a wrapper from a strategy is involved."""
79
80class OriginalModule(torch.nn.Module):
81def __init__(self):
82super().__init__()
83self.submodule = torch.nn.Linear(2, 3)
84
85def forward(self, x):
86return x
87
88def method_without_module_invocation(self):
89return 100
90
91def method_with_submodule_invocation(self):
92self.submodule(torch.rand(2, 2))
93return 101
94
95def method_with_self_invocation(self):
96self(None)
97return 102
98
99class ModuleWrapper(torch.nn.Module):
100def __init__(self, module):
101super().__init__()
102self.wrapped = module
103
104# Regular case: forward_module == original_module -> no warnings
105original_module = OriginalModule()
106fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
107assert fabric_module.method_without_module_invocation() == 100
108
109# Special case: original module wrapped by forward module: -> warn if method accepts args
110original_module = OriginalModule()
111wrapped_module = ModuleWrapper(original_module)
112fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
113assert fabric_module.method_without_module_invocation() == 100
114with pytest.raises(
115RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
116):
117assert fabric_module.method_with_submodule_invocation() == 101
118with pytest.raises(
119RuntimeError, match=r"You are calling the method `OriginalModule.method_with_self_invocation\(\)` from"
120):
121assert fabric_module.method_with_self_invocation() == 102
122
123
124def test_fabric_module_setattr():
125"""Test that setattr sets attributes on the original module."""
126
127class OriginalModule(torch.nn.Module):
128def __init__(self):
129super().__init__()
130self.layer = torch.nn.Linear(2, 3)
131self.attribute = 1
132self._x = None
133
134@property
135def x(self):
136return self._x
137
138@x.setter
139def x(self, value):
140self._x = value
141
142original_module = OriginalModule()
143
144class ModuleWrapper(torch.nn.Module):
145def __init__(self):
146super().__init__()
147self.wrapped = original_module
148
149wrapped_module = ModuleWrapper()
150fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
151
152# Check new attribute is set on original_module
153fabric_module.new_attribute = 100
154assert original_module.new_attribute == 100
155
156# Modify existing attribute on original_module
157fabric_module.attribute = 101
158assert original_module.attribute == 101
159
160# Check setattr of original_module
161fabric_module.x = 102
162assert original_module.x == 102
163
164# Check set submodule
165assert not hasattr(original_module, "linear")
166linear = torch.nn.Linear(2, 2)
167fabric_module.linear = linear
168assert hasattr(original_module, "linear")
169assert isinstance(original_module.linear, torch.nn.Module)
170assert linear in fabric_module.modules()
171assert linear in original_module.modules()
172
173
174def test_fabric_module_state_dict_access():
175"""Test that state_dict access passes through to the original module."""
176
177class OriginalModule(torch.nn.Module):
178def __init__(self):
179super().__init__()
180self.layer = torch.nn.Linear(2, 3)
181
182original_module = OriginalModule()
183
184class ModuleWrapper(torch.nn.Module):
185def __init__(self):
186super().__init__()
187self.wrapped = original_module
188
189wrapped_module = ModuleWrapper()
190
191fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
192state_dict = fabric_module.state_dict()
193assert set(state_dict.keys()) == {"layer.weight", "layer.bias"}
194
195weight, bias = torch.rand(3, 2), torch.rand(3)
196fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias})
197assert torch.equal(fabric_module.layer.weight, weight)
198assert torch.equal(fabric_module.layer.bias, bias)
199
200if _TORCH_GREATER_EQUAL_2_1:
201# Can use additional `assign` argument in PyTorch >= 2.1
202with torch.device("meta"):
203original_module = OriginalModule()
204fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
205assert fabric_module.layer.weight.is_meta
206fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
207assert not fabric_module.layer.weight.is_meta
208
209
210@pytest.mark.parametrize(
211("precision", "input_type", "expected_type", "accelerator", "device_str"),
212[
213pytest.param(32, torch.float16, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
214pytest.param(32, torch.float32, torch.float32, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
215pytest.param(32, torch.float64, torch.float64, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
216pytest.param(32, torch.int, torch.int, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
217pytest.param(16, torch.float32, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
218pytest.param(16, torch.float64, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
219pytest.param(16, torch.long, torch.long, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
220pytest.param(
221"bf16",
222torch.float32,
223torch.bfloat16,
224"gpu",
225"cuda:0",
226marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
227),
228pytest.param(
229"bf16",
230torch.float64,
231torch.bfloat16,
232"gpu",
233"cuda:0",
234marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
235),
236pytest.param(
237"bf16",
238torch.bool,
239torch.bool,
240"gpu",
241"cuda:0",
242marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
243),
244pytest.param(32, torch.float32, torch.float32, "mps", "mps:0", marks=RunIf(mps=True)),
245],
246)
247def test_fabric_module_forward_conversion(precision, input_type, expected_type, accelerator, device_str):
248"""Test that the FabricModule performs autocasting on the input tensors and during forward()."""
249fabric = Fabric(precision=precision, accelerator=accelerator, devices=1)
250device = torch.device(device_str)
251
252def check_autocast(forward_input):
253assert precision != 16 or torch.is_autocast_enabled()
254return forward_input
255
256module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
257fabric_module = _FabricModule(module, fabric._strategy).to(device)
258out = fabric_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
259assert module.call_args[0][0].dtype == expected_type
260assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
261
262
263@pytest.mark.parametrize(
264"device_str",
265[
266"cpu",
267pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
268pytest.param("mps", marks=RunIf(mps=True)),
269],
270)
271@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
272def test_fabric_module_device_dtype_propagation(device_str, dtype):
273"""Test that the FabricModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""
274device = torch.device(device_str)
275
276class DeviceModule(_DeviceDtypeModuleMixin):
277pass
278
279device_module = DeviceModule()
280fabric_module = _FabricModule(device_module, Mock())
281fabric_module.to(device)
282assert device_module.device == device
283assert fabric_module.device == device
284
285fabric_module.to(dtype)
286assert device_module.dtype == dtype
287assert fabric_module.dtype == dtype
288
289
290def test_fabric_dataloader_iterator():
291"""Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no automatic
292device placement)."""
293dataloader = DataLoader(range(5), batch_size=2)
294fabric_dataloader = _FabricDataLoader(dataloader)
295assert len(fabric_dataloader) == len(dataloader) == 3
296
297iterator = iter(dataloader)
298fabric_iterator = iter(fabric_dataloader)
299
300assert torch.equal(next(iterator), next(fabric_iterator))
301assert torch.equal(next(iterator), next(fabric_iterator))
302assert torch.equal(next(iterator), next(fabric_iterator))
303
304with pytest.raises(StopIteration):
305next(iterator)
306
307with pytest.raises(StopIteration):
308next(fabric_iterator)
309
310
311@pytest.mark.parametrize(
312("src_device_str", "dest_device_str"),
313[
314("cpu", "cpu"),
315pytest.param("cpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
316pytest.param("cuda:0", "cpu", marks=RunIf(min_cuda_gpus=1)),
317pytest.param("cpu", "mps", marks=RunIf(mps=True)),
318pytest.param("mps", "cpu", marks=RunIf(mps=True)),
319],
320)
321def test_fabric_dataloader_device_placement(src_device_str, dest_device_str):
322"""Test that the FabricDataLoader moves data to the device in its iterator."""
323src_device = torch.device(src_device_str)
324dest_device = torch.device(dest_device_str)
325
326sample0 = torch.tensor(0, device=src_device)
327sample1 = torch.tensor(1, device=src_device)
328sample2 = {"data": torch.tensor(2, device=src_device)}
329sample3 = {"data": torch.tensor(3, device=src_device)}
330dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
331fabric_dataloader = _FabricDataLoader(dataloader=dataloader, device=dest_device)
332iterator = iter(fabric_dataloader)
333
334batch0 = next(iterator)
335assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device))
336
337batch1 = next(iterator)
338assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
339
340
341@pytest.mark.parametrize("use_batch_sampler", [False, True])
342def test_fabric_dataloader_distributed_sampler_set_epoch(use_batch_sampler):
343"""Test that the FabricDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
344dataset = range(3)
345sampler = DistributedSampler(dataset, num_replicas=2, rank=0)
346sampler.set_epoch = Mock()
347
348if not use_batch_sampler:
349dataloader = DataLoader(dataset, sampler=sampler)
350else:
351batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False)
352dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
353
354fabric_dataloader = _FabricDataLoader(dataloader)
355iterator_epoch_0 = iter(fabric_dataloader)
356sampler.set_epoch.assert_not_called()
357
358next(iterator_epoch_0)
359# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
360assert sampler.set_epoch.mock_calls == [call(0)]
361
362next(iterator_epoch_0)
363assert sampler.set_epoch.mock_calls == [call(0)]
364
365iterator_epoch_1 = iter(fabric_dataloader)
366assert sampler.set_epoch.mock_calls == [call(0)]
367
368next(iterator_epoch_1)
369# with every new iterator call, the epoch increases
370assert sampler.set_epoch.mock_calls == [call(0), call(1)]
371
372
373def test_fabric_optimizer_wraps():
374"""Test that the FabricOptimizer fully wraps the optimizer."""
375optimizer_cls = torch.optim.SGD
376optimizer = Mock(spec=optimizer_cls)
377fabric_optimizer = _FabricOptimizer(optimizer, Mock())
378assert fabric_optimizer.optimizer is optimizer
379assert isinstance(fabric_optimizer, optimizer_cls)
380assert isinstance(fabric_optimizer, _FabricOptimizer)
381assert type(fabric_optimizer).__name__ == "FabricSGD"
382
383
384def test_fabric_optimizer_state_dict():
385"""Test that the FabricOptimizer calls into the strategy to collect the state."""
386optimizer = Mock(spec=torch.optim.Adam)
387strategy = Mock()
388fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
389fabric_optimizer.state_dict()
390strategy.get_optimizer_state.assert_called_with(optimizer)
391
392
393def test_fabric_optimizer_load_state_dict():
394"""Test that the FabricOptimizer can load the state dict on the wrapped optimizer and update its internal
395`__dict__`."""
396model = torch.nn.Linear(1, 1)
397optimizer = torch.optim.Adam(model.parameters())
398assert not optimizer.state # a fresh optimizer has no state
399model(torch.rand(1)).backward()
400optimizer.step()
401assert optimizer.state
402state_dict = optimizer.state_dict()
403
404optimizer = torch.optim.Adam(model.parameters()) # fresh optimizer
405fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
406assert not fabric_optimizer.state # a fresh optimizer has no state
407fabric_optimizer.load_state_dict(state_dict)
408assert fabric_optimizer.state
409assert fabric_optimizer.optimizer.state_dict()["state"] == state_dict["state"]
410
411
412def test_fabric_optimizer_steps():
413"""Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
414optimizer = Mock(spec=torch.optim.Adam)
415strategy = Mock(spec=["optimizer_step"])
416strategy.optimizer_step.return_value = 123
417fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
418step_output = fabric_optimizer.step()
419assert step_output == 123
420strategy.optimizer_step.assert_called_once_with(optimizer)
421
422strategy.reset_mock()
423
424# with closure as input
425closure = Mock()
426fabric_optimizer.step(closure=closure)
427strategy.optimizer_step.assert_called_once_with(optimizer, closure=closure)
428
429# with model as optimizer
430strategy = Mock(spec=["optimizer_step", "model"])
431fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
432fabric_optimizer.step()
433strategy.optimizer_step.assert_called_once_with(strategy.model)
434
435
436def test_fabric_optimizer_zero_grad_kwargs():
437"""Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""
438# Test PyTorch's standard `.zero_grad()` signature
439with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
440optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
441fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
442fabric_optimizer.zero_grad()
443zero_grad_mock.assert_called_with()
444fabric_optimizer.zero_grad(set_to_none=False)
445zero_grad_mock.assert_called_with(set_to_none=False)
446fabric_optimizer.zero_grad(set_to_none=True)
447zero_grad_mock.assert_called_with(set_to_none=True)
448
449# Test weird `.zero_grad()` signatures from other libraries
450custom_zero_grad = Mock()
451
452class CustomSGD(torch.optim.SGD):
453def zero_grad(self, set_grads_to_None=False):
454custom_zero_grad(set_grads_to_None=set_grads_to_None)
455
456optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
457fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
458fabric_optimizer.zero_grad()
459custom_zero_grad.assert_called_with(set_grads_to_None=False)
460
461
462@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
463def test_is_wrapped(compile):
464"""Test that the `is_wrapped` utility recognizes when an object was wrapped by Fabric."""
465assert not is_wrapped(None)
466
467# _FabricModule
468module = torch.nn.Linear(2, 2)
469assert not is_wrapped(module)
470wrapped = _FabricModule(module, Mock())
471assert is_wrapped(wrapped)
472
473# _FabricModule inside an OptimizedModule
474if compile:
475from torch._dynamo import OptimizedModule
476
477module = torch.nn.Linear(2, 2)
478wrapped = torch.compile(_FabricModule(module, Mock()))
479assert isinstance(wrapped, OptimizedModule)
480assert is_wrapped(wrapped)
481
482# _FabricOptimizer
483optimizer = torch.optim.Adam(module.parameters())
484assert not is_wrapped(optimizer)
485wrapped = _FabricOptimizer(optimizer, Mock())
486assert is_wrapped(wrapped)
487
488# _FabricDataLoader
489dataloader = DataLoader([1, 2, 3])
490assert not is_wrapped(dataloader)
491wrapped = _FabricDataLoader(dataloader)
492assert is_wrapped(wrapped)
493
494
495@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
496def test_unwrap_objects(compile):
497# empty container
498assert _unwrap_objects({}) == {}
499
500# container with pure objects and wrapped objects
501module = torch.nn.Linear(1, 1)
502wrapped_module = _FabricModule(module, Mock())
503if compile:
504wrapped_module = torch.compile(wrapped_module)
505optimizer = torch.optim.Adam(module.parameters())
506wrapped_optimizer = _FabricOptimizer(optimizer, Mock())
507dataloader = DataLoader([1, 2, 3])
508wrapped_dataloader = _FabricDataLoader(dataloader)
509container = {
510"int": 1,
511"module": module,
512"wrapped_module": wrapped_module,
513"optimizer": optimizer,
514"wrapped_optimizer": wrapped_optimizer,
515"dataloader": dataloader,
516"wrapped_dataloader": wrapped_dataloader,
517"nested": [module, wrapped_module, optimizer, wrapped_optimizer, dataloader, wrapped_dataloader],
518}
519expected = {
520"int": 1,
521"module": module,
522"wrapped_module": wrapped_module._forward_module,
523"optimizer": optimizer,
524"wrapped_optimizer": optimizer,
525"dataloader": dataloader,
526"wrapped_dataloader": dataloader,
527"nested": [module, wrapped_module._forward_module, optimizer, optimizer, dataloader, dataloader],
528}
529assert _unwrap_objects(container) == expected
530
531
532def test_step_method_redirection():
533"""Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
534module."""
535
536class DDP(torch.nn.Module):
537def __init__(self, module):
538super().__init__()
539self.module = module
540
541def forward(self, *args, **kwargs):
542return self.module(*args, **kwargs)
543
544class LightningModule(torch.nn.Module):
545def forward(self):
546return "forward_return"
547
548def training_step(self, arg, kwarg=None):
549assert self() == "forward_return"
550assert arg == "train_arg"
551assert kwarg == "train_kwarg"
552return "training_step_return"
553
554def validation_step(self, arg, kwarg=None):
555assert self() == "forward_return"
556assert arg == "val_arg"
557assert kwarg == "val_kwarg"
558return "validation_step_return"
559
560def normal_method(self):
561pass
562
563strategy = Mock()
564strategy.precision = Mock(wraps=Precision())
565original_module = LightningModule()
566forward_module = DDP(original_module)
567fabric_module = _FabricModule(forward_module=forward_module, strategy=strategy, original_module=original_module)
568
569# Regular methods on the original_module are visible and identical on the fabric_module ...
570assert fabric_module.normal_method.__wrapped__ == original_module.normal_method
571
572# ... but special methods like training_step get redirected to the forward_module
573assert fabric_module.training_step.__name__ == "call_forward_module"
574assert fabric_module.validation_step.__name__ == "call_forward_module"
575assert fabric_module.test_step.__name__ == "call_forward_module"
576assert fabric_module.predict_step.__name__ == "call_forward_module"
577
578with pytest.raises(AttributeError, match="has no attribute 'predict_step'"):
579# A special method that does not exist will raise its AttributeError when being called
580fabric_module.predict_step()
581
582# The forward method on the original module remains untouched
583assert original_module.forward.__name__ == "forward"
584
585# The special methods get redirected correctly to produce the expected output
586assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
587assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
588assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
589strategy.precision.forward_context.assert_called()
590
591# The forward method remains untouched/unpatched after the special methods have been called
592assert original_module.forward.__name__ == "forward"
593
594# Special case: forward_module == original_module -> no special treatment applied
595fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
596assert fabric_module.training_step == original_module.training_step
597assert fabric_module.validation_step == original_module.validation_step
598
599
600@RunIf(dynamo=True)
601def test_unwrap_compiled():
602model = torch.nn.Linear(1, 1)
603
604with mock.patch("lightning.fabric.wrappers", "_TORCH_GREATER_EQUAL_2_0", False):
605unwrapped, compile_kwargs = _unwrap_compiled(model)
606assert unwrapped is model
607assert compile_kwargs is None
608
609compiled = torch.compile(model, fullgraph=True, dynamic=True, disable=False)
610assert compiled._compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
611unwrapped, compile_kwargs = _unwrap_compiled(compiled)
612assert unwrapped is compiled._orig_mod
613assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
614
615del compiled._compile_kwargs
616with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
617_unwrap_compiled(compiled)
618