pytorch
957 строк · 37.4 Кб
1# Owner(s): ["oncall: distributed"]
2
3import functools4import itertools5import os6import tempfile7import unittest8from enum import auto, Enum9from typing import Callable, Union10
11import torch12import torch.nn as nn13import torch.nn.functional as F14from torch.distributed.fsdp._wrap_utils import _validate_frozen_params15from torch.distributed.fsdp.fully_sharded_data_parallel import (16BackwardPrefetch,17CPUOffload,18FullyShardedDataParallel as FSDP,19MixedPrecision,20ShardingStrategy,21)
22from torch.distributed.fsdp.wrap import (23_or_policy,24_Policy,25_wrap_module_cls_individually,26always_wrap_policy,27CustomPolicy,28enable_wrap,29ModuleWrapPolicy,30size_based_auto_wrap_policy,31transformer_auto_wrap_policy,32wrap,33)
34from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer35from torch.nn.modules.batchnorm import _BatchNorm36from torch.testing._internal.common_cuda import TEST_MULTIGPU37from torch.testing._internal.common_distributed import skip_if_lt_x_gpu38from torch.testing._internal.common_fsdp import (39_maybe_cuda,40CUDAInitMode,41DummyProcessGroup,42FSDPInitMode,43FSDPTest,44TransformerWithSharedParams,45)
46from torch.testing._internal.common_utils import (47FILE_SCHEMA,48find_free_port,49instantiate_parametrized_tests,50parametrize,51run_tests,52TEST_CUDA,53TestCase,54)
55
56
57class BatchNormNet(nn.Module):58def __init__(self):59super().__init__()60self.lin = nn.Linear(10, 10, bias=False)61self.bn1 = nn.BatchNorm1d(10)62self.bn2 = nn.BatchNorm2d(10)63self.bn3 = nn.BatchNorm3d(10)64self.sync_bn = nn.SyncBatchNorm(10)65
66
67class LoraModel(nn.Module):68"""This is a toy LoRA decoder model."""69
70def __init__(self):71super().__init__()72self.embed_tokens = nn.Embedding(100, 32)73self.layers = nn.ModuleList([LoraDecoder() for _ in range(4)])74self.norm = nn.LayerNorm(32)75self.embed_tokens.weight.requires_grad_(False)76self.norm.weight.requires_grad_(False)77self.norm.bias.requires_grad_(False)78
79
80class LoraDecoder(nn.Module):81def __init__(self):82super().__init__()83self.attn = LoraAttention()84self.mlp = LoraMLP()85self.inp_layernorm = nn.LayerNorm(32)86self.post_attn_layernorm = nn.LayerNorm(32)87self.inp_layernorm.weight.requires_grad_(False)88self.inp_layernorm.bias.requires_grad_(False)89self.post_attn_layernorm.weight.requires_grad_(False)90self.post_attn_layernorm.bias.requires_grad_(False)91
92
93class LoraAttention(nn.Module):94def __init__(self):95super().__init__()96self.q_proj = nn.Linear(32, 32, bias=False)97self.lora_A = nn.Linear(32, 8, bias=False)98self.lora_B = nn.Linear(8, 32, bias=False)99self.k_proj = nn.Linear(32, 32, bias=False)100self.v_proj = nn.Linear(32, 32, bias=False)101self.o_proj = nn.Linear(32, 32, bias=False)102self.q_proj.weight.requires_grad_(False)103self.k_proj.weight.requires_grad_(False)104self.v_proj.weight.requires_grad_(False)105self.o_proj.weight.requires_grad_(False)106
107
108class LoraMLP(nn.Module):109def __init__(self):110super().__init__()111self.proj1 = nn.Linear(32, 128, bias=False)112self.proj2 = nn.Linear(128, 32, bias=False)113self.proj1.weight.requires_grad_(False)114self.proj2.weight.requires_grad_(False)115
116
117class WrapMethod(Enum):118FSDP_CTOR = auto()119# FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss120# any use cases and fix them to work with FSDP_CTOR over time.121WRAP_API = auto()122
123
124class TestFSDPWrap(FSDPTest):125"""126Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into
127FSDP constructor.
128"""
129
130def setUp(self) -> None:131super().setUp()132
133class NestedSequentialModel:134@staticmethod135def get_model(cuda=True):136sequential = nn.Sequential(137nn.Linear(5, 5),138nn.Linear(5, 5),139nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)),140)141if cuda:142sequential = sequential.cuda()143return sequential144
145@staticmethod146def verify_model_all_wrapped(cls, model):147cls.assertTrue(isinstance(model, FSDP))148cls.assertTrue(isinstance(model.module[0], FSDP))149cls.assertTrue(isinstance(model.module[1], FSDP))150cls.assertTrue(isinstance(model.module[2], FSDP))151cls.assertTrue(isinstance(model.module[2].module[0], FSDP))152cls.assertTrue(isinstance(model.module[2].module[1], FSDP))153
154@staticmethod155def verify_model(cls, model):156cls.assertTrue(isinstance(model, FSDP))157cls.assertTrue(isinstance(model.module[0], nn.Linear))158cls.assertTrue(isinstance(model.module[1], nn.Linear))159cls.assertTrue(isinstance(model.module[2], FSDP))160# following modules were not wrapped by the policy.161cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear))162cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear))163
164def _get_linear(self, fin, fout):165return nn.Linear(fin, fout, bias=False)166
167def _get_already_wrapped_fsdp(168self, cuda_init_mode=CUDAInitMode.CUDA_BEFORE, nested=False169) -> FSDP:170fn_self = self171
172class MyModel(nn.Module):173def __init__(self, nested):174super().__init__()175# TODO: test the various init modes.176move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE177# if nested=True, the FSDP module will be nested one layer deep178# and we should pick that up.179if nested:180self.lin1 = nn.Sequential(181_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda),182FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)),183)184else:185self.lin1 = FSDP(186_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)187)188self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))189self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))190
191def forward(self, input: torch.Tensor) -> torch.Tensor:192return self.lin3(self.lin2(self.lin1(input)))193
194model = MyModel(nested=nested)195return model196
197@skip_if_lt_x_gpu(2)198@parametrize("nested", [True, False])199@parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE])200def test_error_already_wrapped(self, nested, cuda_init_mode):201"""202Test that an error is raised if we attempt to wrap when submodules are
203already FSDP.
204"""
205wrapped_fsdp = self._get_already_wrapped_fsdp(206nested=nested, cuda_init_mode=cuda_init_mode207)208if cuda_init_mode == CUDAInitMode.CUDA_AFTER:209wrapped_fsdp = wrapped_fsdp.cuda()210
211wrapped_module_name = "lin1.1" if nested else "lin1"212with self.assertRaisesRegex(213ValueError,214"FSDP auto wrapping requires modules to not already have FSDP "215f"applied but found {wrapped_module_name} in",216):217FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy)218
219@skip_if_lt_x_gpu(2)220@parametrize("use_or_policy", [True, False])221def test_wrap_batchnorm_individually(self, use_or_policy):222def never_wrap_policy(*args, **kwargs):223return False224
225wrap_batchnorm_individually = functools.partial(226_wrap_module_cls_individually,227module_classes=[228_BatchNorm,229],230)231policy = (232functools.partial(233_or_policy, policies=[never_wrap_policy, wrap_batchnorm_individually]234)235if use_or_policy236else wrap_batchnorm_individually237)238model = BatchNormNet()239fsdp = FSDP(model, auto_wrap_policy=policy)240# Batchnorms should be wrapped241for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]:242self.assertTrue(isinstance(layer, FSDP))243
244self.assertFalse(isinstance(fsdp.lin, FSDP))245
246@skip_if_lt_x_gpu(2)247def test_bn_always_wrapped_individually(self):248"""249Ensures that by using _or_policy with _wrap_module_cls_individually, even
250if the other policy results in a module containing a BN unit being
251wrapped, the contained BN unit will still be individually wrapped.
252"""
253
254class MyModule(nn.Module):255def __init__(self):256super().__init__()257self.bn_container = BatchNormNet()258
259def wrap_bn_container(module, recurse, *args, **kwargs):260if recurse:261return True262return isinstance(module, BatchNormNet)263
264wrap_batchnorm_individually = functools.partial(265_wrap_module_cls_individually,266module_classes=[267_BatchNorm,268],269)270
271my_policy = functools.partial(272_or_policy, policies=[wrap_bn_container, wrap_batchnorm_individually]273)274mod = MyModule()275fsdp = FSDP(mod, auto_wrap_policy=my_policy)276
277# Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN))))278# and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner279# BN is not individually wrapped.)280
281for bn in [282fsdp.bn_container.bn1,283fsdp.bn_container.bn2,284fsdp.bn_container.bn3,285fsdp.bn_container.sync_bn,286]:287self.assertTrue(isinstance(bn, FSDP))288
289# if we just wrapped BN container, individual batchnorms are not290# wrapped.291mod = MyModule()292fsdp = FSDP(mod, auto_wrap_policy=wrap_bn_container)293self.assertTrue(isinstance(mod.bn_container, FSDP))294for bn in [295fsdp.bn_container.bn1,296fsdp.bn_container.bn2,297fsdp.bn_container.bn3,298fsdp.bn_container.sync_bn,299]:300self.assertFalse(isinstance(bn, FSDP))301
302@skip_if_lt_x_gpu(2)303@parametrize(304"cpu_offload",305[CPUOffload(offload_params=False), CPUOffload(offload_params=True)],306)307@parametrize(308"backward_prefetch",309[BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE],310)311@parametrize("forward_prefetch", [False, True])312@parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE])313def test_main_wrap_api(314self,315cpu_offload: CPUOffload,316backward_prefetch: BackwardPrefetch,317forward_prefetch: bool,318cuda_init_mode: CUDAInitMode,319):320if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:321# they don't work together, expected322return323
324move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE325
326class Nested(nn.Module):327def __init__(self):328super().__init__()329self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)330
331def forward(self, input):332return self.nested_lin(input)333
334class MyModel(nn.Module):335def __init__(self):336super().__init__()337self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)338self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)339self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)340self.lin4 = Nested()341
342def forward(self, input):343return self.lin4(self.lin3(self.lin2(self.lin1(input))))344
345model = MyModel()346wrapped_model = FSDP(347model,348auto_wrap_policy=functools.partial(349size_based_auto_wrap_policy,350min_num_params=0, # wrap all modules351),352cpu_offload=cpu_offload,353backward_prefetch=backward_prefetch,354forward_prefetch=forward_prefetch,355)356if cuda_init_mode == CUDAInitMode.CUDA_AFTER:357wrapped_model = wrapped_model.cuda()358
359modules_in_fsdp_graph_order = [360wrapped_model.module.lin1,361wrapped_model.module.lin2,362wrapped_model.module.lin3,363wrapped_model.module.lin4.module.nested_lin,364wrapped_model.module.lin4,365wrapped_model,366]367
368for module in modules_in_fsdp_graph_order:369self.assertTrue(isinstance(module, FSDP))370self._check_cpu_offload(module, cpu_offload)371self._check_backward_prefetch(module, backward_prefetch)372self._check_forward_prefetch(module, forward_prefetch)373
374# Run model a few times for sanity check.375optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)376inp = torch.ones(1).cuda()377for _ in range(6):378optim.zero_grad()379loss = wrapped_model(inp).sum()380loss.backward()381optim.step()382
383
384class TestAutoWrap(TestCase):385def setUp(self) -> None:386super().setUp()387# For all the tests here, we use a fake group388self.process_group = DummyProcessGroup(rank=0, size=1)389
390@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")391@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])392def test_wrap(self, wrap_method):393if wrap_method == WrapMethod.WRAP_API:394with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):395layer = wrap(nn.Linear(5, 5))396else:397assert wrap_method == WrapMethod.FSDP_CTOR398layer = FSDP(399nn.Linear(5, 5),400process_group=self.process_group,401auto_wrap_policy=functools.partial(402size_based_auto_wrap_policy, min_num_params=1403),404)405self.assertTrue(isinstance(layer, FSDP))406self.assertEqual(layer.rank, self.process_group.rank())407self.assertEqual(layer.world_size, self.process_group.size())408
409@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")410def test_wrap_disabled_outside_context(self):411pg = self.process_group412
413class MyModel(nn.Module):414def __init__(self):415super().__init__()416self.lin = wrap(nn.Linear(5, 5), process_group=pg)417
418model = MyModel()419with enable_wrap(wrapper_cls=FSDP, process_group=pg):420model = wrap(model)421
422self.assertTrue(isinstance(model, FSDP))423self.assertFalse(isinstance(model.lin, FSDP))424self.assertTrue(isinstance(model.lin, nn.Linear))425
426@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")427def test_wrap_override_defaults(self):428new_process_group = DummyProcessGroup(rank=0, size=2)429with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):430layer = wrap(nn.Linear(5, 5), process_group=new_process_group)431self.assertTrue(isinstance(layer, FSDP))432self.assertTrue(layer.process_group is new_process_group)433self.assertEqual(layer.rank, 0)434self.assertEqual(layer.world_size, 2)435
436@unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")437def test_always_wrap(self):438"""439Test to ensure that if `always_wrap_policy` is
440passed into FSDP, all submodules are wrapped.
441"""
442seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)443model = FSDP(444seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy445)446TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model)447
448@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")449def test_transformer_auto_wrap_policy(self):450"""Tests the ``transformer_auto_wrap_policy``."""451auto_wrap_policy = functools.partial(452transformer_auto_wrap_policy,453transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},454)455self._test_transformer_wrapping(auto_wrap_policy)456
457@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")458def test_module_wrap_policy(self):459"""Tests the ``ModuleWrapPolicy``."""460auto_wrap_policy = ModuleWrapPolicy(461{TransformerEncoderLayer, TransformerDecoderLayer}462)463self._test_transformer_wrapping(auto_wrap_policy)464
465@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")466def test_module_wrap_policy_callable(self):467"""Tests the ``ModuleWrapPolicy`` as a ``Callable``."""468auto_wrap_policy = ModuleWrapPolicy(469{TransformerEncoderLayer, TransformerDecoderLayer}470)471callable_policy = functools.partial(_or_policy, policies=[auto_wrap_policy])472self._test_transformer_wrapping(callable_policy)473
474def _test_transformer_wrapping(self, auto_wrap_policy: Union[Callable, _Policy]):475fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}476fsdp_model = TransformerWithSharedParams.init(477self.process_group,478FSDPInitMode.RECURSIVE,479CUDAInitMode.CUDA_BEFORE,480fsdp_kwargs,481)482modules = list(fsdp_model.modules())483encoder_layers = set(fsdp_model.module.transformer.encoder.layers)484decoder_layers = set(fsdp_model.module.transformer.decoder.layers)485for module in modules:486if (487module is fsdp_model488or module in encoder_layers489or module in decoder_layers490):491self.assertTrue(isinstance(module, FSDP))492else:493self.assertFalse(isinstance(module, FSDP))494
495@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")496def test_custom_policy(self):497"""498Tests ``CustomPolicy`` with both a lambda function that uses uniform
499kwargs (so only returns ``False`` or ``True``) and a lambda function
500that uses non-uniform kwargs (so returns a dict to override the root
501kwargs).
502"""
503for use_uniform_kwargs in [False, True]:504self._test_custom_policy(use_uniform_kwargs)505
506def _test_custom_policy(self, use_uniform_kwargs: bool):507print(f"use_uniform_kwargs={use_uniform_kwargs}")508model = TransformerWithSharedParams.init(509self.process_group,510FSDPInitMode.NO_FSDP,511CUDAInitMode.CUDA_BEFORE,512{},513)514
515if use_uniform_kwargs:516
517def lambda_fn(module: nn.Module):518if module is model.bn:519return True520elif isinstance(521module, (TransformerEncoderLayer, TransformerDecoderLayer)522):523return True524return False525
526else:527
528def lambda_fn(module: nn.Module):529if module is model.bn:530return {"sharding_strategy": ShardingStrategy.NO_SHARD}531elif isinstance(module, TransformerEncoderLayer):532return True533elif isinstance(module, TransformerDecoderLayer):534return {535"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,536"backward_prefetch": BackwardPrefetch.BACKWARD_POST,537}538return False539
540policy = CustomPolicy(lambda_fn)541# Use a size-2 dummy PG to avoid clamping the sharding strategy to542# `NO_SHARD` as for a size-1 PG543process_group = DummyProcessGroup(rank=0, size=2)544fp16_mp = MixedPrecision(param_dtype=torch.float16)545fp32_mp = MixedPrecision()546model = FSDP(547model,548process_group=process_group,549auto_wrap_policy=policy,550mixed_precision=fp16_mp,551)552encoder_layers = set(model.module.transformer.encoder.layers)553decoder_layers = set(model.module.transformer.decoder.layers)554bn = model.module.bn555bn_strategy = (556ShardingStrategy.FULL_SHARD557if use_uniform_kwargs558else ShardingStrategy.NO_SHARD559)560bn_prefetch = BackwardPrefetch.BACKWARD_PRE561encoder_strategy = root_strategy = ShardingStrategy.FULL_SHARD562encoder_prefetch = root_prefetch = BackwardPrefetch.BACKWARD_PRE563decoder_strategy = (564ShardingStrategy.FULL_SHARD565if use_uniform_kwargs566else ShardingStrategy.SHARD_GRAD_OP567)568decoder_prefetch = (569BackwardPrefetch.BACKWARD_PRE570if use_uniform_kwargs571else BackwardPrefetch.BACKWARD_POST572)573for module in model.modules():574if module is bn:575self.assertTrue(isinstance(module, FSDP))576self.assertEqual(module.sharding_strategy, bn_strategy)577self.assertEqual(module.backward_prefetch, bn_prefetch)578# We currently override batch norm modules to use fp32579self.assertEqual(module.mixed_precision, fp32_mp)580elif module in encoder_layers:581self.assertTrue(isinstance(module, FSDP))582self.assertEqual(module.sharding_strategy, encoder_strategy)583self.assertEqual(module.backward_prefetch, encoder_prefetch)584self.assertEqual(module.mixed_precision, fp16_mp)585elif module in decoder_layers:586self.assertTrue(isinstance(module, FSDP))587self.assertEqual(module.sharding_strategy, decoder_strategy)588self.assertEqual(module.backward_prefetch, decoder_prefetch)589self.assertEqual(module.mixed_precision, fp16_mp)590elif module is model:591self.assertTrue(isinstance(module, FSDP))592self.assertEqual(module.sharding_strategy, root_strategy)593self.assertEqual(module.backward_prefetch, root_prefetch)594self.assertEqual(module.mixed_precision, fp16_mp)595else:596self.assertFalse(isinstance(module, FSDP))597
598@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")599def test_auto_wrap_api(self):600"""601Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
602``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
603"""
604sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)605my_auto_wrap_policy = functools.partial(606size_based_auto_wrap_policy, min_num_params=40607)608model = FSDP(609sequential,610process_group=self.process_group,611auto_wrap_policy=my_auto_wrap_policy,612)613
614TestFSDPWrap.NestedSequentialModel.verify_model(self, model)615
616@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")617def test_auto_wrap_preset_exclude_wrap(self):618"""619Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
620min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
621"""
622sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])623my_auto_wrap_policy = functools.partial(624size_based_auto_wrap_policy, min_num_params=40625)626
627model = FSDP(628sequential,629process_group=self.process_group,630auto_wrap_policy=my_auto_wrap_policy,631)632
633self.assertTrue(isinstance(model, FSDP))634self.assertTrue(isinstance(model[0], nn.Linear))635self.assertTrue(isinstance(model[1], nn.Linear))636
637@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")638def test_auto_wrap_preset_exclude_wrap_include_children(self):639"""640Test to ensure excluded modules are not wrapped, but children are if param size is greater than
641min_num_params
642"""
643sequential = nn.ModuleList([nn.Linear(10, 10)])644my_auto_wrap_policy = functools.partial(645size_based_auto_wrap_policy, min_num_params=40646)647model = FSDP(648sequential,649process_group=self.process_group,650auto_wrap_policy=my_auto_wrap_policy,651)652
653self.assertTrue(isinstance(model, FSDP))654self.assertTrue(isinstance(model[0], FSDP))655
656@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")657def test_auto_wrap_preset_force_leaf(self):658"""659Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
660size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
661"""
662sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))663my_auto_wrap_policy = functools.partial(664size_based_auto_wrap_policy, min_num_params=40665)666model = FSDP(667sequential,668process_group=self.process_group,669auto_wrap_policy=my_auto_wrap_policy,670)671self.assertTrue(isinstance(model.module[0], FSDP))672# Assert children of multihead attention are not wrapped673self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))674self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))675
676@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")677def test_auto_wrap_preset_force_leaf_custom(self):678"""679Test to ensure force-leaf modules are not wrapped.
680"""
681my_auto_wrap_policy = functools.partial(682size_based_auto_wrap_policy,683min_num_params=40,684force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES.union(685{nn.Linear}686),687)688sequential = nn.Sequential(689nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])690)691model = FSDP(692sequential,693process_group=self.process_group,694auto_wrap_policy=my_auto_wrap_policy,695)696# Model was wrapped in FSDP as no inner modules were wrapped.697self.assertTrue(isinstance(model, FSDP))698self.assertTrue(isinstance(model.module[0], nn.Linear))699self.assertTrue(isinstance(model.module[1], nn.ModuleList))700
701@unittest.skipIf(not TEST_CUDA, "Test Requires CUDA")702@parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER])703@parametrize(704"cpu_offload",705[CPUOffload(offload_params=False), CPUOffload(offload_params=True)],706)707@parametrize("use_device_id", [True, False])708def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id):709# CPU offload and CUDA after don't work together as expected.710if cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER:711return712
713device = torch.device("cuda")714torch.cuda.set_device(0)715device_id = (716torch.device("cuda", torch.cuda.current_device()) if use_device_id else None717)718
719# Random port in case the next test run quickly, same port would cause conflict.720os.environ["MASTER_ADDR"] = "localhost"721os.environ["MASTER_PORT"] = str(find_free_port())722
723file_name = tempfile.NamedTemporaryFile(delete=False).name724torch.distributed.init_process_group(725backend="nccl",726init_method=f"{FILE_SCHEMA}_{file_name}",727rank=0,728world_size=1,729)730
731# NOTE: We move model to CUDA after init with FSDP to simulate real use732# cases where full model cannot be loaded onto GPU, but their shards can.733cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER734try:735sequential = TestFSDPWrap.NestedSequentialModel.get_model(736cuda=(not cuda_after_init)737)738my_auto_wrap_policy = functools.partial(739size_based_auto_wrap_policy, min_num_params=40740)741model = FSDP(742sequential,743cpu_offload=cpu_offload,744auto_wrap_policy=my_auto_wrap_policy,745device_id=device_id,746)747TestFSDPWrap.NestedSequentialModel.verify_model(self, model)748if cuda_after_init:749model = model.cuda()750input = torch.rand((1, 5), dtype=torch.float).to(device)751output = model(input)752loss = F.mse_loss(input, output)753loss.backward()754finally:755torch.distributed.destroy_process_group()756
757try:758os.remove(file_name)759except FileNotFoundError:760pass761
762@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")763@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])764def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):765sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)766ignored_modules = [sequential[1], sequential[2][0]]767fsdp_kwargs = {768"process_group": self.process_group,769"auto_wrap_policy": always_wrap_policy,770"ignored_modules": ignored_modules,771}772if wrap_method == WrapMethod.FSDP_CTOR:773model = FSDP(sequential, **fsdp_kwargs)774elif wrap_method == WrapMethod.WRAP_API:775with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):776model = wrap(sequential)777else:778assert 0, f"Unsupported wrap method: {wrap_method}"779# All non-ignored modules should be wrapped with FSDP780self.assertTrue(isinstance(model, FSDP))781self.assertTrue(isinstance(model.module[0], FSDP))782self.assertTrue(isinstance(model.module[1], nn.Linear))783self.assertTrue(isinstance(model.module[2], FSDP))784self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))785self.assertTrue(isinstance(model.module[2].module[1], FSDP))786
787@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")788@parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])789def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):790sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)791ignored_modules = [sequential[1], sequential[2][0]]792my_auto_wrap_policy = functools.partial(793size_based_auto_wrap_policy,794min_num_params=40,795)796fsdp_kwargs = {797"process_group": self.process_group,798"auto_wrap_policy": my_auto_wrap_policy,799"ignored_modules": ignored_modules,800}801if wrap_method == WrapMethod.FSDP_CTOR:802model = FSDP(sequential, **fsdp_kwargs)803elif wrap_method == WrapMethod.WRAP_API:804with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):805model = wrap(sequential)806else:807assert 0, f"Unsupported wrap method: {wrap_method}"808# Since the 2nd linear (`sequential[1]`) is ignored, the wrapping809# policy does not exceed the parameter threshold before the inner810# sequential (`sequential[2]`) anymore; hence, it flattens811# `sequential[0]` and `sequential[2][0]` into `model` and leaves812# `sequential[1]` and `sequential[2][1]` as-is since they are ignored813self.assertTrue(isinstance(model, FSDP))814self.assertTrue(isinstance(model.module[0], nn.Linear))815self.assertTrue(isinstance(model.module[1], nn.Linear))816self.assertTrue(isinstance(model.module[2], nn.Sequential))817self.assertTrue(isinstance(model.module[2][0], nn.Linear))818self.assertTrue(isinstance(model.module[2][1], nn.Linear))819
820@unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs")821def test_frozen_params(self):822"""823Tests that mixing frozen/non-frozen parameters in an FSDP instance
824raises for ``use_orig_params=False`` and warns for ``True``.
825"""
826module_classes = (LoraAttention, LoraMLP, LoraDecoder)827module_wrap_policy = ModuleWrapPolicy(module_classes)828
829def lambda_fn_uniform(module: nn.Module):830return isinstance(module, module_classes)831
832def lambda_fn_nonuniform(module: nn.Module):833if isinstance(module, LoraAttention):834return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP}835elif isinstance(module, module_classes):836return True837return False838
839lambda_wrap_policy_uniform = CustomPolicy(lambda_fn_uniform)840lambda_wrap_policy_nonuniform = CustomPolicy(lambda_fn_nonuniform)841
842for use_orig_params, policy in itertools.product(843[True, False],844[845module_wrap_policy,846lambda_wrap_policy_uniform,847lambda_wrap_policy_nonuniform,848],849):850self._test_frozen_params(use_orig_params, policy)851
852def _test_frozen_params(self, use_orig_params: bool, policy: _Policy):853model = LoraModel().cuda()854msg = "layers.0.attn has both parameters with requires_grad=True and False. "855if use_orig_params:856msg += "We do not recommend wrapping such modules"857ctx = self.assertWarnsRegex(UserWarning, msg)858else:859msg += "FSDP does not support wrapping such modules when use_orig_params=False."860ctx = self.assertRaisesRegex(ValueError, msg)861with ctx:862FSDP(863model,864process_group=self.process_group,865auto_wrap_policy=policy,866use_orig_params=use_orig_params,867)868
869
870class TestWrapUtils(TestCase):871def test_validate_frozen_params(self):872"""Tests the method ``_validate_frozen_params()``."""873for use_orig_params in [True, False]:874self._test_validate_frozen_params(use_orig_params)875
876def _test_validate_frozen_params(self, use_orig_params: bool):877model = LoraModel()878# Wrap only LoRA modules879modules_to_wrap = {880module
881for module_name, module in model.named_modules()882if "lora_A" in module_name or "lora_B" in module_name883}884_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)885# Additionally wrap attention886for module in model.modules():887if isinstance(module, LoraAttention):888modules_to_wrap.add(module)889_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)890# Additionally wrap decoders891for module in model.modules():892if isinstance(module, LoraDecoder):893modules_to_wrap.add(module)894_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)895# Do not wrap the LoRA-A modules (meaning mixed frozen/non-frozen)896for module_name, module in model.named_modules():897if "lora_A" in module_name:898modules_to_wrap.remove(module)899regex = "layers.0.attn has both parameters with requires_grad=True and False."900if use_orig_params:901# Wrapping the attention manages all parameters except those from902# the LoRA-B module, which is separately wrapped and all nonfrozen903lorab_numel = sum(904p.numel() for p in model.layers[0].attn.lora_B.parameters()905)906attn_frozen_param_numel = sum(907p.numel()908for p in model.layers[0].attn.parameters()909if not p.requires_grad910)911attn_nonfrozen_param_numel = (912sum(913p.numel()914for p in model.layers[0].attn.parameters()915if p.requires_grad916)917- lorab_numel918)919attn_total_param_numel = (920attn_frozen_param_numel + attn_nonfrozen_param_numel921)922regex += (923" We do not recommend wrapping such modules since the "924r"gradient memory usage will be higher than expected \("925f"{attn_total_param_numel} numel instead of {attn_nonfrozen_param_numel} numel "926r"before sharding via reduce-scatter\). "927)928else:929regex += " FSDP does not support wrapping such modules when use_orig_params=False. "930regex += "If possible, wrap the frozen parameters with FSDP separately.\n"931regex += (932"The following parameters have requires_grad=True:\n"933r"\['layers.0.attn.lora_A.weight'\]\n"934"The following parameters have requires_grad=False:\n"935r"\['layers.0.attn.q_proj.weight', 'layers.0.attn.k_proj.weight', "936r"'layers.0.attn.v_proj.weight', 'layers.0.attn.o_proj.weight'\]"937)938if use_orig_params:939ctx = self.assertWarnsRegex(UserWarning, regex)940else:941ctx = self.assertRaisesRegex(ValueError, regex)942with ctx:943_validate_frozen_params(model, modules_to_wrap, set(), use_orig_params)944# Now ignore those LoRA-A modules' parameters945ignored_params = set()946for module_name, module in model.named_modules():947if "lora_A" in module_name:948for param in module.parameters():949ignored_params.add(param)950_validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params)951
952
953instantiate_parametrized_tests(TestFSDPWrap)954instantiate_parametrized_tests(TestAutoWrap)955
956if __name__ == "__main__":957run_tests()958