pytorch-lightning

Форк
0
617 строк · 23.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from unittest import mock
15
from unittest.mock import Mock, call
16

17
import pytest
18
import torch
19
from lightning.fabric.fabric import Fabric
20
from lightning.fabric.plugins import Precision
21
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
22
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
23
from lightning.fabric.wrappers import (
24
    _FabricDataLoader,
25
    _FabricModule,
26
    _FabricOptimizer,
27
    _unwrap_compiled,
28
    _unwrap_objects,
29
    is_wrapped,
30
)
31
from torch.utils.data import BatchSampler, DistributedSampler
32
from torch.utils.data.dataloader import DataLoader
33

34
from tests_fabric.helpers.runif import RunIf
35

36

37
def test_fabric_module_wraps():
38
    """Test that the wrapped module is accessible via the property."""
39
    module = Mock()
40
    assert _FabricModule(module, Mock()).module is module
41

42
    wrapped_module = Mock()
43
    original_module = Mock()
44
    assert _FabricModule(wrapped_module, Mock(), original_module=original_module).module is original_module
45

46

47
def test_fabric_module_attribute_lookup():
48
    """Test that attribute lookup passes through to the original module when possible."""
49

50
    class OriginalModule(torch.nn.Module):
51
        def __init__(self):
52
            super().__init__()
53
            self.layer = torch.nn.Linear(2, 3)
54
            self.attribute = 1
55

56
        def method(self):
57
            return 2
58

59
    original_module = OriginalModule()
60

61
    class ModuleWrapper(torch.nn.Module):
62
        def __init__(self):
63
            super().__init__()
64
            self.wrapped = original_module
65

66
    wrapped_module = ModuleWrapper()
67

68
    fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
69
    assert fabric_module.attribute == 1
70
    assert fabric_module.layer is original_module.layer
71
    assert fabric_module.forward.__self__.__class__ == _FabricModule
72

73
    with pytest.raises(AttributeError):
74
        _ = fabric_module.not_exists
75

76

77
def test_fabric_module_method_lookup():
78
    """Test that access to methods warns about improper use when a wrapper from a strategy is involved."""
79

80
    class OriginalModule(torch.nn.Module):
81
        def __init__(self):
82
            super().__init__()
83
            self.submodule = torch.nn.Linear(2, 3)
84

85
        def forward(self, x):
86
            return x
87

88
        def method_without_module_invocation(self):
89
            return 100
90

91
        def method_with_submodule_invocation(self):
92
            self.submodule(torch.rand(2, 2))
93
            return 101
94

95
        def method_with_self_invocation(self):
96
            self(None)
97
            return 102
98

99
    class ModuleWrapper(torch.nn.Module):
100
        def __init__(self, module):
101
            super().__init__()
102
            self.wrapped = module
103

104
    # Regular case: forward_module == original_module -> no warnings
105
    original_module = OriginalModule()
106
    fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
107
    assert fabric_module.method_without_module_invocation() == 100
108

109
    # Special case: original module wrapped by forward module: -> warn if method accepts args
110
    original_module = OriginalModule()
111
    wrapped_module = ModuleWrapper(original_module)
112
    fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
113
    assert fabric_module.method_without_module_invocation() == 100
114
    with pytest.raises(
115
        RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
116
    ):
117
        assert fabric_module.method_with_submodule_invocation() == 101
118
    with pytest.raises(
119
        RuntimeError, match=r"You are calling the method `OriginalModule.method_with_self_invocation\(\)` from"
120
    ):
121
        assert fabric_module.method_with_self_invocation() == 102
122

123

124
def test_fabric_module_setattr():
125
    """Test that setattr sets attributes on the original module."""
126

127
    class OriginalModule(torch.nn.Module):
128
        def __init__(self):
129
            super().__init__()
130
            self.layer = torch.nn.Linear(2, 3)
131
            self.attribute = 1
132
            self._x = None
133

134
        @property
135
        def x(self):
136
            return self._x
137

138
        @x.setter
139
        def x(self, value):
140
            self._x = value
141

142
    original_module = OriginalModule()
143

144
    class ModuleWrapper(torch.nn.Module):
145
        def __init__(self):
146
            super().__init__()
147
            self.wrapped = original_module
148

149
    wrapped_module = ModuleWrapper()
150
    fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
151

152
    # Check new attribute is set on original_module
153
    fabric_module.new_attribute = 100
154
    assert original_module.new_attribute == 100
155

156
    # Modify existing attribute on original_module
157
    fabric_module.attribute = 101
158
    assert original_module.attribute == 101
159

160
    # Check setattr of original_module
161
    fabric_module.x = 102
162
    assert original_module.x == 102
163

164
    # Check set submodule
165
    assert not hasattr(original_module, "linear")
166
    linear = torch.nn.Linear(2, 2)
167
    fabric_module.linear = linear
168
    assert hasattr(original_module, "linear")
169
    assert isinstance(original_module.linear, torch.nn.Module)
170
    assert linear in fabric_module.modules()
171
    assert linear in original_module.modules()
172

173

174
def test_fabric_module_state_dict_access():
175
    """Test that state_dict access passes through to the original module."""
176

177
    class OriginalModule(torch.nn.Module):
178
        def __init__(self):
179
            super().__init__()
180
            self.layer = torch.nn.Linear(2, 3)
181

182
    original_module = OriginalModule()
183

184
    class ModuleWrapper(torch.nn.Module):
185
        def __init__(self):
186
            super().__init__()
187
            self.wrapped = original_module
188

189
    wrapped_module = ModuleWrapper()
190

191
    fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
192
    state_dict = fabric_module.state_dict()
193
    assert set(state_dict.keys()) == {"layer.weight", "layer.bias"}
194

195
    weight, bias = torch.rand(3, 2), torch.rand(3)
196
    fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias})
197
    assert torch.equal(fabric_module.layer.weight, weight)
198
    assert torch.equal(fabric_module.layer.bias, bias)
199

200
    if _TORCH_GREATER_EQUAL_2_1:
201
        # Can use additional `assign` argument in PyTorch >= 2.1
202
        with torch.device("meta"):
203
            original_module = OriginalModule()
204
        fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
205
        assert fabric_module.layer.weight.is_meta
206
        fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
207
        assert not fabric_module.layer.weight.is_meta
208

209

210
@pytest.mark.parametrize(
211
    ("precision", "input_type", "expected_type", "accelerator", "device_str"),
212
    [
213
        pytest.param(32, torch.float16, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
214
        pytest.param(32, torch.float32, torch.float32, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
215
        pytest.param(32, torch.float64, torch.float64, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
216
        pytest.param(32, torch.int, torch.int, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
217
        pytest.param(16, torch.float32, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
218
        pytest.param(16, torch.float64, torch.float16, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
219
        pytest.param(16, torch.long, torch.long, "gpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
220
        pytest.param(
221
            "bf16",
222
            torch.float32,
223
            torch.bfloat16,
224
            "gpu",
225
            "cuda:0",
226
            marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
227
        ),
228
        pytest.param(
229
            "bf16",
230
            torch.float64,
231
            torch.bfloat16,
232
            "gpu",
233
            "cuda:0",
234
            marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
235
        ),
236
        pytest.param(
237
            "bf16",
238
            torch.bool,
239
            torch.bool,
240
            "gpu",
241
            "cuda:0",
242
            marks=RunIf(min_cuda_gpus=1, bf16_cuda=True),
243
        ),
244
        pytest.param(32, torch.float32, torch.float32, "mps", "mps:0", marks=RunIf(mps=True)),
245
    ],
246
)
247
def test_fabric_module_forward_conversion(precision, input_type, expected_type, accelerator, device_str):
248
    """Test that the FabricModule performs autocasting on the input tensors and during forward()."""
249
    fabric = Fabric(precision=precision, accelerator=accelerator, devices=1)
250
    device = torch.device(device_str)
251

252
    def check_autocast(forward_input):
253
        assert precision != 16 or torch.is_autocast_enabled()
254
        return forward_input
255

256
    module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
257
    fabric_module = _FabricModule(module, fabric._strategy).to(device)
258
    out = fabric_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
259
    assert module.call_args[0][0].dtype == expected_type
260
    assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
261

262

263
@pytest.mark.parametrize(
264
    "device_str",
265
    [
266
        "cpu",
267
        pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
268
        pytest.param("mps", marks=RunIf(mps=True)),
269
    ],
270
)
271
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
272
def test_fabric_module_device_dtype_propagation(device_str, dtype):
273
    """Test that the FabricModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""
274
    device = torch.device(device_str)
275

276
    class DeviceModule(_DeviceDtypeModuleMixin):
277
        pass
278

279
    device_module = DeviceModule()
280
    fabric_module = _FabricModule(device_module, Mock())
281
    fabric_module.to(device)
282
    assert device_module.device == device
283
    assert fabric_module.device == device
284

285
    fabric_module.to(dtype)
286
    assert device_module.dtype == dtype
287
    assert fabric_module.dtype == dtype
288

289

290
def test_fabric_dataloader_iterator():
291
    """Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no automatic
292
    device placement)."""
293
    dataloader = DataLoader(range(5), batch_size=2)
294
    fabric_dataloader = _FabricDataLoader(dataloader)
295
    assert len(fabric_dataloader) == len(dataloader) == 3
296

297
    iterator = iter(dataloader)
298
    fabric_iterator = iter(fabric_dataloader)
299

300
    assert torch.equal(next(iterator), next(fabric_iterator))
301
    assert torch.equal(next(iterator), next(fabric_iterator))
302
    assert torch.equal(next(iterator), next(fabric_iterator))
303

304
    with pytest.raises(StopIteration):
305
        next(iterator)
306

307
    with pytest.raises(StopIteration):
308
        next(fabric_iterator)
309

310

311
@pytest.mark.parametrize(
312
    ("src_device_str", "dest_device_str"),
313
    [
314
        ("cpu", "cpu"),
315
        pytest.param("cpu", "cuda:0", marks=RunIf(min_cuda_gpus=1)),
316
        pytest.param("cuda:0", "cpu", marks=RunIf(min_cuda_gpus=1)),
317
        pytest.param("cpu", "mps", marks=RunIf(mps=True)),
318
        pytest.param("mps", "cpu", marks=RunIf(mps=True)),
319
    ],
320
)
321
def test_fabric_dataloader_device_placement(src_device_str, dest_device_str):
322
    """Test that the FabricDataLoader moves data to the device in its iterator."""
323
    src_device = torch.device(src_device_str)
324
    dest_device = torch.device(dest_device_str)
325

326
    sample0 = torch.tensor(0, device=src_device)
327
    sample1 = torch.tensor(1, device=src_device)
328
    sample2 = {"data": torch.tensor(2, device=src_device)}
329
    sample3 = {"data": torch.tensor(3, device=src_device)}
330
    dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
331
    fabric_dataloader = _FabricDataLoader(dataloader=dataloader, device=dest_device)
332
    iterator = iter(fabric_dataloader)
333

334
    batch0 = next(iterator)
335
    assert torch.equal(batch0, torch.tensor([0, 1], device=dest_device))
336

337
    batch1 = next(iterator)
338
    assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
339

340

341
@pytest.mark.parametrize("use_batch_sampler", [False, True])
342
def test_fabric_dataloader_distributed_sampler_set_epoch(use_batch_sampler):
343
    """Test that the FabricDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
344
    dataset = range(3)
345
    sampler = DistributedSampler(dataset, num_replicas=2, rank=0)
346
    sampler.set_epoch = Mock()
347

348
    if not use_batch_sampler:
349
        dataloader = DataLoader(dataset, sampler=sampler)
350
    else:
351
        batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False)
352
        dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
353

354
    fabric_dataloader = _FabricDataLoader(dataloader)
355
    iterator_epoch_0 = iter(fabric_dataloader)
356
    sampler.set_epoch.assert_not_called()
357

358
    next(iterator_epoch_0)
359
    # .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
360
    assert sampler.set_epoch.mock_calls == [call(0)]
361

362
    next(iterator_epoch_0)
363
    assert sampler.set_epoch.mock_calls == [call(0)]
364

365
    iterator_epoch_1 = iter(fabric_dataloader)
366
    assert sampler.set_epoch.mock_calls == [call(0)]
367

368
    next(iterator_epoch_1)
369
    # with every new iterator call, the epoch increases
370
    assert sampler.set_epoch.mock_calls == [call(0), call(1)]
371

372

373
def test_fabric_optimizer_wraps():
374
    """Test that the FabricOptimizer fully wraps the optimizer."""
375
    optimizer_cls = torch.optim.SGD
376
    optimizer = Mock(spec=optimizer_cls)
377
    fabric_optimizer = _FabricOptimizer(optimizer, Mock())
378
    assert fabric_optimizer.optimizer is optimizer
379
    assert isinstance(fabric_optimizer, optimizer_cls)
380
    assert isinstance(fabric_optimizer, _FabricOptimizer)
381
    assert type(fabric_optimizer).__name__ == "FabricSGD"
382

383

384
def test_fabric_optimizer_state_dict():
385
    """Test that the FabricOptimizer calls into the strategy to collect the state."""
386
    optimizer = Mock(spec=torch.optim.Adam)
387
    strategy = Mock()
388
    fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
389
    fabric_optimizer.state_dict()
390
    strategy.get_optimizer_state.assert_called_with(optimizer)
391

392

393
def test_fabric_optimizer_load_state_dict():
394
    """Test that the FabricOptimizer can load the state dict on the wrapped optimizer and update its internal
395
    `__dict__`."""
396
    model = torch.nn.Linear(1, 1)
397
    optimizer = torch.optim.Adam(model.parameters())
398
    assert not optimizer.state  # a fresh optimizer has no state
399
    model(torch.rand(1)).backward()
400
    optimizer.step()
401
    assert optimizer.state
402
    state_dict = optimizer.state_dict()
403

404
    optimizer = torch.optim.Adam(model.parameters())  # fresh optimizer
405
    fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
406
    assert not fabric_optimizer.state  # a fresh optimizer has no state
407
    fabric_optimizer.load_state_dict(state_dict)
408
    assert fabric_optimizer.state
409
    assert fabric_optimizer.optimizer.state_dict()["state"] == state_dict["state"]
410

411

412
def test_fabric_optimizer_steps():
413
    """Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
414
    optimizer = Mock(spec=torch.optim.Adam)
415
    strategy = Mock(spec=["optimizer_step"])
416
    strategy.optimizer_step.return_value = 123
417
    fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
418
    step_output = fabric_optimizer.step()
419
    assert step_output == 123
420
    strategy.optimizer_step.assert_called_once_with(optimizer)
421

422
    strategy.reset_mock()
423

424
    # with closure as input
425
    closure = Mock()
426
    fabric_optimizer.step(closure=closure)
427
    strategy.optimizer_step.assert_called_once_with(optimizer, closure=closure)
428

429
    # with model as optimizer
430
    strategy = Mock(spec=["optimizer_step", "model"])
431
    fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
432
    fabric_optimizer.step()
433
    strategy.optimizer_step.assert_called_once_with(strategy.model)
434

435

436
def test_fabric_optimizer_zero_grad_kwargs():
437
    """Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""
438
    # Test PyTorch's standard `.zero_grad()` signature
439
    with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
440
        optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
441
        fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
442
        fabric_optimizer.zero_grad()
443
        zero_grad_mock.assert_called_with()
444
        fabric_optimizer.zero_grad(set_to_none=False)
445
        zero_grad_mock.assert_called_with(set_to_none=False)
446
        fabric_optimizer.zero_grad(set_to_none=True)
447
        zero_grad_mock.assert_called_with(set_to_none=True)
448

449
    # Test weird `.zero_grad()` signatures from other libraries
450
    custom_zero_grad = Mock()
451

452
    class CustomSGD(torch.optim.SGD):
453
        def zero_grad(self, set_grads_to_None=False):
454
            custom_zero_grad(set_grads_to_None=set_grads_to_None)
455

456
    optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
457
    fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
458
    fabric_optimizer.zero_grad()
459
    custom_zero_grad.assert_called_with(set_grads_to_None=False)
460

461

462
@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
463
def test_is_wrapped(compile):
464
    """Test that the `is_wrapped` utility recognizes when an object was wrapped by Fabric."""
465
    assert not is_wrapped(None)
466

467
    # _FabricModule
468
    module = torch.nn.Linear(2, 2)
469
    assert not is_wrapped(module)
470
    wrapped = _FabricModule(module, Mock())
471
    assert is_wrapped(wrapped)
472

473
    # _FabricModule inside an OptimizedModule
474
    if compile:
475
        from torch._dynamo import OptimizedModule
476

477
        module = torch.nn.Linear(2, 2)
478
        wrapped = torch.compile(_FabricModule(module, Mock()))
479
        assert isinstance(wrapped, OptimizedModule)
480
        assert is_wrapped(wrapped)
481

482
    # _FabricOptimizer
483
    optimizer = torch.optim.Adam(module.parameters())
484
    assert not is_wrapped(optimizer)
485
    wrapped = _FabricOptimizer(optimizer, Mock())
486
    assert is_wrapped(wrapped)
487

488
    # _FabricDataLoader
489
    dataloader = DataLoader([1, 2, 3])
490
    assert not is_wrapped(dataloader)
491
    wrapped = _FabricDataLoader(dataloader)
492
    assert is_wrapped(wrapped)
493

494

495
@pytest.mark.parametrize("compile", [False, pytest.param(True, marks=RunIf(dynamo=True))])
496
def test_unwrap_objects(compile):
497
    # empty container
498
    assert _unwrap_objects({}) == {}
499

500
    # container with pure objects and wrapped objects
501
    module = torch.nn.Linear(1, 1)
502
    wrapped_module = _FabricModule(module, Mock())
503
    if compile:
504
        wrapped_module = torch.compile(wrapped_module)
505
    optimizer = torch.optim.Adam(module.parameters())
506
    wrapped_optimizer = _FabricOptimizer(optimizer, Mock())
507
    dataloader = DataLoader([1, 2, 3])
508
    wrapped_dataloader = _FabricDataLoader(dataloader)
509
    container = {
510
        "int": 1,
511
        "module": module,
512
        "wrapped_module": wrapped_module,
513
        "optimizer": optimizer,
514
        "wrapped_optimizer": wrapped_optimizer,
515
        "dataloader": dataloader,
516
        "wrapped_dataloader": wrapped_dataloader,
517
        "nested": [module, wrapped_module, optimizer, wrapped_optimizer, dataloader, wrapped_dataloader],
518
    }
519
    expected = {
520
        "int": 1,
521
        "module": module,
522
        "wrapped_module": wrapped_module._forward_module,
523
        "optimizer": optimizer,
524
        "wrapped_optimizer": optimizer,
525
        "dataloader": dataloader,
526
        "wrapped_dataloader": dataloader,
527
        "nested": [module, wrapped_module._forward_module, optimizer, optimizer, dataloader, dataloader],
528
    }
529
    assert _unwrap_objects(container) == expected
530

531

532
def test_step_method_redirection():
533
    """Test that the FabricModule redirects the special `LightningModule.*_step` methods through the forward-
534
    module."""
535

536
    class DDP(torch.nn.Module):
537
        def __init__(self, module):
538
            super().__init__()
539
            self.module = module
540

541
        def forward(self, *args, **kwargs):
542
            return self.module(*args, **kwargs)
543

544
    class LightningModule(torch.nn.Module):
545
        def forward(self):
546
            return "forward_return"
547

548
        def training_step(self, arg, kwarg=None):
549
            assert self() == "forward_return"
550
            assert arg == "train_arg"
551
            assert kwarg == "train_kwarg"
552
            return "training_step_return"
553

554
        def validation_step(self, arg, kwarg=None):
555
            assert self() == "forward_return"
556
            assert arg == "val_arg"
557
            assert kwarg == "val_kwarg"
558
            return "validation_step_return"
559

560
        def normal_method(self):
561
            pass
562

563
    strategy = Mock()
564
    strategy.precision = Mock(wraps=Precision())
565
    original_module = LightningModule()
566
    forward_module = DDP(original_module)
567
    fabric_module = _FabricModule(forward_module=forward_module, strategy=strategy, original_module=original_module)
568

569
    # Regular methods on the original_module are visible and identical on the fabric_module ...
570
    assert fabric_module.normal_method.__wrapped__ == original_module.normal_method
571

572
    # ... but special methods like training_step get redirected to the forward_module
573
    assert fabric_module.training_step.__name__ == "call_forward_module"
574
    assert fabric_module.validation_step.__name__ == "call_forward_module"
575
    assert fabric_module.test_step.__name__ == "call_forward_module"
576
    assert fabric_module.predict_step.__name__ == "call_forward_module"
577

578
    with pytest.raises(AttributeError, match="has no attribute 'predict_step'"):
579
        # A special method that does not exist will raise its AttributeError when being called
580
        fabric_module.predict_step()
581

582
    # The forward method on the original module remains untouched
583
    assert original_module.forward.__name__ == "forward"
584

585
    # The special methods get redirected correctly to produce the expected output
586
    assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
587
    assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"  # call 2nd time
588
    assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
589
    strategy.precision.forward_context.assert_called()
590

591
    # The forward method remains untouched/unpatched after the special methods have been called
592
    assert original_module.forward.__name__ == "forward"
593

594
    # Special case: forward_module == original_module -> no special treatment applied
595
    fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
596
    assert fabric_module.training_step == original_module.training_step
597
    assert fabric_module.validation_step == original_module.validation_step
598

599

600
@RunIf(dynamo=True)
601
def test_unwrap_compiled():
602
    model = torch.nn.Linear(1, 1)
603

604
    with mock.patch("lightning.fabric.wrappers", "_TORCH_GREATER_EQUAL_2_0", False):
605
        unwrapped, compile_kwargs = _unwrap_compiled(model)
606
    assert unwrapped is model
607
    assert compile_kwargs is None
608

609
    compiled = torch.compile(model, fullgraph=True, dynamic=True, disable=False)
610
    assert compiled._compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
611
    unwrapped, compile_kwargs = _unwrap_compiled(compiled)
612
    assert unwrapped is compiled._orig_mod
613
    assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
614

615
    del compiled._compile_kwargs
616
    with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
617
        _unwrap_compiled(compiled)
618

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.