pytorch

Форк
0
/
test_fsdp_apply.py 
102 строки · 3.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6
import torch.distributed as dist
7
import torch.nn as nn
8
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
10
from torch.testing._internal.common_fsdp import (
11
    CUDAInitMode,
12
    FSDPInitMode,
13
    FSDPTest,
14
    NestedWrappedModule,
15
    TransformerWithSharedParams,
16
)
17
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
18

19
if not dist.is_available():
20
    print("Distributed not available, skipping tests", file=sys.stderr)
21
    sys.exit(0)
22

23
if TEST_WITH_DEV_DBG_ASAN:
24
    print(
25
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
26
        file=sys.stderr,
27
    )
28
    sys.exit(0)
29

30

31
class TestApply(FSDPTest):
32
    @property
33
    def world_size(self):
34
        return 2
35

36
    @torch.no_grad()
37
    def _init_linear_weights(self, m):
38
        if type(m) == nn.Linear:
39
            m.weight.fill_(1.0)
40
            m.bias.fill_(1.0)
41

42
    def check_weights(self, fsdp, expected_tensor_fn, check):
43
        with FSDP.summon_full_params(fsdp, recurse=True):
44
            linear_modules = [
45
                module for module in fsdp.modules() if type(module) == nn.Linear
46
            ]
47
            for module in linear_modules:
48
                for param in module.parameters():
49
                    expected = expected_tensor_fn(param)
50
                    check(param, expected, f"Got {param} but expected {expected}")
51

52
    def _check_apply(self, fsdp):
53
        # Assert linear weights are not all 1.0
54
        self.check_weights(
55
            fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual
56
        )
57

58
        fsdp.apply(self._init_linear_weights)
59

60
        # Ensure all weights are 1.0
61
        self.check_weights(
62
            fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual
63
        )
64

65
    @skip_if_lt_x_gpu(2)
66
    def test_nested_module_apply(self):
67
        """Tests that ``apply()`` modifies parameter values in-place on a
68
        non-FSDP-root nested FSDP-wrapped model."""
69
        nested_wrapped_module = NestedWrappedModule.init(
70
            self.process_group,
71
            FSDPInitMode.RECURSIVE,
72
            CUDAInitMode.CUDA_AFTER,
73
        )
74
        self._check_apply(nested_wrapped_module)
75

76
    @skip_if_lt_x_gpu(2)
77
    def test_transformer_module_apply(self):
78
        """Tests that ``apply()`` modifies parameter values in-place on an
79
        FSDP-wrapped transformer model with shared parameters."""
80
        transformer = TransformerWithSharedParams.init(
81
            self.process_group,
82
            FSDPInitMode.RECURSIVE,
83
            CUDAInitMode.CUDA_AFTER,
84
        )
85
        self._check_apply(transformer)
86

87
    @skip_if_lt_x_gpu(2)
88
    def test_apply_in_summon_raises_error(self):
89
        """Tests that calling ``apply()`` on an FSDP instance inside the
90
        ``summon_full_params()`` context raises an error."""
91
        transformer = TransformerWithSharedParams.init(
92
            self.process_group,
93
            FSDPInitMode.RECURSIVE,
94
            CUDAInitMode.CUDA_AFTER,
95
        )
96
        with transformer.summon_full_params(transformer):
97
            with self.assertRaisesRegex(ValueError, "expected to be in states"):
98
                transformer.apply(self._init_linear_weights)
99

100

101
if __name__ == "__main__":
102
    run_tests()
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.