pytorch
304 строки · 10.8 Кб
1# Owner(s): ["oncall: distributed"]
2
3import contextlib4import itertools5import sys6from dataclasses import dataclass7from typing import Any, Dict, List, Optional, Tuple8
9import torch10from torch import distributed as dist11from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP12from torch.distributed.fsdp.fully_sharded_data_parallel import (13BackwardPrefetch,14ShardingStrategy,15)
16from torch.testing._internal.common_distributed import skip_if_lt_x_gpu17from torch.testing._internal.common_fsdp import (18CUDAInitMode,19FSDPInitMode,20FSDPTest,21TransformerWithSharedParams,22)
23from torch.testing._internal.common_utils import (24instantiate_parametrized_tests,25parametrize,26run_tests,27TEST_WITH_DEV_DBG_ASAN,28)
29
30if not dist.is_available():31print("Distributed not available, skipping tests", file=sys.stderr)32sys.exit(0)33
34if TEST_WITH_DEV_DBG_ASAN:35print(36"Skip dev-asan as torch + multiprocessing spawn have known issues",37file=sys.stderr,38)39sys.exit(0)40
41
42@dataclass
43class _GradAccConfig:44"""45This configures how gradients are accumulated in :meth:`_test_grad_acc`.
46Each instance of this class represents ``num_iters``-many consecutive
47iterations, where the ``no_sync()`` context manager is used or not as given
48by ``use_no_sync``.
49
50Attributes:
51use_no_sync (bool): Indicates whether to use the ``no_sync()`` context
52manager as the way to accumulate gradients.
53num_iters (int): Number of iterations to accumulate gradients.
54"""
55
56use_no_sync: bool57num_iters: int58
59def __repr__(self) -> str:60# Override to remove any spaces in the string to appease the internal61# build's test name parser62return f"(use_no_sync={self.use_no_sync}," f"num_iters={self.num_iters})"63
64
65@dataclass
66class _GradAccConfigs:67"""68This wraps a :class:`list` of :class:`_GradAccConfig` instances with the
69sole purpose of overriding :meth:`__repr__` to remove spaces.
70"""
71
72configs: List[_GradAccConfig]73
74def __repr__(self) -> str:75# Override to remove any spaces in the string to appease the internal76# build's test name parser77return "[" + ",".join(config.__repr__() for config in self.configs) + "]"78
79
80class TestGradAcc(FSDPTest):81"""Tests ``FullyShardedDataParallel``'s gradient accumulation via both its82``no_sync()`` context manager and without the context manager."""
83
84@property85def world_size(self) -> int:86return 287
88def _test_grad_acc(89self,90batch_dim: int,91configs: List[_GradAccConfig],92cpu_offload: CPUOffload,93backward_prefetch: Optional[BackwardPrefetch],94sharding_strategy: ShardingStrategy,95use_orig_params: bool,96):97"""98Tests gradient accumulation by comparing a run that trains sequentially
99through some batches while accumulating gradients with a run that
100trains on the concatenation of those batches in a single iteration.
101
102The last iteration always synchronizes gradients regardless of what is
103specified by the last element of ``configs``.
104
105Arguments:
106batch_dim (int): Batch dimension in the input tensor to be passed
107into the model for the forward pass.
108configs (List[_GradAccConfig]): :class:`list` of configurations
109specifying how gradients are accumulated; for example, a list
110corresponding to [(False, 2), (True, 2), (False, 2)] indicates
111to accumulate over 2 + 2 + 2 = 6 total iterations, where the
112first two do not use ``no_sync()``, the middle two do use
113``no_sync()``, and the final two again do not use
114``no_sync()``.
115cpu_offload (CPUOffload): Configures CPU offloading.
116backward_prefetch (Optional[BackwardPrefetch]): Specifies at which
117point to prefetch the next layer's full parameters during the
118backward pass, if at all.
119"""
120# Initialize the FSDP model and optimizer121fsdp_kwargs = {122"cpu_offload": cpu_offload,123"backward_prefetch": backward_prefetch,124"sharding_strategy": sharding_strategy,125"use_orig_params": use_orig_params,126}127fsdp_model: FSDP = TransformerWithSharedParams.init(128self.process_group,129FSDPInitMode.RECURSIVE,130CUDAInitMode.CUDA_BEFORE,131fsdp_kwargs,132deterministic=True,133add_bn=False, # disable BN since the test uses varying batch sizes134)135device = torch.device("cuda")136optim = torch.optim.SGD(137fsdp_model.parameters(),138lr=0.01,139momentum=0.9,140)141
142# Generate the sequence of batches, each containing the same data143# but permuted144def permute_tensor(x: torch.Tensor):145return x.view(-1)[torch.randperm(x.numel())].view_as(x)146
147batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device)148batches: List[Tuple[torch.Tensor, ...]] = [batch]149num_iters_to_acc = sum(config.num_iters for config in configs)150for _ in range(num_iters_to_acc - 1):151batches.append(tuple(permute_tensor(t) for t in batch))152for batch1, batch2 in itertools.combinations(batches, r=2):153for t1, t2 in zip(batch1, batch2):154assert not torch.all(155t1 == t2156), "Check the test to make sure that batches are distinct"157
158# Concatenate the batches along the given batch dimension159concat_batch: Tuple[torch.Tensor, ...] = tuple(160torch.cat(ts, dim=batch_dim) for ts in zip(*batches)161)162
163# Establish reference gradients using the concatenated batch164fsdp_model.zero_grad()165output = fsdp_model(*concat_batch)166ref_loss = fsdp_model.module.get_loss(concat_batch, output)167ref_loss.backward()168ref_grads = [169p.grad.detach().clone()170for p in fsdp_model.parameters()171if p.grad is not None172]173
174# Compute and accumulate the gradients175fsdp_model.zero_grad()176losses = []177batch_idx = 0178for config in configs:179sync_context = (180fsdp_model.no_sync() if config.use_no_sync else contextlib.nullcontext()181)182with sync_context:183for _ in range(config.num_iters):184if batch_idx == num_iters_to_acc - 1:185break # always sync on the last iteration186batch = batches[batch_idx]187batch_idx += 1188output = fsdp_model(*batch)189loss = fsdp_model.module.get_loss(batch, output)190loss.backward()191losses.append(loss)192output = fsdp_model(*batches[-1])193loss = fsdp_model.module.get_loss(batches[-1], output)194loss.backward()195losses.append(loss)196acc_loss = sum(losses)197acc_grads = [198p.grad.detach().clone()199for p in fsdp_model.parameters()200if p.grad is not None201]202
203# Compare the losses and gradients204torch.testing.assert_close(ref_loss, acc_loss)205self.assertEqual(len(ref_grads), len(acc_grads))206for ref_grad, acc_grad in zip(ref_grads, acc_grads):207self.assertEqual(ref_grad.device, acc_grad.device)208self.assertEqual(ref_grad.size(), acc_grad.size())209self.assertEqual(ref_grad.dtype, acc_grad.dtype)210torch.testing.assert_close(ref_grad, acc_grad)211
212# Check that the optimizer step does not error213optim.step()214
215def _get_subtest_config(self) -> Dict[str, List[Any]]:216"""Returns a subtest configuration that subtests prefetching."""217return {218"backward_prefetch": [219None,220BackwardPrefetch.BACKWARD_PRE,221BackwardPrefetch.BACKWARD_POST,222],223"sharding_strategy": [224ShardingStrategy.FULL_SHARD,225ShardingStrategy.SHARD_GRAD_OP,226ShardingStrategy.NO_SHARD,227],228}229
230@skip_if_lt_x_gpu(2)231@parametrize(232"configs",233[234_GradAccConfigs(235[236_GradAccConfig(use_no_sync=True, num_iters=3),237_GradAccConfig(use_no_sync=False, num_iters=3),238_GradAccConfig(use_no_sync=True, num_iters=3),239]240),241_GradAccConfigs(242[243_GradAccConfig(use_no_sync=False, num_iters=3),244_GradAccConfig(use_no_sync=True, num_iters=3),245_GradAccConfig(use_no_sync=False, num_iters=3),246]247),248],249)250@parametrize("use_orig_params", [False, True])251def test_grad_acc(252self,253configs: _GradAccConfigs,254use_orig_params: bool,255):256"""257Tests gradient accumulation without parameter CPU offloading.
258
259This exercises gradient accumulation inside and outside the
260``no_sync()`` context manager, in particular by interleaving the two.
261It tests both interleaving starting with (and ending with, resp.)
262inside versus outside ``no_sync()`` to ensure that initial conditions
263(and final conditions, resp.) do not affect the correctness.
264"""
265subtest_config = self._get_subtest_config()266subtest_config["cpu_offload"] = [CPUOffload(offload_params=False)]267self.run_subtests(268subtest_config,269self._test_grad_acc,270batch_dim=1,271configs=configs.configs,272use_orig_params=use_orig_params,273)274
275@skip_if_lt_x_gpu(2)276@parametrize("use_orig_params", [False, True])277def test_grad_acc_cpu_offload(278self,279use_orig_params: bool,280):281"""282Tests gradient accumulation with parameter CPU offloading.
283
284NOTE: Gradient accumulation without using the ``no_sync()`` context
285manager is not currently compatible with CPU offloading.
286"""
287# Only test `no_sync` since outside `no_sync()` is not supported with288# parameter CPU offloading289configs = _GradAccConfigs([_GradAccConfig(use_no_sync=True, num_iters=3)])290subtest_config = self._get_subtest_config()291subtest_config["cpu_offload"] = [CPUOffload(offload_params=True)]292self.run_subtests(293subtest_config,294self._test_grad_acc,295batch_dim=1,296configs=configs.configs,297use_orig_params=use_orig_params,298)299
300
301instantiate_parametrized_tests(TestGradAcc)302
303if __name__ == "__main__":304run_tests()305