pytorch
1402 строки · 54.6 Кб
1# Owner(s): ["oncall: distributed"]
2
3import copy4import functools5import itertools6import os7import sys8import unittest9from typing import Any, Dict, List, Optional, Tuple, Type10
11import torch12import torch.nn as nn13from torch import distributed as dist14from torch.distributed.fsdp import (15BackwardPrefetch,16CPUOffload,17FullyShardedDataParallel as FSDP,18MixedPrecision,19ShardingStrategy,20StateDictType,21)
22from torch.distributed.fsdp._common_utils import clean_tensor_name23from torch.distributed.fsdp._flat_param import (24_FSDP_SKIP_WRITEBACK_CHECK,25_FSDP_USE_FULL_PREC_IN_EVAL,26)
27from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES28from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy29from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer30from torch.nn.parallel.distributed import DistributedDataParallel as DDP31from torch.testing._internal.common_cuda import TEST_CUDA32from torch.testing._internal.common_distributed import skip_if_lt_x_gpu33from torch.testing._internal.common_fsdp import (34CUDAInitMode,35FSDPInitMode,36FSDPTest,37TransformerWithSharedParams,38)
39from torch.testing._internal.common_utils import (40instantiate_parametrized_tests,41parametrize,42run_tests,43TEST_WITH_DEV_DBG_ASAN,44TestCase,45)
46
47if not dist.is_available():48print("Distributed not available, skipping tests", file=sys.stderr)49sys.exit(0)50
51if TEST_WITH_DEV_DBG_ASAN:52print(53"Skip dev-asan as torch + multiprocessing spawn have known issues",54file=sys.stderr,55)56sys.exit(0)57
58
59class TestFSDPUseOrigParamsMultipleParamGroups(FSDPTest):60"""Tests multiple parameter groups."""61
62@property63def world_size(self) -> int:64return 265
66def _get_param_groups(self, model: nn.Module) -> List[Dict[str, Any]]:67"""68Constructs separate parameter groups for weights, biases, and other
69parameters.
70"""
71param_groups = [72{"params": [], "weight_decay": 0.1, "lr": 1e-2},73{"params": [], "weight_decay": 0.01, "lr": 1e-3},74{"params": []},75]76for param_name, param in model.named_parameters():77if "weight" in param_name:78param_groups[0]["params"].append(param)79elif "bias" in param_name:80param_groups[1]["params"].append(param)81else:82param_groups[2]["params"].append(param)83return param_groups84
85def _get_optim(86self,87model: nn.Module,88optim_class: Type[torch.optim.Optimizer],89multi_tensor: bool,90) -> torch.optim.Optimizer:91"""92Constructs an Adam optimizer with three parameter groups, one for
93weights, one for biases, and one for everything else, each with
94different weight decay and learning rates.
95"""
96param_groups = self._get_param_groups(model)97return optim_class(param_groups, lr=5e-3, foreach=multi_tensor)98
99def _get_ddp_transformer(self, find_unused_params: bool) -> DDP:100"""Returns a transformer with shared parameters wrapped with DDP."""101model = TransformerWithSharedParams.init(102self.process_group,103FSDPInitMode.NO_FSDP,104CUDAInitMode.CUDA_BEFORE,105deterministic=True,106)107ddp_model = DDP(108model,109device_ids=[self.rank],110find_unused_parameters=find_unused_params,111)112return ddp_model113
114def _get_fsdp_transformer_and_optim(115self,116cuda_init_mode: CUDAInitMode,117init_optim_before_wrap: bool,118optim_class: Type[torch.optim.Optimizer],119multi_tensor: bool,120sharding_strategy: ShardingStrategy,121backward_prefetch: Optional[BackwardPrefetch],122cpu_offload: CPUOffload,123) -> Tuple[FSDP, torch.optim.Optimizer]:124"""125Returns a transformer with shared parameters wrapped with FSDP and a
126corresponding optimizer.
127"""
128# Each transformer layer has multiple linear layers, so this policy, in129# combination with the parameter group construction, ensures different130# hyperparameter settings within one `FlatParameter`131fsdp_kwargs = {132"auto_wrap_policy": ModuleWrapPolicy(133{134TransformerEncoderLayer,135TransformerDecoderLayer,136}137),138"use_orig_params": True,139"sharding_strategy": sharding_strategy,140"backward_prefetch": backward_prefetch,141"cpu_offload": cpu_offload,142}143model = TransformerWithSharedParams.init(144self.process_group,145FSDPInitMode.NO_FSDP,146cuda_init_mode,147deterministic=True,148)149if init_optim_before_wrap:150fsdp_optim = self._get_optim(model, optim_class, multi_tensor)151fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs)152else:153fsdp_model = FSDP(model, self.process_group, **fsdp_kwargs)154fsdp_optim = self._get_optim(fsdp_model, optim_class, multi_tensor)155if (156cuda_init_mode == CUDAInitMode.CUDA_AFTER157and not fsdp_model.cpu_offload.offload_params158):159fsdp_model = fsdp_model.cuda()160return fsdp_model, fsdp_optim161
162def _check_train_parity(163self,164ddp_model: DDP,165ddp_optim: torch.optim.Optimizer,166fsdp_model: FSDP,167fsdp_optim: torch.optim.Optimizer,168set_to_none: bool,169num_iters: int = 10,170):171"""Checks training parity between DDP and FSDP."""172device = torch.device("cuda")173for i in range(num_iters):174iter_losses = []175for model, optim in ((ddp_model, ddp_optim), (fsdp_model, fsdp_optim)):176module = model.module177# Test two different `zero_grad()` timings178if i % 2 == 0:179optim.zero_grad(set_to_none=set_to_none) # pre-forward180inp = module.get_input(device)181output = model(*inp)182loss = module.get_loss(inp, output).to(device)183iter_losses.append(loss)184if i % 2 == 1:185optim.zero_grad(set_to_none=set_to_none) # pre-backward186module.run_backward(loss)187# Perform the DDP optimizer step on CPU to match FSDP if needed188if model is ddp_model and fsdp_model.cpu_offload.offload_params:189model.to(torch.device("cpu"))190optim.step()191if model is ddp_model and fsdp_model.cpu_offload.offload_params:192model.to(device)193torch.testing.assert_close(iter_losses[0], iter_losses[1])194iter_losses.clear()195self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)196
197def _check_ddp_fsdp_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):198with FSDP.summon_full_params(fsdp_model):199for (n1, p1), (n2, p2) in zip(200ddp_model.module.named_parameters(), fsdp_model.named_parameters()201):202# Allow for FSDP prefixes203self.assertEqual(n1, clean_tensor_name(n2))204torch.testing.assert_close(p1, p2)205
206def _get_sharding_strategy_from_str(207self, sharding_strategy_str: str208) -> ShardingStrategy:209if sharding_strategy_str == "no_shard":210sharding_strategy = ShardingStrategy.NO_SHARD211elif sharding_strategy_str == "shard_grad_op":212sharding_strategy = ShardingStrategy.SHARD_GRAD_OP213elif sharding_strategy_str == "full_shard":214sharding_strategy = ShardingStrategy.FULL_SHARD215else:216raise ValueError(f"Invalid string: {sharding_strategy_str}")217return sharding_strategy218
219@skip_if_lt_x_gpu(2)220def test_fsdp_compile(self):221self.run_subtests(222{223"sharding_strategy": [224ShardingStrategy.FULL_SHARD,225ShardingStrategy.SHARD_GRAD_OP,226ShardingStrategy.NO_SHARD,227],228"skip_fsdp_guards": [True, False],229},230self._test_fsdp_compile,231)232
233def _test_fsdp_compile(234self, sharding_strategy: ShardingStrategy, skip_fsdp_guards: bool235):236torch._dynamo.config.skip_fsdp_guards = skip_fsdp_guards237fsdp_kwargs = {238"auto_wrap_policy": ModuleWrapPolicy(239{240TransformerEncoderLayer,241TransformerDecoderLayer,242}243),244"use_orig_params": True,245"sharding_strategy": sharding_strategy,246"backward_prefetch": BackwardPrefetch.BACKWARD_PRE,247"cpu_offload": CPUOffload(False),248}249base_model = TransformerWithSharedParams.init(250self.process_group,251FSDPInitMode.NO_FSDP,252CUDAInitMode.CUDA_BEFORE,253deterministic=True,254)255ref_model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)256ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)257model = FSDP(copy.deepcopy(base_model), self.process_group, **fsdp_kwargs)258model = torch.compile(model)259optim = torch.optim.Adam(model.parameters(), lr=1e-2)260for i in range(10):261losses = []262inp = ref_model.get_input(torch.device("cuda"))263for _model, _optim in ((ref_model, ref_optim), (model, optim)):264_optim.zero_grad()265loss = _model(*inp).sum()266losses.append(loss)267loss.backward()268_optim.step()269self.assertEqual(losses[0], losses[1])270
271@skip_if_lt_x_gpu(2)272@parametrize(273"sharding_strategy_str",274["no_shard", "shard_grad_op", "full_shard"],275)276def test_diff_hyperparams(self, sharding_strategy_str: str):277"""278Tests FSDP parity with DDP when using multiple parameter groups with
279different hyperparameter settings.
280"""
281sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str)282self.run_subtests(283{284"cuda_init_mode": [285CUDAInitMode.CUDA_BEFORE,286CUDAInitMode.CUDA_AFTER,287],288"init_optim_before_wrap": [False, True],289"optim_class": [torch.optim.AdamW],290"multi_tensor": [False, True],291"set_to_none": [False, True],292"backward_prefetch": [293None,294BackwardPrefetch.BACKWARD_PRE,295BackwardPrefetch.BACKWARD_POST,296],297"skip_writeback_check": [False, True],298},299self._test_diff_hyperparams,300cpu_offload=CPUOffload(offload_params=False),301sharding_strategy=sharding_strategy,302)303
304@skip_if_lt_x_gpu(2)305@parametrize(306"sharding_strategy_str",307["no_shard", "shard_grad_op", "full_shard"],308)309def test_diff_hyperparams_cpu_offload(self, sharding_strategy_str: str):310"""311Tests FSDP parity with DDP when using multiple parameter groups with
312different hyperparameter settings with CPU offloading enabled. This is
313separate from :meth:`test_diff_hyperparams` because CPU offloading has
314some issues with subtesting for some specific subtesting configs (e.g.,
315with ``offload_params=False`` followed by ``True`` but not vice versa).
316"""
317sharding_strategy = self._get_sharding_strategy_from_str(sharding_strategy_str)318for skip_writeback_check in (False, True):319self._test_diff_hyperparams(320cuda_init_mode=CUDAInitMode.CUDA_BEFORE,321init_optim_before_wrap=False,322optim_class=torch.optim.Adam,323multi_tensor=False,324set_to_none=False,325backward_prefetch=BackwardPrefetch.BACKWARD_PRE,326cpu_offload=CPUOffload(offload_params=True),327sharding_strategy=sharding_strategy,328skip_writeback_check=skip_writeback_check,329)330
331def _test_diff_hyperparams(332self,333cuda_init_mode: CUDAInitMode,334init_optim_before_wrap: bool,335optim_class: Type[torch.optim.Optimizer],336multi_tensor: bool,337set_to_none: bool,338backward_prefetch: Optional[BackwardPrefetch],339cpu_offload: CPUOffload,340sharding_strategy: ShardingStrategy,341skip_writeback_check: bool,342):343"""344Args:
345init_optim_before_wrap (bool): If ``True``, initializes the
346FSDP optimizer before wrapping the model with FSDP; otherwise,
347initializes the FSDP optimizer after wrapping the model with
348FSDP. We permit both forms of initialization to give users
349flexibility.
350"""
351if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:352return # not supported353if skip_writeback_check:354os.environ[_FSDP_SKIP_WRITEBACK_CHECK] = "1"355ddp_model = self._get_ddp_transformer(find_unused_params=False)356ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor)357fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim(358cuda_init_mode=cuda_init_mode,359init_optim_before_wrap=init_optim_before_wrap,360optim_class=optim_class,361multi_tensor=multi_tensor,362sharding_strategy=sharding_strategy,363backward_prefetch=backward_prefetch,364cpu_offload=cpu_offload,365)366self._check_train_parity(367ddp_model, ddp_optim, fsdp_model, fsdp_optim, set_to_none368)369
370@skip_if_lt_x_gpu(2)371def test_diff_trainability(self):372"""373Tests FSDP parity with DDP when using multiple parameter groups and
374freezing the parameters in one parameter group.
375"""
376self.run_subtests(377{378"multi_tensor": [False, True],379"sharding_strategy": [380ShardingStrategy.FULL_SHARD,381ShardingStrategy.SHARD_GRAD_OP,382ShardingStrategy.NO_SHARD,383],384},385self._test_diff_trainability,386)387
388def _test_diff_trainability(389self,390multi_tensor: bool,391sharding_strategy: ShardingStrategy,392):393optim_class = torch.optim.Adam394ddp_model = self._get_ddp_transformer(find_unused_params=True)395ddp_optim = self._get_optim(ddp_model, optim_class, multi_tensor)396fsdp_model, fsdp_optim = self._get_fsdp_transformer_and_optim(397cuda_init_mode=CUDAInitMode.CUDA_BEFORE,398init_optim_before_wrap=False,399optim_class=optim_class,400multi_tensor=multi_tensor,401sharding_strategy=sharding_strategy,402backward_prefetch=BackwardPrefetch.BACKWARD_PRE,403cpu_offload=None,404)405# Freeze all biases (which happen to be in the same parameter group)406for param_name, param in ddp_model.named_parameters():407if "bias" in param_name:408param.requires_grad_(False)409for param_name, param in fsdp_model.named_parameters():410if "bias" in param_name:411param.requires_grad_(False)412self._check_train_parity(ddp_model, ddp_optim, fsdp_model, fsdp_optim, False)413
414@skip_if_lt_x_gpu(2)415def test_multiple_optimizers(self):416"""417Tests using two optimizers where only one sets gradients to ``None``.
418"""
419self.run_subtests(420{421"sharding_strategy": [422ShardingStrategy.FULL_SHARD,423ShardingStrategy.SHARD_GRAD_OP,424]425},426self._test_multiple_optimizers,427)428
429def _test_multiple_optimizers(self, sharding_strategy: ShardingStrategy):430ddp_model = self._get_ddp_transformer(find_unused_params=True)431ddp_param_groups = self._get_param_groups(ddp_model)432assert len(ddp_param_groups) == 3, f"{len(ddp_param_groups)}"433(434fsdp_model,435_,436) = self._get_fsdp_transformer_and_optim( # ignore returned optimizer437cuda_init_mode=CUDAInitMode.CUDA_BEFORE,438init_optim_before_wrap=False,439optim_class=torch.optim.Adam, # ignored440multi_tensor=False, # ignored441sharding_strategy=sharding_strategy,442backward_prefetch=BackwardPrefetch.BACKWARD_PRE,443cpu_offload=None,444)445fsdp_param_groups = self._get_param_groups(fsdp_model)446assert len(fsdp_param_groups) == 3, f"{len(fsdp_param_groups)}"447ddp_optims = []448fsdp_optims = []449# For the transformer model, every parameter is either a weight or a450# bias, so we only use the first two parameter groups. Moreover, we use451# Adam and AdamW in particular since they both use bias correction452# dependent on the step, which is incremented even if a parameter has a453# zero gradient but not if the gradient is `None`. This is to test that454# we are differentiating between a zero and `None` gradient correctly.455optim_ctors = [456functools.partial(torch.optim.Adam, lr=5e-3),457functools.partial(torch.optim.AdamW, lr=1e-2),458]459
460for optim_ctor, ddp_param_group, fsdp_param_group in zip(461optim_ctors,462ddp_param_groups[:2],463fsdp_param_groups[:2],464):465ddp_optims.append(optim_ctor(ddp_param_group["params"]))466fsdp_optims.append(optim_ctor(fsdp_param_group["params"]))467device = torch.device("cuda")468
469# Check that there exists a `FlatParameter` that has both a weight and470# a bias in this rank's shard471has_both = False472for fsdp_module in FSDP.fsdp_modules(fsdp_model):473handle = fsdp_module._handle474if not handle:475continue476flat_param = handle.flat_param477assert flat_param._params is not None478has_weight = False479has_bias = False480for param, fqn in zip(flat_param._params, flat_param._fqns):481if "weight" in fqn and param.numel() > 0:482has_weight = True483elif "bias" in fqn and param.numel() > 0:484has_bias = True485has_both |= has_weight and has_bias486assert has_both, (487f"Rank {self.rank} does not have a `FlatParameter` with both a "488"weight and a bias in its shard, meaning that this test is vacuous"489)490
491# Run one iteration to generate gradients492def run_iter():493iter_losses = []494for model, optims in ((ddp_model, ddp_optims), (fsdp_model, fsdp_optims)):495module = model.module496inp = module.get_input(device)497output = model(*inp)498loss = module.get_loss(inp, output).to(device)499iter_losses.append(loss)500module.run_backward(loss)501for optim in optims:502optim.step()503torch.testing.assert_close(iter_losses[0], iter_losses[1])504iter_losses.clear()505self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)506
507run_iter()508
509# Only set the weights' gradients to None510ddp_optims[0].zero_grad(set_to_none=True)511fsdp_optims[0].zero_grad(set_to_none=True)512inp = ddp_model.module.get_input(device)513ddp_output = ddp_model(*inp)514fsdp_output = fsdp_model(*inp)515
516# Check that FSDP correctly exposes gradients even after forward517# (namely, `None` for weights and non-`None` for biases)518if sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES:519# Skip the check since we do not expose the gradients after forward520# for these strategies521return522for (ddp_n, ddp_p), (fsdp_n, fsdp_p) in zip(523ddp_model.module.named_parameters(),524fsdp_model.named_parameters(),525):526self.assertEqual(ddp_n, clean_tensor_name(fsdp_n))527if fsdp_p.numel() == 0:528# Not in this rank's shard529self.assertTrue(fsdp_p.grad is None)530continue531if ddp_p.grad is None:532self.assertTrue(fsdp_p.grad is None)533else:534self.assertEqual(ddp_p.flatten(), fsdp_p.flatten())535self.assertEqual(ddp_p.grad.flatten(), fsdp_p.grad.flatten())536self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)537
538# Finish the iteration (backward pass and optimizer step)539ddp_loss = ddp_model.module.get_loss(inp, ddp_output).to(device)540fsdp_loss = fsdp_model.module.get_loss(inp, fsdp_output).to(device)541ddp_model.module.run_backward(ddp_loss)542fsdp_model.module.run_backward(fsdp_loss)543for optim in itertools.chain(ddp_optims, fsdp_optims):544optim.step()545self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)546
547# Run one more iteration to confirm bias corrections are correct548run_iter()549self._check_ddp_fsdp_param_parity(ddp_model, fsdp_model)550
551
552class TestFSDPUseOrigParamsUnshardReshard(FSDPTest):553"""Tests the unshard/reshard flow."""554
555@property556def world_size(self) -> int:557return 2558
559def _get_fsdp_models_and_optims(560self,561sharding_strategy: ShardingStrategy,562cpu_offload: CPUOffload,563) -> Tuple[FSDP, torch.optim.Optimizer, FSDP, torch.optim.Optimizer]:564"""565Returns a pair of (FSDP model, optimizer) for ``use_orig_params=False``
566and ``True``, respectively.
567"""
568LR = 1e-2569fsdp_kwargs = {570"sharding_strategy": sharding_strategy,571"cpu_offload": cpu_offload,572"use_orig_params": False,573}574fsdp_model = TransformerWithSharedParams.init(575self.process_group,576FSDPInitMode.RECURSIVE,577CUDAInitMode.CUDA_BEFORE,578fsdp_kwargs=fsdp_kwargs,579deterministic=True,580)581optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR)582fsdp_kwargs["use_orig_params"] = True583fsdp_model_orig_params = TransformerWithSharedParams.init(584self.process_group,585FSDPInitMode.RECURSIVE,586CUDAInitMode.CUDA_BEFORE,587fsdp_kwargs=fsdp_kwargs,588deterministic=True,589)590optim_orig_params = torch.optim.Adam(591fsdp_model_orig_params.parameters(), foreach=False, lr=LR592)593return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params594
595def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None:596"""Checks that two FSDP instances have the same model parameters."""597with FSDP.summon_full_params(fsdp1), FSDP.summon_full_params(fsdp2):598for (n1, p1), (n2, p2) in zip(599fsdp1.named_parameters(),600fsdp2.named_parameters(),601):602self.assertEqual(n1, n2)603torch.testing.assert_close(p1, p2)604
605def _get_fsdp_parity_subtest_config(self):606return {607"sharding_strategy": [608ShardingStrategy.NO_SHARD,609ShardingStrategy.SHARD_GRAD_OP,610ShardingStrategy.FULL_SHARD,611],612}613
614@skip_if_lt_x_gpu(2)615@parametrize("offload_params", [False, True])616def test_multiple_forward(self, offload_params: bool):617"""618Tests that ``use_orig_params=True`` has parity with ``False`` when
619running multiple forward passes before a backward pass.
620"""
621cpu_offload = CPUOffload(offload_params=offload_params)622self.run_subtests(623self._get_fsdp_parity_subtest_config(),624self._test_multiple_forward,625cpu_offload=cpu_offload,626)627
628@skip_if_lt_x_gpu(2)629def _test_multiple_forward(630self,631sharding_strategy: ShardingStrategy,632cpu_offload: CPUOffload,633):634(635fsdp_model,636optim,637fsdp_model_orig_params,638optim_orig_params,639) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)640device = torch.device("cuda")641for _ in range(3):642inp1 = fsdp_model.get_input(device)643_inp2 = fsdp_model.get_input(device)644inp2 = tuple(645t + torch.ones_like(t) for t in _inp2646) # make different from `inp1`647# For these loss lists: elem 0 is baseline; elem 1 is test648losses1 = []649losses2 = []650losses = []651for _model, _optim in (fsdp_model, optim), (652fsdp_model_orig_params,653optim_orig_params,654):655_optim.zero_grad()656loss1 = _model(*inp1)657losses1.append(loss1)658loss2 = _model(*inp2)659losses2.append(loss2)660loss = (loss1 + loss2).sum()661losses.append(loss)662_model.run_backward(loss)663_optim.step()664self.assertEqual(losses1[0], losses1[1])665self.assertEqual(losses2[0], losses2[1])666self.assertEqual(losses[0], losses[1])667self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)668
669@skip_if_lt_x_gpu(2)670@parametrize("offload_params", [False, True])671def test_summon_between_two_forwards(self, offload_params: bool):672"""673Tests that ``use_orig_params=True`` has parity with ``False`` when
674running a forward pass, :meth:`summon_full_params()`, and another
675forward pass before a backward pass.
676"""
677cpu_offload = CPUOffload(offload_params=offload_params)678self.run_subtests(679self._get_fsdp_parity_subtest_config(),680self._test_summon_between_two_forwards,681cpu_offload=cpu_offload,682)683
684def _test_summon_between_two_forwards(685self,686sharding_strategy: ShardingStrategy,687cpu_offload: CPUOffload,688):689(690fsdp_model,691optim,692fsdp_model_orig_params,693optim_orig_params,694) = self._get_fsdp_models_and_optims(sharding_strategy, cpu_offload)695device = torch.device("cuda")696for _ in range(3):697optim.zero_grad()698optim_orig_params.zero_grad()699
700inp1 = fsdp_model.get_input(device)701loss1 = fsdp_model(*inp1)702loss_orig_params1 = fsdp_model_orig_params(*inp1)703self.assertEqual(loss1, loss_orig_params1)704
705# Calls into `summon_full_params()`706self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)707
708inp2 = fsdp_model.get_input(device)709loss2 = fsdp_model(*inp2)710loss_orig_params2 = fsdp_model_orig_params(*inp2)711self.assertEqual(loss2, loss_orig_params2)712
713loss = (loss1 + loss2).sum()714loss_orig_params = (loss_orig_params1 + loss_orig_params2).sum()715fsdp_model.run_backward(loss)716fsdp_model_orig_params.run_backward(loss_orig_params)717optim.step()718optim_orig_params.step()719self._check_fsdp_parameter_parity(fsdp_model, fsdp_model_orig_params)720
721
722class TestFSDPUseOrigParamsParamAccess(FSDPTest):723"""Tests original parameter access."""724
725@property726def world_size(self):727# Force a world size of 2 since the tests hard code to the FSDP728# sharding strategy to check sharded parameter parity729return 2730
731@skip_if_lt_x_gpu(2)732def test_access_params_after_forward(self):733"""734Tests that accessing the original parameters after the forward but
735before the backward. Notably, this is not supported when
736``use_orig_params=False``. However, for ``True``, FSDP exposes the
737(flattened) sharded original parameters, making it possible.
738"""
739self.run_subtests(740{741"sharding_strategy": [742ShardingStrategy.NO_SHARD,743ShardingStrategy.FULL_SHARD,744ShardingStrategy.SHARD_GRAD_OP,745],746},747self._test_access_params_after_forward,748)749
750def _test_access_params_after_forward(751self,752sharding_strategy: ShardingStrategy,753):754# NOTE: This test needs to be changed if the FSDP sharding algorithm755# changes. It is still valuable until such a change to sanity check the756# `use_orig_params=True` implementation.757class Model(nn.Module):758def __init__(self):759super().__init__()760torch.manual_seed(42)761# 5 * 5 = 25 numel -> pad to 26 -> 13 on each rank762self.lin1 = nn.Linear(5, 5, bias=False)763# 5 * 7 + (1) + 7 = 43 numel -> pad to 44 -> 22 on each rank,764# where the (1) is from intra-`FlatParameter` alignment padding765# 22 of weight on rank 0; 13 of weight, 1 alignment padding,766# and 7 of bias on rank 1767self.lin2 = nn.Linear(5, 7)768
769def forward(self, x: torch.Tensor) -> torch.Tensor:770z = self.lin1(x)771z = nn.functional.relu(z)772z = self.lin2(z)773return z774
775def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:776return (torch.randn((2, 5)).to(device),)777
778def get_loss(self, inp, out):779return out.sum()780
781def check_parameter_parity(782ddp_model: DDP, fsdp_model: FSDP, between_fwd_and_bwd: bool783):784assert self.rank in (7850,7861,787), f"Expects world size of 2 but got {self.world_size}"788for (n1, p1), (n2, p2) in zip(789ddp_model.module.named_parameters(),790fsdp_model.named_parameters(),791):792self.assertEqual(n1, clean_tensor_name(n2))793if sharding_strategy == ShardingStrategy.NO_SHARD:794# For `NO_SHARD`, do nothing since the original parameters795# are unflattened796pass797elif (798between_fwd_and_bwd
799and sharding_strategy in NO_RESHARD_AFTER_FORWARD_STRATEGIES800):801# For no reshard after forward strategies, do nothing since802# FSDP did not use sharded views after forward803pass804# Otherwise, case on the parameter (see the model definition)805elif n1 == "lin1.weight":806if self.rank == 0:807p1 = p1.flatten()[:13]808elif self.rank == 1:809p1 = p1.flatten()[13:]810elif n1 == "lin2.weight":811if self.rank == 0:812p1 = p1.flatten()[:22]813elif self.rank == 1:814p1 = p1.flatten()[22:]815elif n1 == "lin2.bias":816if self.rank == 0:817p1 = torch.empty(0, device=p1.device)818elif self.rank == 1:819p1 = p1.flatten()820torch.testing.assert_close(p1, p2)821
822ddp_model = DDP(Model().cuda(), device_ids=[self.rank])823fsdp_model = FSDP(824Model().cuda(),825sharding_strategy=sharding_strategy,826auto_wrap_policy=always_wrap_policy,827use_orig_params=True,828)829LR = 1e-2830ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)831fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)832device = torch.device("cuda")833
834inp = fsdp_model.get_input(device)835ddp_out = ddp_model(*inp)836fsdp_out = fsdp_model(*inp)837check_parameter_parity(ddp_model, fsdp_model, True)838
839ddp_loss = ddp_model.module.get_loss(inp, ddp_out)840fsdp_loss = fsdp_model.get_loss(inp, fsdp_out)841ddp_loss.backward()842fsdp_loss.backward()843ddp_optim.step()844fsdp_optim.step()845check_parameter_parity(ddp_model, fsdp_model, False)846
847inp = fsdp_model.get_input(device)848ddp_out = ddp_model(*inp)849fsdp_out = fsdp_model(*inp)850check_parameter_parity(ddp_model, fsdp_model, True)851
852
853class TestFSDPUseOrigParamsWriteback(FSDPTest):854"""Tests parameter and gradient writeback."""855
856class Model(nn.Module):857def __init__(self, device: torch.device):858super().__init__()859torch.manual_seed(42)860self.lin1 = nn.Linear(5, 5, bias=True, device=device)861self.lin2 = nn.Linear(5, 7, bias=True, device=device)862
863def forward(self, x: torch.Tensor) -> torch.Tensor:864z = self.lin1(x)865z = nn.functional.relu(z)866z = self.lin2(z)867return z868
869def get_input(self, device: torch.device) -> Tuple[torch.Tensor, ...]:870return (torch.randn((2, 5)).to(device),)871
872def get_loss(self, inp, out):873return out.sum()874
875@property876def world_size(self):877# Force a world size of 2 since the tests hard code to the FSDP878# sharding strategy879return 2880
881def _check_param_parity(self, ddp_model: DDP, fsdp_model: FSDP):882with FSDP.summon_full_params(fsdp_model):883for (n1, p1), (n2, p2) in zip(884ddp_model.module.named_parameters(),885fsdp_model.named_parameters(),886):887self.assertEqual(n1, n2)888torch.testing.assert_close(p1, p2)889
890@skip_if_lt_x_gpu(2)891def test_param_writeback(self):892"""Tests that changes to the original parameters are written back."""893self.run_subtests(894{895"change_first_weight": [True, False], # first vs. second `weight`896"change_data": [True, False], # change `.data` vs. variable itself897},898self._test_param_writeback,899)900
901def _test_param_writeback(self, change_first_weight: bool, change_data: bool):902def transform_param(param: nn.Parameter) -> nn.Parameter:903return nn.Parameter(torch.ones_like(param) * 2)904
905# Check that the writeback propagates906ddp_model = DDP(907TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),908device_ids=[self.rank],909)910fsdp_model = FSDP(911TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),912use_orig_params=True,913)914ddp = ddp_model.module # for brevity915fsdp = fsdp_model.module916if change_first_weight:917if change_data:918ddp.lin1.weight.data = transform_param(ddp.lin1.weight)919fsdp.lin1.weight.data = transform_param(fsdp.lin1.weight)920else:921ddp.lin1.weight = transform_param(ddp.lin1.weight)922fsdp.lin1.weight = transform_param(fsdp.lin1.weight)923else:924if change_data:925ddp.lin2.weight.data = transform_param(ddp.lin2.weight)926fsdp.lin2.weight.data = transform_param(fsdp.lin2.weight)927else:928ddp.lin2.weight = transform_param(ddp.lin2.weight)929fsdp.lin2.weight = transform_param(fsdp.lin2.weight)930self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback931
932@skip_if_lt_x_gpu(2)933def test_grad_writeback(self):934"""935Tests that changes to the original parameters' gradients are written
936back.
937"""
938self.run_subtests(939{940"change_first_weight_grad": [False, True],941"change_data": [False, True], # change `.data` vs. variable itself942"set_to_none": [False, True],943},944self._test_grad_writeback,945)946
947def _test_grad_writeback(948self,949change_first_weight_grad: bool,950change_data: bool,951set_to_none: bool,952):953if change_data and set_to_none:954return # not well-defined955
956def transform_grad(param: nn.Parameter) -> nn.Parameter:957return None if set_to_none else torch.ones_like(param) * 2958
959ddp_model = DDP(960TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),961device_ids=[self.rank],962)963fsdp_model = FSDP(964TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),965use_orig_params=True,966)967LR = 1e-2968# TODO: If we add `summon_full_params(with_grads=True)`, then replace969# the following. For now, we use the optimizer step as a surrogate for970# checking that gradients were written back.971ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR)972fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)973
974# Generate an initial gradient975inp = fsdp_model.get_input(torch.device("cuda"))976ddp_out = ddp_model(*inp)977fsdp_out = fsdp_model(*inp)978ddp_out.sum().backward()979fsdp_out.sum().backward()980
981# Change the gradient through the original parameters982ddp = ddp_model.module # for brevity983fsdp = fsdp_model.module984if change_first_weight_grad:985if change_data:986ddp.lin1.weight.grad.data = transform_grad(ddp.lin1.weight)987if fsdp.lin1.weight.grad is not None:988fsdp.lin1.weight.grad.data = transform_grad(fsdp.lin1.weight)989else:990ddp.lin1.weight.grad = transform_grad(ddp.lin1.weight)991fsdp.lin1.weight.grad = transform_grad(fsdp.lin1.weight)992else:993if change_data:994ddp.lin2.weight.grad.data = transform_grad(ddp.lin2.weight)995if fsdp.lin2.weight.grad is not None:996fsdp.lin2.weight.grad.data = transform_grad(fsdp.lin2.weight)997else:998ddp.lin2.weight.grad = transform_grad(ddp.lin2.weight)999fsdp.lin2.weight.grad = transform_grad(fsdp.lin2.weight)1000ddp_optim.step()1001fsdp_optim.step()1002self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback1003
1004# Intentionally do not zero the gradient to check writeback1005inp = fsdp_model.get_input(torch.device("cuda"))1006ddp_out = ddp_model(*inp)1007fsdp_out = fsdp_model(*inp)1008ddp_out.sum().backward()1009fsdp_out.sum().backward()1010ddp_optim.step()1011fsdp_optim.step()1012self._check_param_parity(ddp_model, fsdp_model) # triggers a writeback1013
1014@skip_if_lt_x_gpu(2)1015def test_writeback_shape_mismatch(self):1016fsdp_model = FSDP(1017TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")),1018use_orig_params=True,1019)1020# Check that writing back with mismatched shape errors1021fsdp = fsdp_model.module # for brevity1022assert self.rank in (0, 1), f"Expects world size of 2 but got {self.world_size}"1023with self.assertRaisesRegex(RuntimeError, "Cannot writeback"):1024# Change the gradient to a new one with 1 added to each dimension1025# to force a shape mismatch when writing back1026if self.rank == 0:1027# Change `lin1.weight.grad` since it exists on rank 01028lin1_weight_shape = list(fsdp.lin1.weight.shape)1029for dim_index in range(len(lin1_weight_shape)):1030lin1_weight_shape[dim_index] += 11031fsdp.lin1.weight = nn.Parameter(1032torch.randn(1033torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device1034)1035)1036fsdp.lin1.weight.grad = torch.randn(1037torch.Size(lin1_weight_shape), device=fsdp.lin1.weight.device1038)1039elif self.rank == 1:1040# Change `lin2.weight.grad` since it exists (partially) on rank 11041lin2_weight_shape = list(fsdp.lin2.weight.shape)1042for dim_index in range(len(lin2_weight_shape)):1043lin2_weight_shape[dim_index] += 11044fsdp.lin2.weight = nn.Parameter(1045torch.randn(1046torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device1047)1048)1049fsdp.lin2.weight.grad = torch.randn(1050torch.Size(lin2_weight_shape), device=fsdp.lin2.weight.device1051)1052with FSDP.summon_full_params(fsdp_model): # triggers a writeback1053...1054
1055@skip_if_lt_x_gpu(2)1056def test_writeback_between_fwd_and_bwd_for_no_reshard_raises(self):1057fsdp_kwargs = {1058"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,1059"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),1060"use_orig_params": True,1061}1062fsdp_wrapper = functools.partial(FSDP, **fsdp_kwargs)1063
1064# Test changing the parameter storage to no longer be a view into the1065# flat parameter1066fsdp_model = fsdp_wrapper(1067TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))1068)1069inp = fsdp_model.get_input(torch.device("cuda"))1070loss = fsdp_model(*inp).sum()1071fsdp_model.lin1.weight.data = fsdp_model.lin1.weight.clone()1072assert_msg = (1073"FSDP does not support changing the parameters between forward and backward"1074)1075with self.assertRaisesRegex(AssertionError, assert_msg):1076loss.backward()1077
1078# Test changing the parameter variable itself1079fsdp_model = fsdp_wrapper(1080TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda"))1081)1082inp = fsdp_model.get_input(torch.device("cuda"))1083loss = fsdp_model(*inp).sum()1084fsdp_model.lin1._fsdp_wrapped_module.weight = nn.Parameter(1085fsdp_model.lin1.weight.clone()1086)1087with self.assertRaisesRegex(AssertionError, assert_msg):1088loss.backward()1089
1090@skip_if_lt_x_gpu(2)1091def test_no_reshard_and_mixed_precision(self):1092"""1093Tests that writeback does not falsely get triggered for a few
1094configurations (exercising the sharded view skipping logic):
1095- Train forward -> full-precision unshard -> train forward
1096- Train forward -> eval forward
1097- Train forward/backward -> eval forward -> model checkpoint
1098"""
1099self.run_subtests(1100{"use_full_prec_in_eval": [False, True]},1101self._test_no_reshard_and_mixed_precision,1102)1103
1104def _test_no_reshard_and_mixed_precision(self, use_full_prec_in_eval: bool):1105if use_full_prec_in_eval:1106os.environ[_FSDP_USE_FULL_PREC_IN_EVAL] = "1"1107fsdp_kwargs = {1108"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,1109"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),1110"mixed_precision": MixedPrecision(param_dtype=torch.float16),1111"use_orig_params": True,1112}1113
1114# Train forward -> full-precision unshard -> train forward1115fsdp_model = FSDP(1116TestFSDPUseOrigParamsWriteback.Model(torch.device("cuda")), **fsdp_kwargs1117)1118inp = fsdp_model.get_input(torch.device("cuda"))1119fsdp_model(*inp)1120with FSDP.summon_full_params(fsdp_model):1121...1122fsdp_model(*inp).sum()1123
1124# Train forward -> eval forward1125fsdp_model.train()1126fsdp_model(*inp)1127fsdp_model.eval()1128fsdp_model(*inp)1129
1130# Train forward/backward -> eval forward -> model checkpoint1131fsdp_model.train()1132fsdp_model(*inp).sum().backward()1133fsdp_model.eval()1134fsdp_model(*inp)1135with FSDP.state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT):1136sd = fsdp_model.state_dict()1137fsdp_model.load_state_dict(sd)1138fsdp_model(*inp).sum().backward()1139
1140
1141class TestFSDPUseOrigParamsFQNs(FSDPTest):1142@skip_if_lt_x_gpu(2)1143def test_named_parameters_in_forward(self):1144"""1145Tests that calling ``named_parameters()`` during forward returns FQNs
1146and ``Tensor`` s corresponding to the original parameters.
1147"""
1148param_shapes = [None, None]1149assert_equal_fn = self.assertEqual1150
1151class Model(nn.Module):1152def __init__(self) -> None:1153super().__init__()1154self.lin = nn.Linear(5, 5)1155
1156def forward(self, x: torch.Tensor) -> torch.Tensor:1157nonlocal param_shapes1158# Allow for FSDP prefixes1159param_names = [1160clean_tensor_name(tup[0]) for tup in self.named_parameters()1161]1162params = [tup[1] for tup in self.named_parameters()]1163assert (1164param_shapes[0] is not None and param_shapes[1] is not None1165), "`param_sizes` should be set"1166assert_equal_fn(1167param_names,1168[1169"lin.weight",1170"lin.bias",1171],1172)1173assert_equal_fn(params[0].shape, param_shapes[0])1174assert_equal_fn(params[1].shape, param_shapes[1])1175return self.lin(x)1176
1177model = Model().cuda()1178# Save the *unsharded* original parameter shapes and check the shapes1179# match in the forward pass1180param_shapes[0] = model.lin.weight.shape1181param_shapes[1] = model.lin.bias.shape1182fsdp_model = FSDP(model, use_orig_params=True)1183inp = torch.randn((2, 5), device=torch.device("cuda"))1184fsdp_model(inp)1185
1186
1187class TestFSDPUseOrigParamsNoSync(FSDPTest):1188@property1189def world_size(self) -> int:1190return 21191
1192@skip_if_lt_x_gpu(2)1193def test_no_sync_correctness(self):1194"""1195Tests a basic ``no_sync()`` setup by comparing ``use_orig_params=True``
1196against ``use_orig_params=False``.
1197"""
1198self.run_subtests(1199{1200"sharding_strategy": [1201ShardingStrategy.FULL_SHARD,1202ShardingStrategy.SHARD_GRAD_OP,1203ShardingStrategy.NO_SHARD,1204],1205},1206self._test_no_sync_correctness,1207)1208
1209def _test_no_sync_correctness(self, sharding_strategy: ShardingStrategy):1210model = nn.Linear(7, 1, bias=False, device="cuda")1211fsdp_kwargs = {1212"sharding_strategy": sharding_strategy,1213}1214model_use_flat_params = FSDP(1215copy.deepcopy(model), use_orig_params=False, **fsdp_kwargs1216)1217model_use_orig_params = FSDP(model, use_orig_params=True, **fsdp_kwargs)1218optim_use_flat_params = torch.optim.AdamW(1219model_use_flat_params.parameters(), foreach=True1220)1221optim_use_orig_params = torch.optim.AdamW(1222model_use_orig_params.parameters(), foreach=True1223)1224
1225def _check_param_grad_parity(1226_baseline_model: nn.Module,1227_test_model: nn.Module,1228):1229"""1230This assumes that the model is ``nn.Linear(7, 1, bias=False)``
1231(i.e. with a single 1D weight parameter) to be able to directly
1232compare the baseline and test models. On rank 1, the baseline
1233includes 1 element of padding.
1234"""
1235self.assertEqual(len(list(_baseline_model.parameters())), 1)1236self.assertEqual(len(list(_test_model.parameters())), 1)1237for flat_param, orig_param in zip(1238_baseline_model.parameters(), _test_model.parameters()1239):1240# Baseline is permitted to have padding1241self.assertGreaterEqual(flat_param.numel(), orig_param.numel())1242unpadded_param_numel = orig_param.numel()1243# For `NO_SHARD`, `use_orig_params=True` presents unflattened1244# parameters, while `False` presents flattened ones1245torch.testing.assert_close(1246flat_param[:unpadded_param_numel], orig_param.flatten()1247)1248# Gradient numel is different if right after `no_sync()` since1249# the gradient is unsharded, while the parameter is sharded1250unpadded_grad_numel = orig_param.grad.numel()1251# For `use_orig_params=False`, the unsharded gradient is1252# flattened, while for `True`, it is unflattened1253torch.testing.assert_close(1254flat_param.grad[:unpadded_grad_numel].reshape(1255orig_param.grad.shape1256),1257orig_param.grad,1258)1259
1260inp = torch.randn((2, 7), device="cuda")1261grad = torch.randn((2, 1), device="cuda")1262
1263# Compute some reference gradients using one forward/backward1264out_use_flat_params = model_use_flat_params(inp)1265out_use_orig_params = model_use_orig_params(inp)1266torch.testing.assert_close(out_use_flat_params, out_use_orig_params)1267out_use_flat_params.backward(grad)1268out_use_orig_params.backward(grad)1269_check_param_grad_parity(model_use_flat_params, model_use_orig_params)1270ref_grads_use_flat_params = [1271param.grad.detach().clone() for param in model_use_flat_params.parameters()1272]1273ref_grads_use_orig_params = [1274param.grad.detach().clone()1275for param in model_use_orig_params.parameters()1276if param.grad is not None1277]1278
1279# Run a forward/backward in `no_sync()`1280optim_use_flat_params.zero_grad(set_to_none=True)1281optim_use_orig_params.zero_grad(set_to_none=True)1282for model in (model_use_flat_params, model_use_orig_params):1283with model.no_sync():1284out = model(inp)1285out.backward(grad)1286_check_param_grad_parity(model_use_flat_params, model_use_orig_params)1287
1288# Run a forward/backward outside `no_sync()`1289for model in (model_use_flat_params, model_use_orig_params):1290out = model(inp)1291out.backward(grad)1292_check_param_grad_parity(model_use_flat_params, model_use_orig_params)1293
1294# Check that, since we accumulated gradients across 2 iterations, that1295# the new gradients are 2x the reference gradients1296grads_use_flat_params = [1297param.grad.detach().clone() for param in model_use_flat_params.parameters()1298]1299grads_use_orig_params = [1300param.grad.detach().clone()1301for param in model_use_orig_params.parameters()1302if param.grad is not None1303]1304for grad, ref_grad in zip(grads_use_flat_params, ref_grads_use_flat_params):1305torch.testing.assert_close(grad, 2 * ref_grad)1306for grad, ref_grad in zip(grads_use_orig_params, ref_grads_use_orig_params):1307torch.testing.assert_close(grad, 2 * ref_grad)1308
1309@skip_if_lt_x_gpu(2)1310def test_no_sync_mixed_precision(self):1311"""1312Tests that dtypes are as expected when using ``no_sync()`` with
1313``use_orig_params=True`` and parameter mixed precision.
1314"""
1315self.run_subtests(1316{1317"sharding_strategy": [1318ShardingStrategy.FULL_SHARD,1319ShardingStrategy.SHARD_GRAD_OP,1320ShardingStrategy.NO_SHARD,1321]1322},1323self._test_no_sync_mixed_precision,1324)1325
1326def _test_no_sync_mixed_precision(self, sharding_strategy: ShardingStrategy):1327model = nn.Linear(3, 3, device="cuda")1328mixed_precision = MixedPrecision(1329param_dtype=torch.float16,1330reduce_dtype=torch.float32,1331)1332fsdp_kwargs = {1333"sharding_strategy": sharding_strategy,1334"mixed_precision": mixed_precision,1335"use_orig_params": True,1336}1337fsdp_model = FSDP(model, **fsdp_kwargs)1338inp = torch.randn((2, 3), device="cuda")1339with fsdp_model.no_sync():1340# For each of these `no_sync()` backward passes, check that the1341# gradients are in the low precision parameter dtype (FP16)1342fsdp_model(inp).sum().backward()1343for param in fsdp_model.parameters():1344if param.grad is not None:1345self.assertEqual(param.grad.dtype, torch.float16)1346fsdp_model(inp).sum().backward()1347for param in fsdp_model.parameters():1348if param.grad is not None:1349self.assertEqual(param.grad.dtype, torch.float16)1350# For the backward pass outside `no_sync()`, check that the gradients1351# are cast to the full precision in preparation for the optimizer step1352fsdp_model(inp).sum().backward()1353for param in fsdp_model.parameters():1354if param.grad is not None:1355self.assertEqual(param.grad.dtype, torch.float32)1356
1357
1358class TestFSDPUseOrigParamsInit(FSDPTest):1359@skip_if_lt_x_gpu(2)1360def test_non_uniform_requires_grad(self):1361model = nn.Sequential(1362nn.Linear(3, 3, device="cuda"),1363nn.Linear(3, 3, device="cuda"),1364)1365# Freeze biases only and flatten both weights and biases into the same1366# `FlatParameter` to exercise non-uniform `requires_grad`1367model[0].bias.requires_grad = False1368model[1].bias.requires_grad = False1369fsdp_model = FSDP(model, use_orig_params=True)1370self.assertTrue(fsdp_model[0].weight.requires_grad)1371self.assertFalse(fsdp_model[0].bias.requires_grad)1372self.assertTrue(fsdp_model[1].weight.requires_grad)1373self.assertFalse(fsdp_model[1].bias.requires_grad)1374
1375
1376# Define this to be large enough to trigger stack corruption
1377NUM_SIZE0_TENSORS = 10001378
1379
1380class TestMultiTensorApply(TestCase):1381def test_multi_tensor_apply_size0_tensors_cpu(self):1382size0_tensors = [torch.empty(0, device="cpu") for _ in range(NUM_SIZE0_TENSORS)]1383# Check that this does not segfault1384torch._foreach_mul_(size0_tensors, 0.1)1385
1386@unittest.skipIf(not TEST_CUDA, "no cuda")1387def test_multi_tensor_apply_size0_tensors_cuda(self):1388size0_tensors = [1389torch.empty(0, device="cuda") for _ in range(NUM_SIZE0_TENSORS)1390]1391# Check that this does not segfault1392torch._foreach_mul_(size0_tensors, 0.1)1393
1394
1395instantiate_parametrized_tests(TestFSDPUseOrigParamsMultipleParamGroups)1396instantiate_parametrized_tests(TestFSDPUseOrigParamsUnshardReshard)1397instantiate_parametrized_tests(TestFSDPUseOrigParamsParamAccess)1398instantiate_parametrized_tests(TestFSDPUseOrigParamsFQNs)1399instantiate_parametrized_tests(TestFSDPUseOrigParamsNoSync)1400
1401if __name__ == "__main__":1402run_tests()1403