pytorch

Форк
0
/
test_fsdp_grad_acc.py 
304 строки · 10.8 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import contextlib
4
import itertools
5
import sys
6
from dataclasses import dataclass
7
from typing import Any, Dict, List, Optional, Tuple
8

9
import torch
10
from torch import distributed as dist
11
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
12
from torch.distributed.fsdp.fully_sharded_data_parallel import (
13
    BackwardPrefetch,
14
    ShardingStrategy,
15
)
16
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
17
from torch.testing._internal.common_fsdp import (
18
    CUDAInitMode,
19
    FSDPInitMode,
20
    FSDPTest,
21
    TransformerWithSharedParams,
22
)
23
from torch.testing._internal.common_utils import (
24
    instantiate_parametrized_tests,
25
    parametrize,
26
    run_tests,
27
    TEST_WITH_DEV_DBG_ASAN,
28
)
29

30
if not dist.is_available():
31
    print("Distributed not available, skipping tests", file=sys.stderr)
32
    sys.exit(0)
33

34
if TEST_WITH_DEV_DBG_ASAN:
35
    print(
36
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
37
        file=sys.stderr,
38
    )
39
    sys.exit(0)
40

41

42
@dataclass
43
class _GradAccConfig:
44
    """
45
    This configures how gradients are accumulated in :meth:`_test_grad_acc`.
46
    Each instance of this class represents ``num_iters``-many consecutive
47
    iterations, where the ``no_sync()`` context manager is used or not as given
48
    by ``use_no_sync``.
49

50
    Attributes:
51
        use_no_sync (bool): Indicates whether to use the ``no_sync()`` context
52
            manager as the way to accumulate gradients.
53
        num_iters (int): Number of iterations to accumulate gradients.
54
    """
55

56
    use_no_sync: bool
57
    num_iters: int
58

59
    def __repr__(self) -> str:
60
        # Override to remove any spaces in the string to appease the internal
61
        # build's test name parser
62
        return f"(use_no_sync={self.use_no_sync}," f"num_iters={self.num_iters})"
63

64

65
@dataclass
66
class _GradAccConfigs:
67
    """
68
    This wraps a :class:`list` of :class:`_GradAccConfig` instances with the
69
    sole purpose of overriding :meth:`__repr__` to remove spaces.
70
    """
71

72
    configs: List[_GradAccConfig]
73

74
    def __repr__(self) -> str:
75
        # Override to remove any spaces in the string to appease the internal
76
        # build's test name parser
77
        return "[" + ",".join(config.__repr__() for config in self.configs) + "]"
78

79

80
class TestGradAcc(FSDPTest):
81
    """Tests ``FullyShardedDataParallel``'s gradient accumulation via both its
82
    ``no_sync()`` context manager and without the context manager."""
83

84
    @property
85
    def world_size(self) -> int:
86
        return 2
87

88
    def _test_grad_acc(
89
        self,
90
        batch_dim: int,
91
        configs: List[_GradAccConfig],
92
        cpu_offload: CPUOffload,
93
        backward_prefetch: Optional[BackwardPrefetch],
94
        sharding_strategy: ShardingStrategy,
95
        use_orig_params: bool,
96
    ):
97
        """
98
        Tests gradient accumulation by comparing a run that trains sequentially
99
        through some batches while accumulating gradients with a run that
100
        trains on the concatenation of those batches in a single iteration.
101

102
        The last iteration always synchronizes gradients regardless of what is
103
        specified by the last element of ``configs``.
104

105
        Arguments:
106
            batch_dim (int): Batch dimension in the input tensor to be passed
107
                into the model for the forward pass.
108
            configs (List[_GradAccConfig]): :class:`list` of configurations
109
                specifying how gradients are accumulated; for example, a list
110
                corresponding to [(False, 2), (True, 2), (False, 2)] indicates
111
                to accumulate over 2 + 2 + 2 = 6 total iterations, where the
112
                first two do not use ``no_sync()``, the middle two do use
113
                ``no_sync()``, and the final two again do not use
114
                ``no_sync()``.
115
            cpu_offload (CPUOffload): Configures CPU offloading.
116
            backward_prefetch (Optional[BackwardPrefetch]): Specifies at which
117
                point to prefetch the next layer's full parameters during the
118
                backward pass, if at all.
119
        """
120
        # Initialize the FSDP model and optimizer
121
        fsdp_kwargs = {
122
            "cpu_offload": cpu_offload,
123
            "backward_prefetch": backward_prefetch,
124
            "sharding_strategy": sharding_strategy,
125
            "use_orig_params": use_orig_params,
126
        }
127
        fsdp_model: FSDP = TransformerWithSharedParams.init(
128
            self.process_group,
129
            FSDPInitMode.RECURSIVE,
130
            CUDAInitMode.CUDA_BEFORE,
131
            fsdp_kwargs,
132
            deterministic=True,
133
            add_bn=False,  # disable BN since the test uses varying batch sizes
134
        )
135
        device = torch.device("cuda")
136
        optim = torch.optim.SGD(
137
            fsdp_model.parameters(),
138
            lr=0.01,
139
            momentum=0.9,
140
        )
141

142
        # Generate the sequence of batches, each containing the same data
143
        # but permuted
144
        def permute_tensor(x: torch.Tensor):
145
            return x.view(-1)[torch.randperm(x.numel())].view_as(x)
146

147
        batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)
148
        batches: List[Tuple[torch.Tensor, ...]] = [batch]
149
        num_iters_to_acc = sum(config.num_iters for config in configs)
150
        for _ in range(num_iters_to_acc - 1):
151
            batches.append(tuple(permute_tensor(t) for t in batch))
152
        for batch1, batch2 in itertools.combinations(batches, r=2):
153
            for t1, t2 in zip(batch1, batch2):
154
                assert not torch.all(
155
                    t1 == t2
156
                ), "Check the test to make sure that batches are distinct"
157

158
        # Concatenate the batches along the given batch dimension
159
        concat_batch: Tuple[torch.Tensor, ...] = tuple(
160
            torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
161
        )
162

163
        # Establish reference gradients using the concatenated batch
164
        fsdp_model.zero_grad()
165
        output = fsdp_model(*concat_batch)
166
        ref_loss = fsdp_model.module.get_loss(concat_batch, output)
167
        ref_loss.backward()
168
        ref_grads = [
169
            p.grad.detach().clone()
170
            for p in fsdp_model.parameters()
171
            if p.grad is not None
172
        ]
173

174
        # Compute and accumulate the gradients
175
        fsdp_model.zero_grad()
176
        losses = []
177
        batch_idx = 0
178
        for config in configs:
179
            sync_context = (
180
                fsdp_model.no_sync() if config.use_no_sync else contextlib.nullcontext()
181
            )
182
            with sync_context:
183
                for _ in range(config.num_iters):
184
                    if batch_idx == num_iters_to_acc - 1:
185
                        break  # always sync on the last iteration
186
                    batch = batches[batch_idx]
187
                    batch_idx += 1
188
                    output = fsdp_model(*batch)
189
                    loss = fsdp_model.module.get_loss(batch, output)
190
                    loss.backward()
191
                    losses.append(loss)
192
        output = fsdp_model(*batches[-1])
193
        loss = fsdp_model.module.get_loss(batches[-1], output)
194
        loss.backward()
195
        losses.append(loss)
196
        acc_loss = sum(losses)
197
        acc_grads = [
198
            p.grad.detach().clone()
199
            for p in fsdp_model.parameters()
200
            if p.grad is not None
201
        ]
202

203
        # Compare the losses and gradients
204
        torch.testing.assert_close(ref_loss, acc_loss)
205
        self.assertEqual(len(ref_grads), len(acc_grads))
206
        for ref_grad, acc_grad in zip(ref_grads, acc_grads):
207
            self.assertEqual(ref_grad.device, acc_grad.device)
208
            self.assertEqual(ref_grad.size(), acc_grad.size())
209
            self.assertEqual(ref_grad.dtype, acc_grad.dtype)
210
            torch.testing.assert_close(ref_grad, acc_grad)
211

212
        # Check that the optimizer step does not error
213
        optim.step()
214

215
    def _get_subtest_config(self) -> Dict[str, List[Any]]:
216
        """Returns a subtest configuration that subtests prefetching."""
217
        return {
218
            "backward_prefetch": [
219
                None,
220
                BackwardPrefetch.BACKWARD_PRE,
221
                BackwardPrefetch.BACKWARD_POST,
222
            ],
223
            "sharding_strategy": [
224
                ShardingStrategy.FULL_SHARD,
225
                ShardingStrategy.SHARD_GRAD_OP,
226
                ShardingStrategy.NO_SHARD,
227
            ],
228
        }
229

230
    @skip_if_lt_x_gpu(2)
231
    @parametrize(
232
        "configs",
233
        [
234
            _GradAccConfigs(
235
                [
236
                    _GradAccConfig(use_no_sync=True, num_iters=3),
237
                    _GradAccConfig(use_no_sync=False, num_iters=3),
238
                    _GradAccConfig(use_no_sync=True, num_iters=3),
239
                ]
240
            ),
241
            _GradAccConfigs(
242
                [
243
                    _GradAccConfig(use_no_sync=False, num_iters=3),
244
                    _GradAccConfig(use_no_sync=True, num_iters=3),
245
                    _GradAccConfig(use_no_sync=False, num_iters=3),
246
                ]
247
            ),
248
        ],
249
    )
250
    @parametrize("use_orig_params", [False, True])
251
    def test_grad_acc(
252
        self,
253
        configs: _GradAccConfigs,
254
        use_orig_params: bool,
255
    ):
256
        """
257
        Tests gradient accumulation without parameter CPU offloading.
258

259
        This exercises gradient accumulation inside and outside the
260
        ``no_sync()`` context manager, in particular by interleaving the two.
261
        It tests both interleaving starting with (and ending with, resp.)
262
        inside versus outside ``no_sync()`` to ensure that initial conditions
263
        (and final conditions, resp.) do not affect the correctness.
264
        """
265
        subtest_config = self._get_subtest_config()
266
        subtest_config["cpu_offload"] = [CPUOffload(offload_params=False)]
267
        self.run_subtests(
268
            subtest_config,
269
            self._test_grad_acc,
270
            batch_dim=1,
271
            configs=configs.configs,
272
            use_orig_params=use_orig_params,
273
        )
274

275
    @skip_if_lt_x_gpu(2)
276
    @parametrize("use_orig_params", [False, True])
277
    def test_grad_acc_cpu_offload(
278
        self,
279
        use_orig_params: bool,
280
    ):
281
        """
282
        Tests gradient accumulation with parameter CPU offloading.
283

284
        NOTE: Gradient accumulation without using the ``no_sync()`` context
285
        manager is not currently compatible with CPU offloading.
286
        """
287
        # Only test `no_sync` since outside `no_sync()` is not supported with
288
        # parameter CPU offloading
289
        configs = _GradAccConfigs([_GradAccConfig(use_no_sync=True, num_iters=3)])
290
        subtest_config = self._get_subtest_config()
291
        subtest_config["cpu_offload"] = [CPUOffload(offload_params=True)]
292
        self.run_subtests(
293
            subtest_config,
294
            self._test_grad_acc,
295
            batch_dim=1,
296
            configs=configs.configs,
297
            use_orig_params=use_orig_params,
298
        )
299

300

301
instantiate_parametrized_tests(TestGradAcc)
302

303
if __name__ == "__main__":
304
    run_tests()
305

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.