pytorch
432 строки · 15.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys4from typing import Optional5
6import torch7import torch.nn as nn8import torch.nn.functional as F9from torch import distributed as dist10from torch.distributed.algorithms._comm_hooks import default_hooks11from torch.distributed.distributed_c10d import _get_default_group12from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision13from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy14from torch.distributed.fsdp.wrap import ModuleWrapPolicy15from torch.testing._internal.common_distributed import (16requires_nccl,17requires_nccl_version,18skip_but_pass_in_sandcastle_if,19skip_if_lt_x_gpu,20)
21from torch.testing._internal.common_fsdp import FSDPTest22from torch.testing._internal.common_utils import (23instantiate_parametrized_tests,24parametrize,25run_tests,26)
27
28if not dist.is_available():29print("Distributed not available, skipping tests", file=sys.stderr)30sys.exit(0)31
32# bfloat16 is only supported by CUDA 11+
33BFLOAT16_AVAILABLE = torch.cuda.is_available() and (34torch.version.cuda is not None or torch.version.hip is not None35)
36
37
38class Net(nn.Module):39def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None):40# to ensure determinism41torch.manual_seed(0)42torch.cuda.manual_seed(0)43super().__init__()44
45if has_wrapping:46self.net = FSDP(47nn.Sequential(48nn.Linear(8, 16),49nn.ReLU(),50FSDP(51nn.Linear(16, 8),52device_id=torch.cuda.current_device(),53sharding_strategy=sharding_strategy,54mixed_precision=mixed_precision,55),56),57device_id=torch.cuda.current_device(),58sharding_strategy=sharding_strategy,59mixed_precision=mixed_precision,60)61else:62self.net = nn.Sequential(nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8))63
64self.out = nn.Linear(8, 4)65
66def forward(self, x):67return self.out(F.relu(self.net(x)))68
69
70class DummyState:71__slots__ = ["process_group", "noise"]72
73def __init__(self, process_group: dist.ProcessGroup, noise: int):74self.process_group = process_group75self.noise = noise76
77
78class DummyHook:79def dummy_hook_for_no_shard_fsdp(self, state: DummyState, grad: torch.Tensor):80"""81This communication hook is for illustration and testing purpose only.
82This communication hook is used during FSDP ``NO_SHARD`` training. It adds some noise to
83the provided ``grad`` parameter and uses ``all_reduce`` to communicate full, flattened,
84unsharded gradient.
85"""
86grad.add_(state.noise)87dist.all_reduce(grad, group=state.process_group)88
89def custom_reduce_scatter(self, output, input, group=None):90"""91This function is for illustrative purpose only.
92It is meant to implement a custom reduce-scatter
93of a flattened tensor to all processes in a group.
94Currently a no-op.
95"""
96pass97
98def dummy_hook_for_sharded_fsdp(99self, state: DummyState, grad: torch.Tensor, output: torch.Tensor100):101"""102This communication hook is for illustration and testing purposes only.
103This communication hook is used during FSDP ``FULL_SHARD`` or ``SHARD_GRAD_OP`` training.
104It adds some noise to the provided ``grad`` parameter, uses
105``reduce_scatter`` for gradient communication and stores a sharded gradient in ``output``.
106"""
107grad.add_(state.noise)108self.custom_reduce_scatter(output, grad, group=state.process_group)109
110
111class TestCommunicationHooks(FSDPTest):112@skip_if_lt_x_gpu(2)113@parametrize(114"sharding_strategy",115[116ShardingStrategy.NO_SHARD,117ShardingStrategy.FULL_SHARD,118ShardingStrategy.SHARD_GRAD_OP,119],120)121def test_default_communication_hook_behavior(122self, sharding_strategy: Optional[ShardingStrategy]123):124"""125Tests FSDP's default communication hook's behavior and correctness.
126This test creates a simple linear net with weight shape ``1 X N``,
127where ``N`` is the number of workers.
128For sharded cases, each worker gets 1 element of the weight parameter. This test
129checks that after backward, each worker has a proper value in its chunk of
130the gradient, or the whole gradient on every worker is equal to an expected value.
131
132Arguments:
133sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
134"""
135out_dim = self.world_size136net = torch.nn.Linear(1, out_dim, bias=False)137inpt = torch.tensor([self.rank]).float().cuda(self.rank)138
139net_default_hook = FSDP(140net,141device_id=torch.cuda.current_device(),142sharding_strategy=sharding_strategy,143).to(self.rank)144
145# Check that by default, `_comm_hook` is None146for entry in FSDP.fsdp_modules(net_default_hook):147self.assertEqual(entry._comm_hook, None)148
149for _ in range(4):150# Clear gradients151net_default_hook.zero_grad()152loss = net_default_hook(inpt).sum()153loss.backward()154
155# For each worker, the gradient on the weight should be worker_rank.156grad = net_default_hook.params[0].grad157expected_grad = (158sum(i for i in range(dist.get_world_size())) / dist.get_world_size()159)160# Verify default hook produces expected gradients161self.assertEqual(162grad[0].item(),163expected_grad,164msg=f"Expected hook grad of {expected_grad} but got {grad[0].item()}",165)166
167def _get_submodules(self, fsdp_net):168return [169submodule
170for submodule in FSDP.fsdp_modules(fsdp_net)171if not submodule.check_is_root()172]173
174def _init_model(self, core, sharding_strategy, mixed_precision=None):175device = torch.device("cuda")176return FSDP(177core,178device_id=torch.cuda.current_device(),179sharding_strategy=sharding_strategy,180mixed_precision=mixed_precision,181).to(device)182
183@skip_if_lt_x_gpu(2)184@parametrize("has_wrapping", [True, False])185@parametrize(186"sharding_strategy",187[188ShardingStrategy.NO_SHARD,189ShardingStrategy.FULL_SHARD,190ShardingStrategy.SHARD_GRAD_OP,191],192)193def test_default_communication_hook_initialization(194self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]195):196"""197Tests FSDP's communication hook interface behavior.
198
199Arguments:
200has_wrapping (bool): Configures wrapping of a module.
201sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
202"""
203
204# Initialize a model205fsdp_model_with_hook = self._init_model(206Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),207sharding_strategy=sharding_strategy,208)209
210# Check that by default, `_comm_hook` is None211for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):212self.assertEqual(fsdp_module._comm_hook, None)213
214dummy_state = DummyState(process_group=None, noise=1234)215dummy_hook = (216DummyHook.dummy_hook_for_no_shard_fsdp217if sharding_strategy != ShardingStrategy.NO_SHARD218else DummyHook.dummy_hook_for_sharded_fsdp219)220
221fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)222
223# Check that we can't register comm hook twice224with self.assertRaisesRegex(225AssertionError, "^A communication hook is already registered$"226):227fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)228
229# Check dummy hook was registered for the root and all submodules if any230for fsdp_module in FSDP.fsdp_modules(fsdp_model_with_hook):231self.assertEqual(fsdp_module._comm_hook, dummy_hook)232self.assertEqual(fsdp_module._comm_hook_state, dummy_state)233
234@skip_if_lt_x_gpu(2)235@parametrize(236"sharding_strategy",237[238ShardingStrategy.NO_SHARD,239ShardingStrategy.FULL_SHARD,240ShardingStrategy.SHARD_GRAD_OP,241],242)243def test_registering_hook_non_root(244self, sharding_strategy: Optional[ShardingStrategy]245):246"""247Tests FSDP's communication hook registering for submodules.
248Make sure it can't be registered for non-root submodules.
249Currently tests only ``NO_SHARD`` strategy.
250
251Arguments:
252sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
253"""
254
255fsdp_model_with_hook = self._init_model(256Net(has_wrapping=True, sharding_strategy=sharding_strategy),257sharding_strategy=sharding_strategy,258)259dummy_state = DummyState(process_group=None, noise=1234)260dummy_hook = (261DummyHook.dummy_hook_for_no_shard_fsdp262if sharding_strategy != ShardingStrategy.NO_SHARD263else DummyHook.dummy_hook_for_sharded_fsdp264)265# Creating a list of non-root submodules to test266submodules = self._get_submodules(fsdp_model_with_hook)267# Check that assertion is raised for registering a comm hook on a non-root268with self.assertRaisesRegex(269AssertionError,270"^register_comm_hook can only be called on a root instance.$",271):272submodules[1].register_comm_hook(dummy_state, dummy_hook)273
274@skip_if_lt_x_gpu(2)275def test_registering_hook_hybrid_strategy(self):276for sharding_strategy in (277ShardingStrategy.HYBRID_SHARD,278ShardingStrategy._HYBRID_SHARD_ZERO2,279):280model = Net(False, None, None).cuda()281fsdp_model = FSDP(282model,283auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),284sharding_strategy=sharding_strategy,285)286dummy_state = DummyState(process_group=None, noise=1234)287dummy_hook = DummyHook.dummy_hook_for_sharded_fsdp288with self.assertRaisesRegex(289AssertionError,290"Communication hook is not supported for hybrid strategies",291):292fsdp_model.register_comm_hook(dummy_state, dummy_hook)293
294@skip_if_lt_x_gpu(2)295@parametrize(296"sharding_strategy",297[298ShardingStrategy.NO_SHARD,299ShardingStrategy.FULL_SHARD,300ShardingStrategy.SHARD_GRAD_OP,301],302)303def test_registering_hook_submodules(304self, sharding_strategy: Optional[ShardingStrategy]305):306"""307Tests FSDP's communication hook registering for submodules.
308Checks behavior if a hook was registered for a non-root submodule
309Currently tests only ``NO_SHARD`` strategy.
310
311Arguments:
312sharding_strategy (Optional[ShardingStrategy]): Configures the FSDP algorithm.
313"""
314
315fsdp_model_with_hook = self._init_model(316Net(has_wrapping=True, sharding_strategy=sharding_strategy),317sharding_strategy=sharding_strategy,318)319dummy_state = DummyState(process_group=None, noise=1234)320dummy_hook = (321DummyHook.dummy_hook_for_no_shard_fsdp322if sharding_strategy != ShardingStrategy.NO_SHARD323else DummyHook.dummy_hook_for_sharded_fsdp324)325submodules = self._get_submodules(fsdp_model_with_hook)326
327# Simulate a registration of a hook on a submodule328submodules[1]._comm_hook = dummy_hook329# Check that an error is raised when some of submodules have a non-default hook assigned330with self.assertRaisesRegex(331AssertionError, "^A communication hook is already registered$"332):333fsdp_model_with_hook.register_comm_hook(dummy_state, dummy_hook)334
335def _check_low_precision_hook(336self, state, hook, sharding_strategy, dtype, has_wrapping337):338# keep everything deterministic for input data339torch.manual_seed(0)340torch.cuda.manual_seed(0)341
342fsdp_with_hook = self._init_model(343Net(has_wrapping=has_wrapping, sharding_strategy=sharding_strategy),344sharding_strategy=sharding_strategy,345)346fsdp_with_hook.register_comm_hook(state, hook)347
348mp_only_grad = MixedPrecision(reduce_dtype=dtype)349fsdp_with_mp = self._init_model(350Net(351has_wrapping=has_wrapping,352sharding_strategy=sharding_strategy,353mixed_precision=mp_only_grad,354),355sharding_strategy=sharding_strategy,356mixed_precision=mp_only_grad,357)358
359optim_hook = torch.optim.SGD(fsdp_with_hook.parameters(), lr=0.1)360optim_mp = torch.optim.SGD(fsdp_with_mp.parameters(), lr=0.1)361
362in_data = torch.rand(16, 8).cuda()363fsdp_with_hook.train()364fsdp_with_mp.train()365loss_hook = fsdp_with_hook(in_data).sum()366loss_mp = fsdp_with_mp(in_data).sum()367loss_hook.backward()368# Make sure grads were cast to the parameter's precision369self.assertEqual(fsdp_with_hook.params[0].grad.dtype, state.parameter_type)370loss_mp.backward()371optim_hook.step()372optim_mp.step()373
374dist.barrier()375
376for hook_param, mp_param in zip(377fsdp_with_hook.parameters(), fsdp_with_mp.parameters()378):379self.assertEqual(hook_param.grad, mp_param.grad)380
381@requires_nccl()382@skip_if_lt_x_gpu(2)383@parametrize("has_wrapping", [True, False])384@parametrize(385"sharding_strategy",386[387ShardingStrategy.NO_SHARD,388ShardingStrategy.FULL_SHARD,389ShardingStrategy.SHARD_GRAD_OP,390],391)392def test_fp16_hook(393self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]394):395state = default_hooks.LowPrecisionState(process_group=_get_default_group())396hook = default_hooks.fp16_compress_hook397
398self._check_low_precision_hook(399state, hook, sharding_strategy, torch.float16, has_wrapping400)401
402@requires_nccl()403@requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")404@skip_but_pass_in_sandcastle_if(405not BFLOAT16_AVAILABLE,406"BFloat16 is only supported by CUDA 11+",407)408@skip_if_lt_x_gpu(2)409@parametrize("has_wrapping", [True, False])410@parametrize(411"sharding_strategy",412[413ShardingStrategy.NO_SHARD,414ShardingStrategy.FULL_SHARD,415ShardingStrategy.SHARD_GRAD_OP,416],417)418def test_bf16_hook(419self, has_wrapping: bool, sharding_strategy: Optional[ShardingStrategy]420):421state = default_hooks.LowPrecisionState(process_group=_get_default_group())422hook = default_hooks.bf16_compress_hook423
424self._check_low_precision_hook(425state, hook, sharding_strategy, torch.bfloat16, has_wrapping426)427
428
429instantiate_parametrized_tests(TestCommunicationHooks)430
431if __name__ == "__main__":432run_tests()433