pytorch
159 строк · 5.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6import torch.distributed.fsdp._traversal_utils as traversal_utils
7from torch import distributed as dist
8from torch.distributed.fsdp import (
9CPUOffload,
10FullyShardedDataParallel as FSDP,
11MixedPrecision,
12)
13from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14from torch.testing._internal.common_fsdp import (
15CUDAInitMode,
16FSDPInitMode,
17FSDPTest,
18NestedWrappedModule,
19)
20from torch.testing._internal.common_utils import (
21instantiate_parametrized_tests,
22run_tests,
23TEST_WITH_DEV_DBG_ASAN,
24)
25
26if not dist.is_available():
27print("Distributed not available, skipping tests", file=sys.stderr)
28sys.exit(0)
29
30if TEST_WITH_DEV_DBG_ASAN:
31print(
32"Skip dev-asan as torch + multiprocessing spawn have known issues",
33file=sys.stderr,
34)
35sys.exit(0)
36
37
38class TestPureFP16(FSDPTest):
39@property
40def world_size(self):
41# Test fails due to inaccuracies when using more than 4 GPUs
42return min(4, super().world_size)
43
44@skip_if_lt_x_gpu(2)
45def test_pure_fp16_training(self):
46"""Tests pure FP16 training, including when the parameter's dtype is
47changed after FSDP initialization and before training."""
48self.run_subtests(
49{
50"cpu_offload": [
51CPUOffload(offload_params=True),
52CPUOffload(offload_params=False),
53]
54},
55self._test_pure_fp16_training,
56)
57
58def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
59self._test_fsdp_parity(
60NestedWrappedModule,
61FSDPInitMode.RECURSIVE,
62cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
63# Run one iteration to avoid NaN without a gradient scaler
64num_iters=1,
65cpu_offload=cpu_offload,
66use_pure_fp16=True,
67)
68
69@skip_if_lt_x_gpu(2)
70def test_fp16_dtypes(self):
71"""
72Tests that both user-facing parameter/gradient dtypes and internal
73saved dtype attributes are as expected when using an FP16 model
74possibly with explicit mixed precision enabled.
75"""
76self.run_subtests(
77{
78"to_half_before_fsdp_init": [False, True],
79"use_orig_params": [False, True],
80"mixed_precision": [
81MixedPrecision(),
82MixedPrecision(
83param_dtype=torch.float16,
84reduce_dtype=torch.float32,
85),
86MixedPrecision(
87param_dtype=torch.float32,
88),
89],
90},
91self._test_fp16_dtypes,
92)
93
94def _test_fp16_dtypes(
95self,
96to_half_before_fsdp_init: bool,
97use_orig_params: bool,
98mixed_precision: MixedPrecision,
99):
100model = NestedWrappedModule.init(
101self.process_group,
102FSDPInitMode.NO_FSDP,
103CUDAInitMode.CUDA_NEVER,
104{},
105)
106fsdp_kwargs = {
107"use_orig_params": use_orig_params,
108"device_id": torch.cuda.current_device(),
109"mixed_precision": mixed_precision,
110}
111if to_half_before_fsdp_init:
112model = model.half()
113fsdp_model = FSDP(model, **fsdp_kwargs)
114if not to_half_before_fsdp_init:
115fsdp_model = fsdp_model.half()
116for param in fsdp_model.parameters():
117self.assertEqual(param.dtype, torch.float16)
118inp = tuple(
119t.half() if torch.is_tensor(t) else t
120for t in fsdp_model.module.get_input(torch.device("cuda"))
121)
122out = fsdp_model(*inp)
123out.sum().backward()
124
125# Check handle dtype attributes
126for handle in traversal_utils._get_fsdp_handles(fsdp_model):
127self.assertEqual(handle.flat_param.dtype, torch.float16)
128self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
129self.assertEqual(handle._orig_param_dtype, torch.float16)
130# Specifying `mixed_precision` takes precedence over the model
131# dtype for both `param_dtype` and `reduce_dtype`
132if mixed_precision.param_dtype is not None:
133self.assertEqual(
134handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
135)
136else:
137self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
138if mixed_precision.reduce_dtype is not None:
139self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
140elif (
141mixed_precision.reduce_dtype is None
142and mixed_precision.param_dtype is not None
143):
144# Special case: infer reduce dtype from parameter dtype
145self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
146else:
147self.assertEqual(handle._reduce_dtype, torch.float16)
148
149# Check parameter/gradient dtypes
150for param in fsdp_model.parameters():
151self.assertEqual(param.dtype, torch.float16)
152if param.grad is not None:
153self.assertEqual(param.grad.dtype, torch.float16)
154
155
156instantiate_parametrized_tests(TestPureFP16)
157
158if __name__ == "__main__":
159run_tests()
160