pytorch
1317 строк · 51.2 Кб
1# Owner(s): ["oncall: distributed"]
2
3import io4import itertools5import sys6from contextlib import nullcontext7from copy import deepcopy8from functools import partial9from typing import Any, Dict10
11import torch12import torch.nn as nn13from torch import distributed as dist14from torch.distributed._shard.sharded_tensor import (15init_from_local_shards,16Shard,17ShardedTensor,18)
19from torch.distributed._state_dict_utils import (20_all_gather_sharded_tensor,21_gather_state_dict,22)
23from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (24apply_activation_checkpointing,25checkpoint_wrapper,26CheckpointImpl,27)
28from torch.distributed.fsdp import (29CPUOffload,30FullStateDictConfig,31FullyShardedDataParallel as FSDP,32LocalStateDictConfig,33MixedPrecision,34ShardedStateDictConfig,35StateDictType,36)
37from torch.distributed.fsdp._common_utils import FSDP_PREFIX38from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM39from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap40from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer41from torch.nn.parallel import DistributedDataParallel42from torch.optim import SGD43from torch.testing._internal.common_distributed import skip_if_lt_x_gpu44from torch.testing._internal.common_fsdp import (45_assert_module_states,46_broadcast_state_dict,47_get_state_dict,48_zero_model,49CUDAInitMode,50FSDPInitMode,51FSDPTest,52get_full_params,53SkipModel,54TransformerWithSharedParams,55)
56from torch.testing._internal.common_utils import (57instantiate_parametrized_tests,58parametrize,59run_tests,60TEST_WITH_DEV_DBG_ASAN,61)
62
63if not dist.is_available():64print("Distributed not available, skipping tests", file=sys.stderr)65sys.exit(0)66
67if TEST_WITH_DEV_DBG_ASAN:68print(69"Skip dev-asan as torch + multiprocessing spawn have known issues",70file=sys.stderr,71)72sys.exit(0)73
74INNER_SHAPE = [4, 4]75OUTER_SHAPE = [4, 5]76BUFFER_SHAPE = [5, 5]77
78NON_ROOT_FSDP_PREFIX = "non_fsdp_lin"79
80_UNFLATTENED_STATE_DICT_IMPLS = ["state_dict", "sharded_state_dict"]81_FLATTENED_STATE_DICT_IMPLS = ["local_state_dict"]82_SUPPORTED_STATE_DICT_IMPLS = (83_UNFLATTENED_STATE_DICT_IMPLS + _FLATTENED_STATE_DICT_IMPLS84)
85
86STATE_DICT_MAPPING = {87"state_dict": StateDictType.FULL_STATE_DICT,88"local_state_dict": StateDictType.LOCAL_STATE_DICT,89"sharded_state_dict": StateDictType.SHARDED_STATE_DICT,90}
91
92
93class Model(Module):94def __init__(95self,96wrap_fsdp,97register_buffers=False,98ignore_inner=False,99mixed_precision=False,100process_group=None,101):102super().__init__()103self.inner = Linear(*INNER_SHAPE)104if register_buffers:105self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))106self.inner.register_buffer(107"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False108)109if wrap_fsdp:110self.inner = FSDP(111self.inner,112ignored_modules=([self.inner] if ignore_inner else []),113mixed_precision=MixedPrecision(114param_dtype=torch.float16,115reduce_dtype=torch.float16,116buffer_dtype=torch.float16,117)118if mixed_precision119else None,120process_group=process_group,121)122self.outer = Linear(*OUTER_SHAPE)123if register_buffers:124self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))125self.outer.register_buffer(126"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False127)128
129def forward(self, x):130# Forward twice.131i = self.inner(x)132j = self.inner(x)133return self.outer(i + j)134
135
136class TestDummyModel(torch.nn.Module):137def __init__(self):138super().__init__()139torch.manual_seed(0)140self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())141self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())142self.net3 = self.net2143self.random_parameter = nn.Parameter(torch.Tensor(10))144self.shared_parameter = self.random_parameter145
146def forward(self, x):147return self.net3(self.net2(self.net1(x)))148
149def get_input(self):150return torch.rand(8, 8, device="cuda")151
152
153class TestFSDPStateDict(FSDPTest):154@property155def world_size(self):156return min(torch.cuda.device_count(), 2)157
158def _broadcast_state_dict(self, model, state_dict):159# TODO (rohan-varma): remove model160return _broadcast_state_dict(self.rank, state_dict)161
162def _state_compare(self, model, model_new, assert_fn, state_generator="parameters"):163state_base = list(getattr(model, state_generator)())164state_new = list(getattr(model_new, state_generator)())165# Regardless of `assert_fn`, the number of parameters should be the same166self.assertEqual(len(state_base), len(state_new))167assert_fn(state_base, state_new)168
169def _compare_models(170self, model, model_new, assert_fn, check_fp16=False, check_buffers=True171):172assert assert_fn in (self.assertEqual, self.assertNotEqual)173with FSDP.summon_full_params(model):174with FSDP.summon_full_params(model_new):175self._state_compare(model, model_new, assert_fn)176if check_buffers:177has_buffers = any(178len(list(m.buffers())) for m in (model, model_new)179)180if has_buffers:181self._state_compare(182model, model_new, assert_fn, state_generator="buffers"183)184if check_fp16:185for tensor in model_new.parameters():186self.assertEqual(tensor.dtype, torch.float16)187
188def _get_simple_nested_model(189self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs190):191if wrap:192lin1 = nn.Linear(10, 10, bias=False).cuda()193lin2 = nn.Linear(10, 10, bias=False).cuda()194if checkpoint_wrap:195lin1 = checkpoint_wrapper(lin1)196lin2 = checkpoint_wrapper(lin2)197seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2)198if checkpoint_wrap:199seq = checkpoint_wrapper(seq)200model = FSDP(seq, *fsdp_args, **fsdp_kwargs)201else:202model = nn.Sequential(203nn.Linear(10, 10, bias=False).cuda(),204nn.Linear(10, 10, bias=False).cuda(),205)206return model207
208def _get_simple_model(self, *fsdp_args, checkpoint_wrap=False, **fsdp_kwargs):209lin = nn.Linear(10, 10, bias=False).cuda()210if checkpoint_wrap:211lin = checkpoint_wrapper(lin)212model = FSDP(lin, *fsdp_args, **fsdp_kwargs)213return model214
215def _get_multibuffer_nested_model(216self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs217):218full_p = torch.float32219lin_mp = fsdp_kwargs.pop("mixed_precision", None)220bn_mp = (221MixedPrecision(param_dtype=full_p, reduce_dtype=full_p, buffer_dtype=full_p)222if lin_mp223else None224)225if wrap:226lin1 = nn.Linear(10, 10, bias=False).cuda()227bn1 = nn.BatchNorm1d(10).cuda()228lin2 = nn.Linear(10, 10, bias=False).cuda()229if checkpoint_wrap:230lin1 = checkpoint_wrapper(lin1)231bn1 = checkpoint_wrapper(bn1)232lin2 = checkpoint_wrapper(lin2)233seq = nn.Sequential(234FSDP(lin1, *fsdp_args, mixed_precision=lin_mp, **fsdp_kwargs),235FSDP(bn1, *fsdp_args, mixed_precision=bn_mp, **fsdp_kwargs),236lin2,237)238if checkpoint_wrap:239seq = checkpoint_wrapper(seq)240model = FSDP(seq, *fsdp_args, **fsdp_kwargs)241else:242model = nn.Sequential(243nn.Linear(10, 10, bias=False).cuda(),244nn.BatchNorm1d(10).cuda(),245nn.Linear(10, 10, bias=False).cuda(),246)247return model248
249def _get_non_fsdp_root_module(self, *fsdp_args, wrap=True, **fsdp_kwargs):250class FSDPContainer(nn.Module):251def __init__(self, fsdp_1, fsdp_2):252super().__init__()253self.non_fsdp_lin = nn.Linear(10, 10, bias=False).cuda()254self.fsdp_1 = fsdp_1255self.fsdp_2 = fsdp_2256
257def forward(self, x):258x = self.non_fsdp_lin(x)259x = self.fsdp_1(x)260x = self.fsdp_2(x)261return x262
263return FSDPContainer(264self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs),265self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs),266)267
268def _get_state_dict_mgr(269self,270model: nn.Module,271state_dict_type: str,272state_dict_rank0_and_offload: bool,273):274_state_dict_type = STATE_DICT_MAPPING[state_dict_type]275if state_dict_type == "state_dict":276config = FullStateDictConfig(277rank0_only=state_dict_rank0_and_offload,278offload_to_cpu=state_dict_rank0_and_offload,279)280elif state_dict_type == "local_state_dict":281config = LocalStateDictConfig(282offload_to_cpu=state_dict_rank0_and_offload,283)284elif state_dict_type == "sharded_state_dict":285config = ShardedStateDictConfig(286offload_to_cpu=state_dict_rank0_and_offload,287)288else:289raise ValueError("Unsupported state_dict_type")290return FSDP.state_dict_type(model, _state_dict_type, config)291
292def _validate_state_dict_contents(293self, model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=None294):295if state_dict_rank0_and_offload:296if self.rank == 0:297self.assertNotEqual(fsdp_state_dict, {})298for key, tensor in fsdp_state_dict.items():299if ignore_keys and key in ignore_keys:300continue301self.assertEqual(302tensor.device,303torch.device("cpu"),304f"{key} is unexpectedly on device {tensor.device}",305)306else:307# For non-FSDP roots, the non FSDP portion can still have parameters on rank 0,308# so bypass the check for now.309if isinstance(model, FSDP):310self.assertEqual(311fsdp_state_dict,312{},313f"Expected empty state_dict but got {fsdp_state_dict} on rank {dist.get_rank()}",314)315
316@skip_if_lt_x_gpu(2)317@parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)318@parametrize(319"checkpoint_wrap",320["source", "dest", "both", "source_after_wrap", "both_after_wrap"],321)322@parametrize("rank0_only_and_offload", [False, True])323def test_fsdp_state_dict_with_activation_checkpoint(324self, state_dict_type, checkpoint_wrap, rank0_only_and_offload325):326"""Tests saving the state dict, zeroing a target model's parameters, and327loading the state dict, where the source and target models may have a
328checkpoint wrapper."""
329
330def apply_ac_to_linears(model) -> None:331non_reentrant_wrapper = partial(332checkpoint_wrapper,333offload_to_cpu=False,334checkpoint_impl=CheckpointImpl.NO_REENTRANT,335)336apply_activation_checkpointing(337model,338checkpoint_wrapper_fn=non_reentrant_wrapper,339check_fn=lambda submodule: isinstance(submodule, nn.Linear),340)341
342for model_call in [343partial(self._get_simple_model),344partial(self._get_simple_nested_model),345]:346model = model_call(checkpoint_wrap=(checkpoint_wrap in ("source", "both")))347if checkpoint_wrap in ("source_after_wrap", "both_after_wrap"):348apply_ac_to_linears(model)349with self._get_state_dict_mgr(350model, state_dict_type, rank0_only_and_offload351):352state_dict = _gather_state_dict(_get_state_dict(model, False, False))353# Possibly wrap new model in activation checkpoint wrapper to test save/354# load with this wrapper355model_new = model_call(356checkpoint_wrap=(checkpoint_wrap in ("dest", "both"))357)358if checkpoint_wrap == "both_after_wrap":359apply_ac_to_linears(model_new)360_zero_model(model_new)361self._compare_models(model, model_new, self.assertNotEqual)362if rank0_only_and_offload:363state_dict = self._broadcast_state_dict(model, state_dict)364# Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks365model_new.load_state_dict(state_dict, strict=True)366self._compare_models(model, model_new, self.assertEqual)367
368@skip_if_lt_x_gpu(2)369@parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)370@parametrize("rank0_only_and_offload", [False, True])371def test_state_dict_with_manual_ac_wrapper(372self,373state_dict_type: str,374rank0_only_and_offload: bool,375):376"""377Tests saving and loading a state dict for a model manually wrapped with
378``FSDP(CheckpointWrapper(module))``, where the ``CheckpointWrapper`` is
379wrapped before FSDP.
380
381TODO: Investigate why the test above does not cover everything in this
382test and de-duplicate afterwards.
383"""
384if state_dict_type == "sharded_state_dict" and rank0_only_and_offload:385return # not supported386model_ac = TransformerWithSharedParams.init(387self.process_group,388FSDPInitMode.NO_FSDP,389CUDAInitMode.CUDA_BEFORE,390)391# Manually wrap FSDP without AC392model_no_ac = deepcopy(model_ac)393for i, layer in enumerate(model_no_ac.transformer.encoder.layers):394model_no_ac.transformer.encoder.layers[i] = FSDP(layer)395for i, layer in enumerate(model_no_ac.transformer.decoder.layers):396model_no_ac.transformer.decoder.layers[i] = FSDP(layer)397model_no_ac.transformer = FSDP(model_no_ac.transformer)398
399# Manually wrap FSDP with AC as `FSDP(CheckpointWrapper(module))`400for i, layer in enumerate(model_ac.transformer.encoder.layers):401layer = checkpoint_wrapper(layer)402model_ac.transformer.encoder.layers[i] = FSDP(layer)403for i, layer in enumerate(model_ac.transformer.decoder.layers):404layer = checkpoint_wrapper(layer)405model_ac.transformer.decoder.layers[i] = FSDP(layer)406model_ac.transformer = FSDP(model_ac.transformer)407
408# Save, load, and compare the two models409with self._get_state_dict_mgr(410model_no_ac, state_dict_type, rank0_only_and_offload411):412state_dict_no_ac = model_no_ac.state_dict()413with self._get_state_dict_mgr(414model_ac, state_dict_type, rank0_only_and_offload415):416state_dict_ac = model_ac.state_dict()417self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys())418if rank0_only_and_offload:419state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac)420state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac)421with self._get_state_dict_mgr(422model_no_ac, state_dict_type, rank0_only_and_offload423):424model_no_ac.load_state_dict(state_dict_no_ac)425with self._get_state_dict_mgr(426model_ac, state_dict_type, rank0_only_and_offload427):428model_ac.load_state_dict(state_dict_ac)429self._compare_models(model_ac, model_no_ac, self.assertEqual)430
431@skip_if_lt_x_gpu(2)432@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)433def test_state_dict_with_shared_parameters(self, state_dict_type):434auto_wrap_policy = ModuleWrapPolicy(435{TransformerEncoderLayer, TransformerDecoderLayer}436)437model_creator = partial(438TransformerWithSharedParams.init,439self.process_group,440FSDPInitMode.RECURSIVE,441CUDAInitMode.CUDA_BEFORE,442{"auto_wrap_policy": auto_wrap_policy},443)444
445fsdp_model = model_creator()446with self._get_state_dict_mgr(fsdp_model, state_dict_type, False):447state_dict = fsdp_model.state_dict()448
449new_model = model_creator()450_zero_model(new_model, zero_buffers=True)451with self._get_state_dict_mgr(new_model, state_dict_type, False):452new_model.load_state_dict(state_dict)453
454@skip_if_lt_x_gpu(2)455@parametrize("use_orig_params", [False, True])456def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool):457"""Tests saving a model checkpoint only on rank 0 and loading it only458on rank 0 with ``sync_module_states=True`` to emulate the workflow to
459avoid redundant CPU memory usage."""
460auto_wrap_policy = ModuleWrapPolicy(461{TransformerEncoderLayer, TransformerDecoderLayer}462)463fsdp_kwargs = {464"auto_wrap_policy": auto_wrap_policy,465"use_orig_params": use_orig_params,466}467fsdp_model = TransformerWithSharedParams.init(468self.process_group,469FSDPInitMode.RECURSIVE,470CUDAInitMode.CUDA_BEFORE,471fsdp_kwargs,472)473# Force model parameters and buffers to be nonzero474with FSDP.summon_full_params(fsdp_model):475for tensor in itertools.chain(476fsdp_model.parameters(), fsdp_model.buffers()477):478if torch.count_nonzero(tensor) == 0:479with torch.no_grad():480tensor.add_(torch.ones_like(tensor))481with self._get_state_dict_mgr(fsdp_model, "state_dict", True):482state_dict = deepcopy(_get_state_dict(fsdp_model))483# Initialize a non-wrapped model on all ranks484new_model = TransformerWithSharedParams.init(485self.process_group,486FSDPInitMode.NO_FSDP,487CUDAInitMode.CUDA_BEFORE,488)489_zero_model(new_model, zero_buffers=True)490# Only load the checkpoint on rank 0491if self.rank == 0:492new_model.load_state_dict(state_dict, strict=True)493_assert_module_states(494new_model,495process_group=self.process_group,496assert_fn=self.assertNotEqual,497)498# Broadcast the module states from rank 0 with `sync_module_states=True`499new_fsdp_model = FSDP(500new_model,501device_id=torch.cuda.current_device(),502auto_wrap_policy=auto_wrap_policy,503sync_module_states=True,504)505# Check FSDP models are equal across ranks506with FSDP.summon_full_params(new_fsdp_model):507_assert_module_states(508new_fsdp_model,509process_group=self.process_group,510assert_fn=self.assertEqual,511)512# Check FSDP models correctly loaded the checkpoint513with FSDP.summon_full_params(fsdp_model):514with FSDP.summon_full_params(new_fsdp_model):515params = list(fsdp_model.parameters())516params_new = list(new_fsdp_model.parameters())517self.assertEqual(params, params_new)518
519@skip_if_lt_x_gpu(2)520@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)521@parametrize(522"cpu_offload",523[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],524)525@parametrize("fp16", [True, False])526@parametrize("state_dict_rank0_and_offload", [True, False])527@parametrize("use_orig_params", [True, False])528def test_basic_save_and_load_state_dict(529self,530state_dict_type: str,531cpu_offload: bool,532fp16: bool,533state_dict_rank0_and_offload: bool,534use_orig_params: bool,535):536"""537Tests that we can save a state_dict and load it into a blank model
538with various configs such as fp16 and cpu offload and parameters
539match as expected.
540"""
541if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or (542use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS543):544return # not supported545device = torch.device(self.rank)546for model_call in [547partial(548self._get_non_fsdp_root_module,549cpu_offload=cpu_offload,550use_orig_params=use_orig_params,551),552partial(553self._get_simple_nested_model,554cpu_offload=cpu_offload,555use_orig_params=use_orig_params,556),557partial(558self._get_simple_model,559cpu_offload=cpu_offload,560use_orig_params=use_orig_params,561),562]:563model = model_call()564if fp16:565model.half()566# Run a forward/backward to compute gradients to test the case567# where there are gradients populated568inp = torch.randn((3, 10), device=device)569if fp16:570inp = inp.half()571model(inp).sum().backward()572
573ctx = self._get_state_dict_mgr(574model, state_dict_type, state_dict_rank0_and_offload575)576with ctx:577fsdp_state_dict = _get_state_dict(578model, cpu_offload.offload_params, fp16579)580
581ignore_keys = [582k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k583]584
585self._validate_state_dict_contents(586model,587fsdp_state_dict,588state_dict_rank0_and_offload,589ignore_keys=ignore_keys,590)591if fp16:592# Verify fp16 is the type593for tensor in fsdp_state_dict.values():594self.assertEqual(tensor.dtype, torch.float16)595
596model_new = model_call()597if not cpu_offload.offload_params:598model_new = model_new.cuda()599if fp16:600model_new.half()601# Run a forward/backward to compute gradients to test the case602# where there are gradients populated603inp = torch.randn((3, 10), device=device)604if fp16:605inp = inp.half()606model_new(inp).sum().backward()607
608# zero the model to ensure parameters are different.609_zero_model(model_new, zero_buffers=True)610self._compare_models(model, model_new, self.assertNotEqual)611
612# Verify parameters are the same in the new model.613if state_dict_rank0_and_offload:614fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)615with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):616model_new.load_state_dict(fsdp_state_dict, strict=True)617
618self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16)619
620@skip_if_lt_x_gpu(2)621@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)622@parametrize(623"cpu_offload",624[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],625)626@parametrize("mixed_precision", [True, False])627@parametrize("state_dict_rank0_and_offload", [True, False])628@parametrize("use_orig_params", [True, False])629def test_buffers_save_and_load_state_dict(630self,631state_dict_type: str,632cpu_offload: bool,633mixed_precision: bool,634state_dict_rank0_and_offload: bool,635use_orig_params: bool,636):637"""638Tests that we can save a state_dict and load it for modules with persistent buffers, including
639in the context of non-default mixed precision, different ``state_dict_type`` s and CPU offloading.
640"""
641if (state_dict_rank0_and_offload and state_dict_type != "state_dict") or (642use_orig_params and state_dict_type not in _UNFLATTENED_STATE_DICT_IMPLS643):644return # not supported645mixed_precision = (646MixedPrecision(647param_dtype=torch.float16,648reduce_dtype=torch.float16,649buffer_dtype=torch.float16,650)651if mixed_precision652else None653)654model_call = partial(655self._get_multibuffer_nested_model,656cpu_offload=cpu_offload,657use_orig_params=use_orig_params,658mixed_precision=mixed_precision,659)660model = model_call()661ctx = self._get_state_dict_mgr(662model, state_dict_type, state_dict_rank0_and_offload663)664with ctx:665fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, False)666
667self._validate_state_dict_contents(668model, fsdp_state_dict, state_dict_rank0_and_offload669)670
671model_new = model_call()672if not cpu_offload.offload_params:673model_new = model_new.cuda()674
675# zero the model to ensure parameters are different.676_zero_model(model_new, zero_buffers=True)677self._compare_models(model, model_new, self.assertNotEqual)678
679# Verify parameters are the same in the new model.680if state_dict_rank0_and_offload:681fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)682with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]):683model_new.load_state_dict(fsdp_state_dict, strict=True)684
685self._compare_models(model, model_new, self.assertEqual)686
687@skip_if_lt_x_gpu(2)688@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)689@parametrize("mixed_precision", [True, False])690@parametrize("state_dict_rank0_and_offload", [True, False])691def test_save_and_load_after_forward_state_dict(692self, state_dict_type, mixed_precision, state_dict_rank0_and_offload693):694"""695Test that saving after some training results in params being updated as
696expected.
697"""
698if state_dict_rank0_and_offload and state_dict_type != "state_dict":699return700torch.cuda.set_device(self.rank)701mixed_precision = (702MixedPrecision(703param_dtype=torch.float16,704reduce_dtype=torch.float16,705buffer_dtype=torch.float16,706)707if mixed_precision708else None709)710model = self._get_simple_nested_model(mixed_precision=mixed_precision)711optim = torch.optim.SGD(model.parameters(), lr=0.1)712initial_params = get_full_params(model)713for _ in range(6):714inp = torch.randn(1, 10, device=torch.cuda.current_device())715output = model(*inp)716loss = output.sum()717expected_dtype = torch.float32 if mixed_precision is None else torch.float16718self.assertEqual(expected_dtype, loss.dtype)719loss.backward()720optim.step()721
722trained_params = get_full_params(model)723# Ensure some training occurred724self.assertNotEqual(initial_params, trained_params)725# Save a copy of the state_dict726fsd_mgr = self._get_state_dict_mgr(727model, state_dict_type, state_dict_rank0_and_offload728)729with fsd_mgr:730state_dict = model.state_dict()731if state_dict_type == "state_dict":732state_dict = {k: v.clone() for k, v in state_dict.items()}733else:734for sharded_tensor in state_dict.values():735shard = sharded_tensor._local_shards[0]736shard.tensor = shard.tensor.clone().detach_()737self._validate_state_dict_contents(738model, state_dict, state_dict_rank0_and_offload739)740_zero_model(model)741
742# Ensure checkpointed params have the full param dtype743for tensor in state_dict.values():744self.assertEqual(tensor.dtype, torch.float32)745
746# Load state_dict into zeroed model747if state_dict_rank0_and_offload:748state_dict = self._broadcast_state_dict(model, state_dict)749
750with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):751model.load_state_dict(state_dict, strict=True)752loaded_params = get_full_params(model)753self.assertEqual(loaded_params, trained_params)754
755def _initialize_model(756self,757wrap_fsdp: bool,758wrap_ddp: bool = True,759register_buffers: bool = False,760):761# keep everything deterministic for input data762torch.manual_seed(0)763
764model = Model(wrap_fsdp, register_buffers=register_buffers).cuda()765if wrap_fsdp:766model = FSDP(model)767elif wrap_ddp:768model = DistributedDataParallel(model, device_ids=[self.rank])769return model770
771@staticmethod772def _state_dict(model: Module, state_dict_type: str):773try:774enum_val = STATE_DICT_MAPPING[state_dict_type]775except KeyError as e:776raise ValueError(f"No state_dict type for {state_dict_type}") from e777
778with FSDP.state_dict_type(model, enum_val):779return model.state_dict()780
781@staticmethod782def _load_state_dict(783model: Module, state_dict_type: str, state_dict: Dict[str, Any]784):785try:786enum_val = STATE_DICT_MAPPING[state_dict_type]787except KeyError as e:788raise ValueError(f"No state_dict for {state_dict_type}") from e789
790with FSDP.state_dict_type(model, enum_val):791return model.load_state_dict(state_dict, strict=True)792
793def _dist_train(794self, wrap_fsdp: bool, state_dict_type: str = "", move_to_cpu: bool = False795):796# TODO: Move this test to common_fsdp.797model = self._initialize_model(wrap_fsdp)798optim = SGD(model.parameters(), lr=0.1)799
800in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))801for _ in range(3):802out = model(in_data)803out.sum().backward()804optim.step()805optim.zero_grad()806
807if wrap_fsdp:808blank_model = FSDP(Model(True).cuda())809_zero_model(blank_model)810state_dict = self._state_dict(model, state_dict_type)811if move_to_cpu:812for key in list(state_dict.keys()):813tensor = state_dict[key]814if isinstance(tensor, torch.Tensor):815state_dict[key] = tensor.cpu()816else:817shards = tensor.local_shards()818if shards:819shards[0].tensor = shards[0].tensor.cpu()820
821self._load_state_dict(blank_model, state_dict_type, state_dict)822return get_full_params(blank_model)823else:824return list(model.parameters())825
826@skip_if_lt_x_gpu(2)827@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)828def test_state_dict_save_load_flow(self, state_dict_type):829self.run_subtests(830{"move_to_cpu": [True, False]},831self._test_state_dict_save_load_flow,832state_dict_type=state_dict_type,833)834
835def _test_state_dict_save_load_flow(self, state_dict_type, move_to_cpu):836fsdp_params = self._dist_train(837wrap_fsdp=True,838state_dict_type=state_dict_type,839move_to_cpu=move_to_cpu,840)841ddp_params = self._dist_train(wrap_fsdp=False)842self.assertEqual(ddp_params, fsdp_params)843
844@skip_if_lt_x_gpu(2)845@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)846def test_fsdp_state_dict_keys(self, state_dict_type):847state_dict = self._state_dict(self._initialize_model(True), state_dict_type)848if state_dict_type == "local_state_dict":849self.assertEqual({FLAT_PARAM, f"inner.{FLAT_PARAM}"}, state_dict.keys())850elif state_dict_type in ("state_dict", "sharded_state_dict"):851# Keys should match local model.852local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False)853local_keys = local_model.state_dict().keys()854self.assertEqual(state_dict.keys(), local_keys)855else:856raise NotImplementedError(f"No test for {state_dict_type}!")857
858@skip_if_lt_x_gpu(2)859@parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)860@parametrize("state_dict_rank0_and_offload", [True, False])861@parametrize("fsdp_root", [True, False])862def test_state_dict_load_into_local_module(863self,864state_dict_type,865state_dict_rank0_and_offload,866fsdp_root,867):868"""869Tests that FSDP's state_dict can be loaded into a local model.
870"""
871if state_dict_rank0_and_offload and state_dict_type != "state_dict":872return873if not fsdp_root:874model = self._get_non_fsdp_root_module()875else:876model = self._initialize_model(wrap_fsdp=True, register_buffers=True)877optim = SGD(model.parameters(), lr=0.1)878if not fsdp_root:879in_data = torch.randn(8801, 10, requires_grad=True, device=torch.device("cuda")881)882else:883in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda"))884for _ in range(3):885out = model(in_data)886out.sum().backward()887optim.step()888optim.zero_grad()889
890with FSDP.summon_full_params(model):891fsdp_params = deepcopy(list(model.parameters()))892
893# get FSDP state_dict. Note that by default we return full_state_dict.894sd_mgr = self._get_state_dict_mgr(895model, state_dict_type, state_dict_rank0_and_offload896)897with sd_mgr:898fsdp_state_dict = model.state_dict()899
900ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]901self._validate_state_dict_contents(902model,903fsdp_state_dict,904state_dict_rank0_and_offload,905ignore_keys=ignore_keys,906)907# Create zeroed local model908if not fsdp_root:909blank_local_model = self._get_non_fsdp_root_module(wrap=False)910else:911blank_local_model = self._initialize_model(912wrap_fsdp=False, wrap_ddp=False, register_buffers=True913)914
915# Nothing should be FSDP916for mod in blank_local_model.modules():917self.assertFalse(isinstance(mod, FSDP))918
919for param in blank_local_model.parameters():920with torch.no_grad():921param.zero_()922
923fsdp_state_dict = _gather_state_dict(fsdp_state_dict)924
925# Load fsdp's full state dict into the local and verify params are as926# expected.927if state_dict_rank0_and_offload:928fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict)929
930blank_local_model.load_state_dict(fsdp_state_dict, strict=True)931local_params = list(blank_local_model.parameters())932for fsdp_param, local_param in zip(fsdp_params, local_params):933self.assertEqual(fsdp_param, local_param)934
935@skip_if_lt_x_gpu(2)936@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)937@parametrize("double_nest", [True])938def test_state_dict_skip_module(self, state_dict_type, double_nest):939torch.cuda.set_device(self.rank)940
941def _create_module(wrap_fsdp=True):942LINEAR_SKIP = "linear_skip"943ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else nullcontext()944with ctx:945module = SkipModel(double_nest=double_nest)946# Full name of linear_skip param tensors in SkipModel, as would be947# stored in checkpoint.948linear_skip_tensor_names = [949k
950for k in dict(module.named_parameters()).keys()951if LINEAR_SKIP in k952]953# skip SkipModule954linear_skip = getattr(module, LINEAR_SKIP)955delattr(module, LINEAR_SKIP)956# Wrap FSDP957fsdp = wrap(module)958# reattach959setattr(module, LINEAR_SKIP, linear_skip)960return fsdp, linear_skip_tensor_names961
962fsdp, linear_skip_tensor_names = _create_module()963# Run a forward pass964inp = torch.randn((1, 10), device=torch.cuda.current_device())965loss = fsdp(inp)966loss.sum().backward()967
968with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]):969state_dict = fsdp.state_dict()970if self.rank == 0 and state_dict_type != "local_state_dict":971sd_keys = list(state_dict.keys())972expected = list(SkipModel(double_nest=False).state_dict().keys())973self.assertEqual(sorted(sd_keys), sorted(expected))974# TODO: parameters in linear_skip_tensor_names should not be handled975# by FSDP.state_dict(). Have a check once this is implemented in976# FSDP.state_dict().977
978# Check that it can be loaded into FSDP.979new_fsdp, _ = _create_module()980_zero_model(new_fsdp)981for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):982self.assertNotEqual(p1, p2)983with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]):984if state_dict_type != "local_state_dict":985# FlatParameter has not supported deepcopy yet.986state_dict = deepcopy(state_dict)987new_fsdp.load_state_dict(state_dict, strict=True)988for p1, p2 in zip(fsdp.parameters(), new_fsdp.parameters()):989self.assertEqual(p1, p2)990
991# Test that the checkpoint can be loaded into a local model.992local, _ = _create_module(wrap_fsdp=False)993for param in local.parameters():994with torch.no_grad():995param.zero_()996
997with fsdp.summon_full_params(fsdp):998for p1, p2 in zip(fsdp.parameters(), local.parameters()):999self.assertNotEqual(p1, p2)1000
1001if state_dict_type == "local_state_dict":1002return1003state_dict = _gather_state_dict(state_dict)1004with fsdp.summon_full_params(fsdp):1005if self.rank == 0:1006local.load_state_dict(state_dict, strict=True)1007for p1, p2 in zip(fsdp.parameters(), local.parameters()):1008self.assertEqual(p1, p2)1009
1010@skip_if_lt_x_gpu(2)1011def test_wrong_state_dict_config(self):1012model = FSDP(Model(wrap_fsdp=True).cuda())1013with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"):1014with model.state_dict_type(1015model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()1016):1017pass1018
1019@skip_if_lt_x_gpu(2)1020@parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS)1021@parametrize("prefix", [True, False])1022@parametrize("ignore_inner", [True, False])1023@parametrize("mixed_precision", [True, False])1024def test_state_dict_with_ignored_modules(1025self, state_dict_type, prefix, ignore_inner, mixed_precision1026):1027# Initialize an FSDP-wrapped model with an ignored module that includes1028# both parameters and a buffer1029model = Model(1030wrap_fsdp=True,1031register_buffers=True,1032ignore_inner=ignore_inner,1033mixed_precision=mixed_precision,1034).cuda()1035ignored_modules = [model.outer]1036ignored_tensor_to_tensor_name = {1037model.outer.bias: "outer.bias",1038model.outer.weight: "outer.weight",1039}1040if ignore_inner:1041ignored_tensor_to_tensor_name = {1042**ignored_tensor_to_tensor_name,1043model.inner.bias: "inner.bias",1044model.inner.weight: "inner.weight",1045}1046# Note that when model.inner is not ignored this test also ensures1047# non-ignored buffers are not cloned.1048buffer_to_buffer_name = {1049model.inner.buffer: "inner.buffer",1050model.outer.buffer: "outer.buffer",1051}1052# expect fp16 model.inner.buffer with mixed_precisions1053# expect fp32 sd.inner.buffer after restoring to original precision1054# so skip AssertEqual1055if mixed_precision and not ignore_inner:1056buffer_to_buffer_name.pop(model.inner.buffer)1057
1058fsdp_model = FSDP(1059model,1060ignored_modules=ignored_modules,1061mixed_precision=MixedPrecision(1062param_dtype=torch.float16,1063reduce_dtype=torch.float16,1064buffer_dtype=torch.float16,1065)1066if mixed_precision1067else None,1068)1069prefix_str = "foo." if prefix else ""1070with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]):1071sd1 = _gather_state_dict(fsdp_model.state_dict(prefix=prefix_str))1072with FSDP.summon_full_params(fsdp_model):1073fsdp_params = deepcopy(list(fsdp_model.parameters()))1074# Check that the ignored parameters and all buffers are not cloned1075for tensor, tensor_name in {1076**ignored_tensor_to_tensor_name,1077**buffer_to_buffer_name,1078}.items():1079prefixed_tensor_name = f"{prefix_str}{tensor_name}"1080self.assertTrue(prefixed_tensor_name in sd1)1081self.assertEqual(1082tensor.data_ptr(),1083sd1[prefixed_tensor_name].data_ptr(),1084f"{prefixed_tensor_name}",1085)1086# should not apply mixed_precision to ignored buffers1087for buffer_name in buffer_to_buffer_name.values():1088prefixed_buffer_name = f"{prefix_str}{buffer_name}"1089self.assertTrue(prefixed_buffer_name in sd1)1090self.assertEqual(sd1[prefixed_buffer_name].dtype, torch.float32)1091# Check that the state dict can be loaded into a non-wrapped version of1092# the model1093nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()1094for param in nonwrapped_model.parameters():1095with torch.no_grad():1096param.zero_()1097
1098to_load = {k[len(prefix_str) :]: v for k, v in sd1.items()}1099nonwrapped_model.load_state_dict(to_load, strict=True)1100local_params = list(nonwrapped_model.parameters())1101for fsdp_param, local_param in zip(fsdp_params, local_params):1102self.assertEqual(fsdp_param, local_param)1103# Check that if we save a state dict again, the ignored parameters and1104# buffer still have the same data pointer1105with FSDP.state_dict_type(fsdp_model, STATE_DICT_MAPPING[state_dict_type]):1106sd2 = fsdp_model.state_dict(prefix=prefix_str)1107for tensor, tensor_name in {1108**ignored_tensor_to_tensor_name,1109**buffer_to_buffer_name,1110}.items():1111prefixed_tensor_name = f"{prefix_str}{tensor_name}"1112self.assertTrue(prefixed_tensor_name in sd2)1113self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr())1114self.assertEqual(1115sd1[prefixed_tensor_name].data_ptr(),1116sd2[prefixed_tensor_name].data_ptr(),1117)1118
1119@skip_if_lt_x_gpu(2)1120def test_state_dict_type(self):1121module = SkipModel(double_nest=True)1122with enable_wrap(wrapper_cls=FSDP):1123fsdp = wrap(module)1124with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):1125pass1126for module in FSDP.fsdp_modules(fsdp):1127self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)1128
1129@skip_if_lt_x_gpu(2)1130def test_local_state_dict_with_empty_ranks(self):1131class Model(Module):1132def __init__(self):1133super().__init__()1134self.my_tensor = torch.full((1,), 3.1415926)1135self.my_parameter = nn.Parameter(self.my_tensor)1136
1137def forward(self, x):1138return self.my_parameter1139
1140model = FSDP(Model().cuda())1141with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):1142out = model(None)1143out.backward()1144
1145state_dict = deepcopy(model.state_dict())1146with torch.no_grad():1147with FSDP.summon_full_params(model):1148self.assertEqual(model.my_parameter.item(), 3.1415926)1149model.my_parameter.copy_(torch.full((1,), 1.75).cuda())1150self.assertEqual(model.my_parameter.item(), 1.75)1151model.load_state_dict(state_dict)1152with FSDP.summon_full_params(model):1153self.assertEqual(model.my_parameter.item(), 3.1415926)1154
1155@skip_if_lt_x_gpu(2)1156def test_torch_save_load(self):1157model = Model(wrap_fsdp=True).cuda()1158with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):1159state_dict = model.state_dict()1160checkpoint = io.BytesIO()1161torch.save(state_dict, checkpoint)1162checkpoint.seek(0)1163state_dict_saved = torch.load(checkpoint)1164for k, v in state_dict_saved.items():1165if isinstance(v, ShardedTensor):1166self.assertEqual(1167v._local_shards[0].tensor, state_dict[k]._local_shards[0].tensor1168)1169else:1170self.assertEqual(v, state_dict[k])1171
1172@skip_if_lt_x_gpu(2)1173def test_shared_module_and_shared_parameter(self):1174model = FSDP(TestDummyModel().cuda())1175with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):1176state_dict = model.state_dict()1177self.assertEqual(1178state_dict["random_parameter"], state_dict["shared_parameter"]1179)1180self.assertEqual(state_dict["net2.0.bias"], state_dict["net3.0.bias"])1181self.assertEqual(state_dict["net2.0.weight"], state_dict["net3.0.weight"])1182
1183@skip_if_lt_x_gpu(2)1184def test_full_state_dict_missing_unexpected_keys_cleaned(self):1185model = self._get_simple_nested_model()1186sd = model.state_dict()1187# Create a missing key1188sd.pop(next(iter(sd.keys())))1189# Create an unexpected key1190sd["unexpected"] = torch.ones(1)1191missing, unexpected = model.load_state_dict(sd, strict=False)1192assert len(missing) == 11193assert len(unexpected) == 11194self.assertTrue(FSDP_PREFIX not in missing[0])1195self.assertTrue(FSDP_PREFIX not in unexpected[0])1196
1197@skip_if_lt_x_gpu(2)1198def test_sharded_load_multi_backend_pg(self):1199auto_wrap_policy = ModuleWrapPolicy(1200{TransformerEncoderLayer, TransformerDecoderLayer}1201)1202fsdp_kwargs = {1203"auto_wrap_policy": auto_wrap_policy,1204"use_orig_params": True,1205}1206for load_cpu in [True, False]:1207with self.subTest(load_cpu=load_cpu):1208pg = dist.new_group(backend="cpu:gloo,cuda:nccl")1209fsdp_model = TransformerWithSharedParams.init(1210pg,1211FSDPInitMode.RECURSIVE,1212CUDAInitMode.CUDA_BEFORE,1213fsdp_kwargs,1214)1215FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT)1216sharded = fsdp_model.state_dict()1217param_copy = [t.clone().detach_() for t in fsdp_model.parameters()]1218with torch.no_grad():1219for p in fsdp_model.parameters():1220p.zero_()1221
1222if load_cpu:1223# Offload to CPU to simulate CPU state_dict load1224for k, v in sharded.items():1225sharded[k] = v.cpu()1226
1227fsdp_model.load_state_dict(sharded)1228for p1, p2 in zip(param_copy, fsdp_model.parameters()):1229self.assertEqual(p1, p2, f"not equal: {p1.sum()} vs {p2.sum()}")1230
1231@skip_if_lt_x_gpu(2)1232def test_world_size_one(self):1233my_pg = None1234for i in range(self.world_size):1235pg = dist.new_group(ranks=[i])1236if i == self.rank:1237my_pg = pg1238
1239model = TransformerWithSharedParams.init(1240my_pg,1241FSDPInitMode.RECURSIVE,1242CUDAInitMode.CUDA_BEFORE,1243)1244with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):1245state_dict = model.state_dict()1246model.load_state_dict(state_dict)1247
1248dist.barrier()1249
1250
1251class TestFSDPStateDict4GPUs(FSDPTest):1252@property1253def world_size(self):1254return torch.cuda.device_count()1255
1256@skip_if_lt_x_gpu(4)1257def test_local_state_dict_reshard(self):1258"""1259This test demonstrates the ability to do resharding when using
1260local_state_dict. Although we do not recommend users to use
1261local_state_dict, there are still some corner cases that
1262using local_state_dict is a better solution.
1263"""
1264model = FSDP(Model(wrap_fsdp=True)).cuda()1265optim = torch.optim.SGD(model.parameters(), lr=0.1)1266
1267batch = torch.randn(4, 4, device=torch.cuda.current_device())1268output = model(batch)1269loss = output.sum()1270loss.backward()1271optim.step()1272with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):1273state_dict = model.state_dict()1274
1275rank = dist.get_rank()1276new_pg = dist.new_group(ranks=[0, 1])1277resharded_state_dict = {}1278# Mimic resharding from 4 GPUs to 2 GPUs1279for key, value in state_dict.items():1280if isinstance(value, ShardedTensor):1281full_flat_param = _all_gather_sharded_tensor(value)1282if rank < 2:1283full_numel = full_flat_param.size()1284chunks = full_flat_param.chunk(2)1285flat_param = chunks[rank]1286shard_offset = 0 if rank == 0 else chunks[0].numel()1287local_shards = [1288Shard.from_tensor_and_offsets(flat_param, [shard_offset], rank)1289]1290sharded_tensor = init_from_local_shards(1291local_shards, full_numel, process_group=new_pg1292)1293resharded_state_dict[key] = sharded_tensor1294else:1295if rank < 2:1296resharded_state_dict[key] = value1297
1298if rank < 2:1299model2 = FSDP(1300Model(wrap_fsdp=True, process_group=new_pg), process_group=new_pg1301).cuda()1302with FSDP.state_dict_type(model2, StateDictType.LOCAL_STATE_DICT):1303model2.load_state_dict(resharded_state_dict)1304
1305with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):1306full_state_dict1 = model.state_dict()1307
1308if rank < 2:1309with FSDP.state_dict_type(model2, StateDictType.FULL_STATE_DICT):1310full_state_dict2 = model2.state_dict()1311self.assertEqual(full_state_dict1, full_state_dict2)1312
1313
1314instantiate_parametrized_tests(TestFSDPStateDict)1315
1316if __name__ == "__main__":1317run_tests()1318