pytorch

Форк
0
/
test_fsdp_pure_fp16.py 
159 строк · 5.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6
import torch.distributed.fsdp._traversal_utils as traversal_utils
7
from torch import distributed as dist
8
from torch.distributed.fsdp import (
9
    CPUOffload,
10
    FullyShardedDataParallel as FSDP,
11
    MixedPrecision,
12
)
13
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14
from torch.testing._internal.common_fsdp import (
15
    CUDAInitMode,
16
    FSDPInitMode,
17
    FSDPTest,
18
    NestedWrappedModule,
19
)
20
from torch.testing._internal.common_utils import (
21
    instantiate_parametrized_tests,
22
    run_tests,
23
    TEST_WITH_DEV_DBG_ASAN,
24
)
25

26
if not dist.is_available():
27
    print("Distributed not available, skipping tests", file=sys.stderr)
28
    sys.exit(0)
29

30
if TEST_WITH_DEV_DBG_ASAN:
31
    print(
32
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
33
        file=sys.stderr,
34
    )
35
    sys.exit(0)
36

37

38
class TestPureFP16(FSDPTest):
39
    @property
40
    def world_size(self):
41
        # Test fails due to inaccuracies when using more than 4 GPUs
42
        return min(4, super().world_size)
43

44
    @skip_if_lt_x_gpu(2)
45
    def test_pure_fp16_training(self):
46
        """Tests pure FP16 training, including when the parameter's dtype is
47
        changed after FSDP initialization and before training."""
48
        self.run_subtests(
49
            {
50
                "cpu_offload": [
51
                    CPUOffload(offload_params=True),
52
                    CPUOffload(offload_params=False),
53
                ]
54
            },
55
            self._test_pure_fp16_training,
56
        )
57

58
    def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
59
        self._test_fsdp_parity(
60
            NestedWrappedModule,
61
            FSDPInitMode.RECURSIVE,
62
            cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
63
            # Run one iteration to avoid NaN without a gradient scaler
64
            num_iters=1,
65
            cpu_offload=cpu_offload,
66
            use_pure_fp16=True,
67
        )
68

69
    @skip_if_lt_x_gpu(2)
70
    def test_fp16_dtypes(self):
71
        """
72
        Tests that both user-facing parameter/gradient dtypes and internal
73
        saved dtype attributes are as expected when using an FP16 model
74
        possibly with explicit mixed precision enabled.
75
        """
76
        self.run_subtests(
77
            {
78
                "to_half_before_fsdp_init": [False, True],
79
                "use_orig_params": [False, True],
80
                "mixed_precision": [
81
                    MixedPrecision(),
82
                    MixedPrecision(
83
                        param_dtype=torch.float16,
84
                        reduce_dtype=torch.float32,
85
                    ),
86
                    MixedPrecision(
87
                        param_dtype=torch.float32,
88
                    ),
89
                ],
90
            },
91
            self._test_fp16_dtypes,
92
        )
93

94
    def _test_fp16_dtypes(
95
        self,
96
        to_half_before_fsdp_init: bool,
97
        use_orig_params: bool,
98
        mixed_precision: MixedPrecision,
99
    ):
100
        model = NestedWrappedModule.init(
101
            self.process_group,
102
            FSDPInitMode.NO_FSDP,
103
            CUDAInitMode.CUDA_NEVER,
104
            {},
105
        )
106
        fsdp_kwargs = {
107
            "use_orig_params": use_orig_params,
108
            "device_id": torch.cuda.current_device(),
109
            "mixed_precision": mixed_precision,
110
        }
111
        if to_half_before_fsdp_init:
112
            model = model.half()
113
        fsdp_model = FSDP(model, **fsdp_kwargs)
114
        if not to_half_before_fsdp_init:
115
            fsdp_model = fsdp_model.half()
116
        for param in fsdp_model.parameters():
117
            self.assertEqual(param.dtype, torch.float16)
118
        inp = tuple(
119
            t.half() if torch.is_tensor(t) else t
120
            for t in fsdp_model.module.get_input(torch.device("cuda"))
121
        )
122
        out = fsdp_model(*inp)
123
        out.sum().backward()
124

125
        # Check handle dtype attributes
126
        for handle in traversal_utils._get_fsdp_handles(fsdp_model):
127
            self.assertEqual(handle.flat_param.dtype, torch.float16)
128
            self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
129
            self.assertEqual(handle._orig_param_dtype, torch.float16)
130
            # Specifying `mixed_precision` takes precedence over the model
131
            # dtype for both `param_dtype` and `reduce_dtype`
132
            if mixed_precision.param_dtype is not None:
133
                self.assertEqual(
134
                    handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
135
                )
136
            else:
137
                self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
138
            if mixed_precision.reduce_dtype is not None:
139
                self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
140
            elif (
141
                mixed_precision.reduce_dtype is None
142
                and mixed_precision.param_dtype is not None
143
            ):
144
                # Special case: infer reduce dtype from parameter dtype
145
                self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
146
            else:
147
                self.assertEqual(handle._reduce_dtype, torch.float16)
148

149
        # Check parameter/gradient dtypes
150
        for param in fsdp_model.parameters():
151
            self.assertEqual(param.dtype, torch.float16)
152
            if param.grad is not None:
153
                self.assertEqual(param.grad.dtype, torch.float16)
154

155

156
instantiate_parametrized_tests(TestPureFP16)
157

158
if __name__ == "__main__":
159
    run_tests()
160

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.