pytorch
102 строки · 3.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys4
5import torch6import torch.distributed as dist7import torch.nn as nn8from torch.distributed.fsdp import FullyShardedDataParallel as FSDP9from torch.testing._internal.common_distributed import skip_if_lt_x_gpu10from torch.testing._internal.common_fsdp import (11CUDAInitMode,12FSDPInitMode,13FSDPTest,14NestedWrappedModule,15TransformerWithSharedParams,16)
17from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN18
19if not dist.is_available():20print("Distributed not available, skipping tests", file=sys.stderr)21sys.exit(0)22
23if TEST_WITH_DEV_DBG_ASAN:24print(25"Skip dev-asan as torch + multiprocessing spawn have known issues",26file=sys.stderr,27)28sys.exit(0)29
30
31class TestApply(FSDPTest):32@property33def world_size(self):34return 235
36@torch.no_grad()37def _init_linear_weights(self, m):38if type(m) == nn.Linear:39m.weight.fill_(1.0)40m.bias.fill_(1.0)41
42def check_weights(self, fsdp, expected_tensor_fn, check):43with FSDP.summon_full_params(fsdp, recurse=True):44linear_modules = [45module for module in fsdp.modules() if type(module) == nn.Linear46]47for module in linear_modules:48for param in module.parameters():49expected = expected_tensor_fn(param)50check(param, expected, f"Got {param} but expected {expected}")51
52def _check_apply(self, fsdp):53# Assert linear weights are not all 1.054self.check_weights(55fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual56)57
58fsdp.apply(self._init_linear_weights)59
60# Ensure all weights are 1.061self.check_weights(62fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual63)64
65@skip_if_lt_x_gpu(2)66def test_nested_module_apply(self):67"""Tests that ``apply()`` modifies parameter values in-place on a68non-FSDP-root nested FSDP-wrapped model."""
69nested_wrapped_module = NestedWrappedModule.init(70self.process_group,71FSDPInitMode.RECURSIVE,72CUDAInitMode.CUDA_AFTER,73)74self._check_apply(nested_wrapped_module)75
76@skip_if_lt_x_gpu(2)77def test_transformer_module_apply(self):78"""Tests that ``apply()`` modifies parameter values in-place on an79FSDP-wrapped transformer model with shared parameters."""
80transformer = TransformerWithSharedParams.init(81self.process_group,82FSDPInitMode.RECURSIVE,83CUDAInitMode.CUDA_AFTER,84)85self._check_apply(transformer)86
87@skip_if_lt_x_gpu(2)88def test_apply_in_summon_raises_error(self):89"""Tests that calling ``apply()`` on an FSDP instance inside the90``summon_full_params()`` context raises an error."""
91transformer = TransformerWithSharedParams.init(92self.process_group,93FSDPInitMode.RECURSIVE,94CUDAInitMode.CUDA_AFTER,95)96with transformer.summon_full_params(transformer):97with self.assertRaisesRegex(ValueError, "expected to be in states"):98transformer.apply(self._init_linear_weights)99
100
101if __name__ == "__main__":102run_tests()103