pytorch
2014 строк · 78.1 Кб
1# Owner(s): ["oncall: distributed"]
2
3import bisect4import sys5from copy import deepcopy6from enum import auto, Enum7from typing import Any, Callable, Dict, List, Optional, Tuple, Type8
9import torch10import torch.nn as nn11from torch import distributed as dist12from torch.distributed._shard.sharded_tensor import ShardedTensor13from torch.distributed._state_dict_utils import _gather_state_dict14from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (15_CHECKPOINT_WRAPPED_MODULE,16apply_activation_checkpointing,17)
18from torch.distributed.fsdp import FullyShardedDataParallel as FSDP19from torch.distributed.fsdp.api import ShardingStrategy20from torch.distributed.fsdp.fully_sharded_data_parallel import (21FullOptimStateDictConfig,22FullStateDictConfig,23OptimStateKeyType,24ShardedOptimStateDictConfig,25ShardedStateDictConfig,26StateDictSettings,27StateDictType,28)
29from torch.distributed.optim import _NamedOptimizer30from torch.testing._internal.common_distributed import skip_if_lt_x_gpu31from torch.testing._internal.common_fsdp import (32CUDAInitMode,33FSDPInitMode,34FSDPTest,35TransformerWithSharedParams,36)
37from torch.testing._internal.common_utils import (38instantiate_parametrized_tests,39parametrize,40run_tests,41TEST_WITH_DEV_DBG_ASAN,42)
43
44STATE_DICT_TYPES = [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT]45
46if not dist.is_available():47print("Distributed not available, skipping tests", file=sys.stderr)48sys.exit(0)49
50if TEST_WITH_DEV_DBG_ASAN:51print(52"Skip dev-asan as torch + multiprocessing spawn have known issues",53file=sys.stderr,54)55sys.exit(0)56
57
58class _OSDCommMethod(Enum):59"""Method for communicating the optimizer state dict for internal tests."""60
61BROADCAST_OBJECT_LIST = auto()62SCATTER_FULL_OSD = auto()63FLATTEN_SHARDED_OSD = auto()64OPTIM_STATE_DICT = auto()65
66
67class _ModelClass(Enum):68"""Different model type to test."""69
70NESTED = auto()71TRANSFORMER = auto()72
73
74class Bias(torch.nn.Module):75"""This module applies a 1D additive bias with dimension ``dim``."""76
77def __init__(self, dim: int) -> None:78super().__init__()79assert dim > 080torch.manual_seed(0)81self.bias = torch.nn.Parameter(torch.randn((dim,)))82
83def forward(self, x):84return x + self.bias85
86
87class BlockA(torch.nn.Module):88"""89Used to define interesting nested structure for FSDP wrapping.
90BlockA
91Bias0
92bias
93weight
94Bias1
95bias
96"""
97
98def __init__(self, in_dim: int, out_dim: int) -> None:99super().__init__()100assert all(v > 0 for v in (in_dim, out_dim))101torch.manual_seed(0)102self.bias_module0 = Bias(out_dim)103self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))104self.bias_module1 = Bias(out_dim)105self.relu = torch.nn.ReLU()106
107def forward(self, x):108x = x @ self.weight109x = self.bias_module0(x)110x = self.relu(x) # ensure biases have different gradients111x = self.bias_module1(x)112return x113
114
115class BlockB(torch.nn.Module):116"""117Used to define interesting nested structure for FSDP wrapping.
118BlockB
119weight
120Bias
121bias
122Bias
123bias
124"""
125
126def __init__(self, in_dim: int, out_dim: int) -> None:127super().__init__()128assert all(v > 0 for v in (in_dim, out_dim))129torch.manual_seed(0)130self.weight = torch.nn.Parameter(torch.randn((in_dim, out_dim)))131self.bias_module0 = Bias(out_dim)132self.bias_module1 = Bias(out_dim)133self.relu = torch.nn.ReLU()134
135def forward(self, x):136x = x @ self.weight137x = self.bias_module0(x)138x = self.relu(x) # ensure biases have different gradients139x = self.bias_module1(x)140return x141
142
143class NestedModel(torch.nn.Module):144def __init__(self) -> None:145super().__init__()146self.block0 = BlockB(5, 3)147self.block1 = BlockB(3, 7)148self.bias = torch.nn.Parameter(torch.randn((5,)))149self.block2 = torch.nn.Sequential(150BlockA(7, 9),151BlockA(9, 9),152BlockB(9, 5),153)154self.relu = torch.nn.ReLU()155
156def forward(self, x) -> torch.Tensor:157x = self.relu(self.block0(x))158x = self.relu(self.block1(x))159x = self.relu(self.block2(x))160x = x + self.bias161return x162
163def get_input(self, device):164BATCH_SIZE = 8165return (torch.randn((BATCH_SIZE, 5)).to(device),)166
167def get_loss(self, inp, output):168return output.sum()169
170def run_backward(self, loss):171loss.backward()172
173@staticmethod174def wrap(175model: torch.nn.Module,176group: Optional[dist.ProcessGroup] = None,177ignore_modules: bool = False,178fsdp_kwargs: Optional[Dict[str, Any]] = None,179) -> torch.nn.Module:180if fsdp_kwargs is None:181fsdp_kwargs = {}182# Flatten Bias0; then flatten weight and Bias1 together into `block1`183model.block1.bias_module0 = FSDP(184model.block1.bias_module0,185process_group=group,186**fsdp_kwargs,187)188model.block1 = FSDP(model.block1, process_group=group, **fsdp_kwargs)189# Flatten Bias0; flatten Bias1; then flatten weight into `block2[1]`190model.block2[1].bias_module0 = FSDP(191model.block2[1].bias_module0,192process_group=group,193**fsdp_kwargs,194)195model.block2[1].bias_module1 = FSDP(196model.block2[1].bias_module1,197process_group=group,198**fsdp_kwargs,199)200model.block2[1] = FSDP(model.block2[1], process_group=group, **fsdp_kwargs)201# Flatten weight, Bias, bias into `block2[2]`202ignored_modules = [model.block2[2].bias_module0] if ignore_modules else None203model.block2[2] = FSDP(204model.block2[2],205process_group=group,206ignored_modules=ignored_modules,207**fsdp_kwargs,208)209return model210
211@staticmethod212def wrap_alt(213model: torch.nn.Module,214group: Optional[dist.ProcessGroup] = None,215fsdp_kwargs: Optional[Dict[str, Any]] = None,216) -> torch.nn.Module:217if fsdp_kwargs is None:218fsdp_kwargs = {}219model.block0.bias_module0 = FSDP(220model.block0.bias_module0,221process_group=group,222**fsdp_kwargs,223)224model.block0 = FSDP(model.block0, process_group=group, **fsdp_kwargs)225return model226
227@staticmethod228def wrap_with_unmanaged_params(229model,230add_to_fsdp_module: bool,231group=None,232) -> Tuple[torch.nn.Module, List[torch.nn.Parameter]]:233"""Registers unmanaged parameters before wrapping with :meth:`wrap`."""234device = next(model.parameters()).device235unmanaged_param = torch.nn.Parameter(torch.randn(5, 5, device=device))236# Either register the parameter to a module to be wrapped with FSDP237# (`model.block2[2]`) or a module not to be wrapped with FSDP (`model`)238register_module = model.block2[2] if add_to_fsdp_module else model239register_module.register_parameter(240"unmanaged_param",241unmanaged_param,242)243# For simplicity, we only add a single unmanaged parameter, but should244# be easy to generalize if needed245return NestedModel.wrap(model, group), [unmanaged_param]246
247@staticmethod248def add_unmanaged_param_entry(osd, unmanaged_param, step) -> None:249"""Adds an entry for the unmanaged parameter ``unmanaged_param``250assuming Adam optimizer and a single parameter group."""
251# The unmanaged parameters should be passed to this method in252# `model.parameters()` order since their parameter IDs will be assigned253# in order of the skipped IDs254# Assign a parameter ID to the unmanaged parameter255unmanaged_param_id = -1256param_ids = osd["param_groups"][0]["params"]257for i in range(1, len(param_ids)):258diff = param_ids[i] - param_ids[i - 1]259if diff != 1:260assert diff > 1, f"Invalid IDs: {param_ids[i - 1]} {param_ids[i]}"261unmanaged_param_id = param_ids[i - 1] + 1262break263if unmanaged_param_id == -1:264unmanaged_param_id = len(param_ids) # last ID skipped265assert unmanaged_param_id >= 0, "One parameter ID should be skipped"266# Add a state entry for the unmanaged parameter267state_device = next(iter(next(iter(osd["state"].values())).values())).device268osd["state"][unmanaged_param_id] = {269"step": torch.tensor(float(step), device=state_device),270"exp_avg": torch.randn(unmanaged_param.shape, device=state_device),271"exp_avg_sq": torch.randn(unmanaged_param.shape, device=state_device),272}273# Insert the ID into the parameter group in order274bisect.insort(osd["param_groups"][0]["params"], unmanaged_param_id)275
276# NOTE: We exclude `self.bias` from either parameter group to test the277# case where the optimizer input does not include all model parameters278def param_group0(self) -> List[torch.nn.Parameter]:279# Use `block1`'s parameters for the first parameter group to deviate280# from the `model.parameters()` order281return list(self.block1.parameters())282
283def param_group1(self) -> List[torch.nn.Parameter]:284# Deviate from the `model.parameters()` order further by rearranging285# `block2`'s parameters to be before `block0`'s parameters286return list(self.block2.parameters()) + list(self.block0.parameters())287
288
289# Simple and boring model to test interface and some corner cases that do not
290# require complicated wrapping strategy.
291class TestDummyModel(torch.nn.Module):292def __init__(self, no_grad: bool = False):293super().__init__()294torch.manual_seed(0)295self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())296self.net1[0].weight.requires_grad = not no_grad297self.net1[0].bias.requires_grad = not no_grad298self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())299self.net3 = nn.Linear(32, 64)300self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))301
302def forward(self, x):303return self.net4(self.net3(self.net2(self.net1(x))))304
305def get_input(self):306return torch.rand(8, 8, device="cuda")307
308
309class TestFSDPOptimState(FSDPTest):310def __init__(self, *args, **kwargs):311super().__init__(*args, **kwargs)312self._model_class = {313_ModelClass.NESTED: self._init_nested_model,314_ModelClass.TRANSFORMER: self._init_transformer_model,315}316
317def _init_nested_model(318self,319wrap: bool,320wrap_alt: bool = False, # ignored if `wrap=False`321device: torch.device = torch.device("cuda"),322group=None,323optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,324use_multiple_param_groups: bool = False,325use_diff_optim_inputs: bool = False,326fsdp_kwargs: Optional[Dict[str, Any]] = None,327):328model = NestedModel().to(device)329if wrap:330model = (331NestedModel.wrap_alt(model, group, fsdp_kwargs)332if wrap_alt333else NestedModel.wrap(model, group, fsdp_kwargs=fsdp_kwargs)334)335if not use_multiple_param_groups:336optim_input = list(model.parameters())337else:338optim_input = [339{"params": model.param_group0()},340{"params": model.param_group1(), "weight_decay": 0.9},341]342# Use a reversed parameter order for the optimizer input on odd ranks343if use_diff_optim_inputs and self.rank % 2 == 1:344if isinstance(optim_input[0], dict):345for param_group in optim_input:346param_group["params"] = list(reversed(param_group["params"]))347else:348optim_input = list(reversed(optim_input))349optim = optim_class(optim_input, lr=0.01)350return model, optim, optim_input351
352def _init_transformer_model(353self,354wrap: bool,355device: torch.device = torch.device("cuda"),356group=None,357optim_class: Type[torch.optim.Optimizer] = torch.optim.Adam,358use_multiple_param_groups: bool = False,359use_diff_optim_inputs: bool = False,360):361if use_multiple_param_groups or use_diff_optim_inputs:362# Keep these as arguments for parity with `_init_nested_model()`;363# these settings are not implemented since the transformer is364# wrapped with FSDP at the top-level, which means that there is365# only a single flat parameter, making these booleans vacuous366raise NotImplementedError()367if group is None:368group = dist.distributed_c10d._get_default_group()369model = TransformerWithSharedParams.init(370group,371FSDPInitMode.RECURSIVE if wrap else FSDPInitMode.NO_FSDP,372CUDAInitMode.CUDA_BEFORE,373deterministic=True,374)375optim = optim_class(model.parameters(), lr=0.01)376return model, optim, None377
378def _step_model(379self,380model: torch.nn.Module,381optim: torch.optim.Optimizer,382device: torch.device = torch.device("cuda"),383num_iters: int = 1,384) -> List[float]:385"""Performs a forward pass, backward pass, and optimizer step386``num_iters``-many times, and returns the per-iteration losses."""
387torch.manual_seed(0) # set seed for determinism388losses = []389module = getattr(model, "module", model)390for _ in range(num_iters):391optim.zero_grad()392inp = module.get_input(device)393output = model(*inp)394loss = module.get_loss(inp, output).to(device)395losses.append(loss.item())396module.run_backward(loss)397optim.step()398return losses399
400def _broadcast_full_osd(self, full_osd: Dict[str, Any], group=None):401"""Broadcasts the full optimizer state dict in place of using402``torch.save()`` and ``torch.load()`` so that all ranks can have it."""
403obj_list = [full_osd]404dist.broadcast_object_list(405obj_list,406src=0,407group=group,408)409full_osd = obj_list[0]410return full_osd411
412def _are_equal_states(413self,414state1: Dict[str, Any],415state2: Dict[str, Any],416) -> bool:417"""Checks if ``state1`` and ``state2`` contain the same mappings."""418if set(state1.keys()) != set(state2.keys()):419return False420for state_name, value1 in state1.items():421value2 = state2[state_name]422if type(value1) != type(value2):423return False424if torch.is_tensor(value1): # tensor state425assert torch.is_tensor(value2)426# Check the values on CPU to be device-agnostic427value1 = value1.cpu()428value2 = value2.cpu()429if value1.shape != value2.shape or not torch.all(430torch.isclose(value1, value2)431):432return False433else: # non-tensor state434if value1 != value2:435return False436return True437
438def _check_same_state(439self,440fsdp_osd,441ref_osd,442check_same_param_keys: bool,443):444"""Checks that ``full_osd`` and ``ref_osd`` have the same "state" part.445If ``check_same_param_keys=True``, then checks that the parameter keys
446match (e.g. when both should be parameter names), and does not check
447the parameter keys otherwise."""
448assert "state" in ref_osd449self.assertTrue("state" in fsdp_osd)450ref_osd_state = ref_osd["state"]451fsdp_osd_state = {452k: _gather_state_dict(v) for k, v in fsdp_osd["state"].items()453}454
455if check_same_param_keys:456# Check parameter keys are the same first for earlier erroring457ref_osd_param_ids = set(ref_osd_state.keys())458fsdp_osd_param_ids = set(fsdp_osd_state.keys())459self.assertTrue(460ref_osd_param_ids == fsdp_osd_param_ids,461f"Rank {self.rank}: {(ref_osd_param_ids, fsdp_osd_param_ids)}",462)463# Check state values are the same464for param_id, param_state in fsdp_osd_state.items():465for state_name, value in param_state.items():466ref_value = ref_osd_state[param_id][state_name]467self.assertEqual(value, ref_value)468return469# Otherwise, only require the parameter keys to be isomorphic (e.g.470# between IDs and names)471ref_osd_states = list(ref_osd_state.values())472fsdp_osd_states = list(fsdp_osd_state.values())473self.assertEqual(len(ref_osd_states), len(fsdp_osd_states))474# Use brute-force quadratic-time comparison since it is hard to475# hash a tensor by value instead of by object476for fsdp_osd_state in fsdp_osd_states:477# Check for at least one match (may be > 1 in toy edge cases, e.g.478# multiple biases); nonetheless, each having >= 1 match and the two479# lists having equal length imply that the list contents are equal480self.assertTrue(481any(482self._are_equal_states(fsdp_osd_state, ref_osd_state)483for ref_osd_state in ref_osd_states484)485)486
487def _check_same_param_groups(488self,489full_osd,490ref_osd,491check_same_param_keys: bool,492):493"""Checks that ``full_osd`` and ``ref_osd`` have the same494"param_groups" part. If ``check_same_param_keys=True`, then checks that
495the parameter keys match (e.g. when both should be parameter names),
496and does not check the parameter keys otherwise."""
497assert "param_groups" in ref_osd498self.assertTrue("param_groups" in full_osd)499ref_osd_param_groups = ref_osd["param_groups"]500full_osd_param_groups = full_osd["param_groups"]501self.assertTrue(len(full_osd_param_groups), len(ref_osd_param_groups))502for full_osd_pg, ref_osd_pg in zip(503full_osd_param_groups,504ref_osd_param_groups,505):506self.assertEqual(507set(full_osd_pg.keys()),508set(ref_osd_pg.keys()),509)510for name, full_osd_value in full_osd_pg.items():511if name == "params" and not check_same_param_keys:512continue513self.assertEqual(full_osd_value, ref_osd_pg[name])514
515@skip_if_lt_x_gpu(2)516@parametrize("state_dict_type", STATE_DICT_TYPES)517@parametrize("use_multiple_param_groups", [False, True])518@parametrize("rank0_only", [False, True])519@parametrize("use_diff_optim_inputs", [False, True])520def test_optim_state_dict_nested(521self,522state_dict_type: StateDictType,523use_multiple_param_groups: bool,524rank0_only: bool,525use_diff_optim_inputs: bool,526) -> None:527"""528Tests :meth:`full_optim_state_dict` and meth:`sharded_optim_state_dict`
529by comparing the returned dict for an FSDP-wrapped model with that of
530an equivalent non-wrapped model.
531
532The test checks the equivalence excluding the parameter keys since the
533FSDP and normal optimizer state dicts key by names and IDs,
534respectively. This means that the test can pass even if parameter keys
535are incorrectly mapped to values. Their correct mapping is tested in
536other tests that exercise the save/load workflow.
537"""
538self.run_subtests(539{"use_optim_input": [False, True]},540self._test_optim_state_dict_nested,541state_dict_type=state_dict_type,542use_multiple_param_groups=use_multiple_param_groups,543rank0_only=rank0_only,544use_diff_optim_inputs=use_diff_optim_inputs,545)546
547def _test_optim_state_dict_nested(548self,549state_dict_type: StateDictType,550use_multiple_param_groups: bool,551rank0_only: bool,552use_diff_optim_inputs: bool,553use_optim_input: bool,554) -> None:555if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT:556return # not supported557NUM_ITERS = 3558model1, optim1, optim_input = self._init_nested_model(559wrap=True,560use_multiple_param_groups=use_multiple_param_groups,561use_diff_optim_inputs=use_diff_optim_inputs,562)563losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)564if state_dict_type == StateDictType.FULL_STATE_DICT:565if use_optim_input:566fsdp_osd = FSDP.full_optim_state_dict(567model1,568optim1,569optim_input,570rank0_only=rank0_only,571)572else:573fsdp_osd = FSDP.full_optim_state_dict(574model1,575optim1,576rank0_only=rank0_only,577)578else:579fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)580# Non-target ranks get an empty state dict581if rank0_only and self.rank != 0:582self.assertEqual(len(fsdp_osd), 0)583return584model2, optim2, _ = self._init_nested_model(585wrap=False,586use_multiple_param_groups=use_multiple_param_groups,587use_diff_optim_inputs=use_diff_optim_inputs,588)589losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)590ref_osd = optim2.state_dict()591# Check the losses to eliminate model drift as a source of error592for i, (l1, l2) in enumerate(zip(losses1, losses2)):593assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"594# Do not check the parameter keys since the full/sharded optimizer state595# dict uses parameter names, while the non-wrapped equivalent uses596# parameter IDs597check_same_param_keys = False598self._check_same_param_groups(599fsdp_osd,600ref_osd,601check_same_param_keys=check_same_param_keys,602)603self._check_same_state(604fsdp_osd,605ref_osd,606check_same_param_keys=check_same_param_keys,607)608
609@skip_if_lt_x_gpu(2)610def test_full_optim_state_dict_keys(self):611"""Tests that the parameter keys returned by612:meth:`full_optim_state_dict` match those of :meth:`state_dict` with
613full ``state_dict_type`` for a non-FSDP-root model with nested FSDP
614instances and ignored modules."""
615device = torch.device("cuda")616model = NestedModel().to(device)617wrapped_model = NestedModel.wrap(model, ignore_modules=True)618# Add checkpointing to ensure optim_state_dict and state_dict strip out619# checkpointing prefixes.620apply_activation_checkpointing(621model, check_fn=lambda module: isinstance(module, torch.nn.Sequential)622)623optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)624self._step_model(model, optim, device)625optim_state_dict = FSDP.full_optim_state_dict(626wrapped_model, optim, rank0_only=False627)628with FSDP.state_dict_type(wrapped_model, StateDictType.FULL_STATE_DICT):629state_dict = wrapped_model.state_dict()630self.assertEqual(optim_state_dict["state"].keys(), state_dict.keys())631# Check that checkpointing prefix was indeed stripped.632for key in optim_state_dict["state"]:633self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, key)634
635@skip_if_lt_x_gpu(2)636def test_full_optim_state_dict_nested_invalid(self):637"""Tests that :meth:`full_optim_state_dict` raises an error when638nonzero ranks are missing the optimizer state for parameters on rank
6390."""
640device = torch.device("cuda")641model = NestedModel.wrap(NestedModel().to(device), None)642optim_input = list(model.parameters())643if self.rank != 0:644# Exclude a parameter so that nonzero ranks are missing state645optim_input = optim_input[:-1]646optim = torch.optim.Adam(optim_input, lr=1e-3)647self._step_model(model, optim, num_iters=3)648error_regex = (649"FSDP currently requires each rank to have at least the "650"optimizer states needed by rank 0's optimizer but some ranks "651"are missing some of those states"652)653with self.assertRaisesRegex(RuntimeError, error_regex):654FSDP.full_optim_state_dict(model, optim)655
656@skip_if_lt_x_gpu(2)657@parametrize("use_multiple_param_groups", [False, True])658@parametrize("wrap_alt", [False, True])659@parametrize("use_diff_optim_inputs", [False, True])660def test_shard_full_optim_state_dict_nested(661self,662use_multiple_param_groups: bool,663wrap_alt: bool,664use_diff_optim_inputs: bool,665):666"""Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model667with nested FSDP instances."""
668self.run_subtests(669{"use_optim_input": [False, True]},670self._test_load_optim_state,671model_class=_ModelClass.NESTED,672use_multiple_param_groups=use_multiple_param_groups,673halve_world_size=False,674osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,675use_diff_optim_inputs=use_diff_optim_inputs,676wrap_alt=wrap_alt,677num_iters=3,678)679
680self._test_load_optim_state_with_optim_state_dict(681_ModelClass.NESTED,682state_dict_settings=StateDictSettings(683StateDictType.FULL_STATE_DICT,684FullStateDictConfig(),685FullOptimStateDictConfig(),686),687use_multiple_param_groups=False,688halve_world_size=False,689use_diff_optim_inputs=use_diff_optim_inputs,690wrap_alt=wrap_alt,691num_iters=3,692)693
694@skip_if_lt_x_gpu(2)695def test_shard_full_optim_state_dict_nested_halve_world_size(self):696"""Tests :meth:`shard_full_optim_state_dict` for a non-FSDP-root model697with nested FSDP instances when loading into a new process group with
698halved world size."""
699# To save CI costs, we test with the "harder" settings:700use_multiple_param_groups = True701use_diff_optim_inputs = True702wrap_alt = True703self.run_subtests(704{"use_optim_input": [False, True]},705self._test_load_optim_state,706model_class=_ModelClass.NESTED,707use_multiple_param_groups=use_multiple_param_groups,708halve_world_size=True,709osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,710use_diff_optim_inputs=use_diff_optim_inputs,711wrap_alt=wrap_alt,712num_iters=3,713)714
715self._test_load_optim_state_with_optim_state_dict(716_ModelClass.NESTED,717state_dict_settings=StateDictSettings(718StateDictType.FULL_STATE_DICT,719FullStateDictConfig(),720FullOptimStateDictConfig(),721),722use_multiple_param_groups=use_multiple_param_groups,723halve_world_size=True,724use_diff_optim_inputs=use_diff_optim_inputs,725wrap_alt=wrap_alt,726num_iters=3,727)728
729@skip_if_lt_x_gpu(2)730def test_shard_full_optim_state_dict_transformer(self) -> None:731"""Tests :meth:`shard_full_optim_state_dict` for an FSDP-root732transformer model with shared parameters."""
733self.run_subtests(734{"use_optim_input": [False, True]},735self._test_load_optim_state,736model_class=_ModelClass.TRANSFORMER,737use_multiple_param_groups=False,738halve_world_size=True,739osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,740use_diff_optim_inputs=False,741num_iters=3,742)743
744self._test_load_optim_state_with_optim_state_dict(745_ModelClass.TRANSFORMER,746state_dict_settings=StateDictSettings(747StateDictType.FULL_STATE_DICT,748FullStateDictConfig(),749FullOptimStateDictConfig(),750),751use_multiple_param_groups=False,752halve_world_size=True,753use_diff_optim_inputs=False,754num_iters=3,755)756
757@skip_if_lt_x_gpu(2)758@parametrize("use_multiple_param_groups", [False, True])759@parametrize("wrap_alt", [False, True])760@parametrize("use_diff_optim_inputs", [False, True])761def test_scatter_full_optim_state_dict_nested(762self,763use_multiple_param_groups: bool,764wrap_alt: bool,765use_diff_optim_inputs: bool,766):767"""Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root768model with nested FSDP instances."""
769self.run_subtests(770{"use_optim_input": [False, True]},771self._test_load_optim_state,772model_class=_ModelClass.NESTED,773use_multiple_param_groups=use_multiple_param_groups,774halve_world_size=False,775osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,776use_diff_optim_inputs=use_diff_optim_inputs,777wrap_alt=wrap_alt,778num_iters=3,779)780
781self._test_load_optim_state_with_optim_state_dict(782_ModelClass.NESTED,783state_dict_settings=StateDictSettings(784StateDictType.FULL_STATE_DICT,785FullStateDictConfig(),786FullOptimStateDictConfig(rank0_only=True),787),788use_multiple_param_groups=use_multiple_param_groups,789halve_world_size=False,790use_diff_optim_inputs=use_diff_optim_inputs,791wrap_alt=wrap_alt,792num_iters=3,793)794
795@skip_if_lt_x_gpu(2)796def test_scatter_full_optim_state_dict_nested_halve_world_size(self):797"""Tests :meth:`scatter_full_optim_state_dict` for a non-FSDP-root798model with nested FSDP instances when loading into a new process group
799with halved world size."""
800# To save CI costs, we test with the "harder" settings:801use_multiple_param_groups = True802use_diff_optim_inputs = True803wrap_alt = True804self.run_subtests(805{"use_optim_input": [False, True]},806self._test_load_optim_state,807model_class=_ModelClass.NESTED,808use_multiple_param_groups=use_multiple_param_groups,809halve_world_size=True,810osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,811use_diff_optim_inputs=use_diff_optim_inputs,812wrap_alt=wrap_alt,813num_iters=3,814)815
816self._test_load_optim_state_with_optim_state_dict(817_ModelClass.NESTED,818state_dict_settings=StateDictSettings(819StateDictType.FULL_STATE_DICT,820FullStateDictConfig(),821FullOptimStateDictConfig(rank0_only=True),822),823use_multiple_param_groups=use_multiple_param_groups,824halve_world_size=True,825use_diff_optim_inputs=use_diff_optim_inputs,826wrap_alt=wrap_alt,827num_iters=3,828)829
830@skip_if_lt_x_gpu(2)831def test_scatter_full_optim_state_dict_transformer(self) -> None:832"""Tests :meth:`scatter_full_optim_state_dict` for an FSDP-root833transformer model with shared parameters."""
834self.run_subtests(835{"use_optim_input": [False, True]},836self._test_load_optim_state,837model_class=_ModelClass.TRANSFORMER,838use_multiple_param_groups=False,839halve_world_size=True,840osd_comm_method=_OSDCommMethod.SCATTER_FULL_OSD,841use_diff_optim_inputs=False,842num_iters=3,843)844
845self._test_load_optim_state_with_optim_state_dict(846_ModelClass.TRANSFORMER,847state_dict_settings=StateDictSettings(848StateDictType.FULL_STATE_DICT,849FullStateDictConfig(),850FullOptimStateDictConfig(rank0_only=True),851),852use_multiple_param_groups=False,853halve_world_size=True,854use_diff_optim_inputs=False,855num_iters=3,856)857
858@skip_if_lt_x_gpu(2)859def test_flatten_sharded_optim_state_dict_nested(self) -> None:860"""Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root861nested model."""
862self._test_load_optim_state(863_ModelClass.NESTED,864use_multiple_param_groups=False,865halve_world_size=False,866osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,867use_diff_optim_inputs=False,868use_optim_input=False,869wrap_alt=True,870num_iters=3,871)872
873self._test_load_optim_state_with_optim_state_dict(874_ModelClass.NESTED,875state_dict_settings=StateDictSettings(876StateDictType.SHARDED_STATE_DICT,877ShardedStateDictConfig(),878ShardedOptimStateDictConfig(),879),880use_multiple_param_groups=False,881halve_world_size=False,882use_diff_optim_inputs=False,883wrap_alt=True,884num_iters=3,885)886
887@skip_if_lt_x_gpu(2)888def test_flatten_sharded_optim_state_dict_transformer(self) -> None:889"""Tests :meth:`flatten_sharded_optim_state_dict` for an FSDP-root890transformer model."""
891self._test_load_optim_state(892_ModelClass.TRANSFORMER,893use_multiple_param_groups=False,894halve_world_size=False,895osd_comm_method=_OSDCommMethod.FLATTEN_SHARDED_OSD,896use_diff_optim_inputs=False,897use_optim_input=False,898num_iters=3,899)900
901self._test_load_optim_state_with_optim_state_dict(902_ModelClass.TRANSFORMER,903state_dict_settings=StateDictSettings(904StateDictType.SHARDED_STATE_DICT,905ShardedStateDictConfig(),906ShardedOptimStateDictConfig(),907),908use_multiple_param_groups=False,909halve_world_size=False,910use_diff_optim_inputs=False,911num_iters=3,912)913
914@skip_if_lt_x_gpu(2)915def test_use_orig_params(self) -> None:916"""Tests :meth:`optim_state_dict` for an FSDP-root nested model."""917self.run_subtests(918{919"halve_world_size": [True, False],920"wrap_alt": [True, False],921},922self._test_load_optim_state_with_optim_state_dict,923model_class=_ModelClass.NESTED,924state_dict_settings=StateDictSettings(925StateDictType.FULL_STATE_DICT,926FullStateDictConfig(),927FullOptimStateDictConfig(),928),929use_multiple_param_groups=False,930use_diff_optim_inputs=False,931num_iters=3,932fsdp_kwargs={"use_orig_params": True},933)934
935self.run_subtests(936{937"halve_world_size": [True, False],938"wrap_alt": [True, False],939},940self._test_load_optim_state_with_optim_state_dict,941model_class=_ModelClass.NESTED,942state_dict_settings=StateDictSettings(943StateDictType.FULL_STATE_DICT,944FullStateDictConfig(),945FullOptimStateDictConfig(rank0_only=True),946),947use_multiple_param_groups=False,948use_diff_optim_inputs=False,949num_iters=3,950fsdp_kwargs={"use_orig_params": True},951)952
953self.run_subtests(954{955"wrap_alt": [True, False],956},957self._test_load_optim_state_with_optim_state_dict,958model_class=_ModelClass.NESTED,959state_dict_settings=StateDictSettings(960StateDictType.SHARDED_STATE_DICT,961ShardedStateDictConfig(),962ShardedOptimStateDictConfig(),963),964use_multiple_param_groups=False,965# We cannot test halve_world_size with SHARDED_STATE_DICT.966halve_world_size=False,967use_diff_optim_inputs=False,968num_iters=3,969fsdp_kwargs={"use_orig_params": True},970)971
972def _test_load_optim_state(973self,974model_class: _ModelClass,975use_multiple_param_groups: bool,976halve_world_size: bool,977osd_comm_method: _OSDCommMethod,978use_diff_optim_inputs: bool,979use_optim_input: bool,980num_iters: int,981**new_model_kwargs,982):983"""984(1) Runs a model with full world size for K iterations to generate a
985full/sharded optimizer state dict;
986(2) initializes a model with halved world size and possibly different
987FSDP wrapping scheme (based on ``new_model_kwargs``);
988(3) loads the full/sharded optimizer state dict from (1) according to the
989halved-world-size model;
990(4) runs the halved-world-size model for K iterations; and
991(5) checks that the sharded optimizer state dict from (3) matches the
992halved-world-size model's local optimizer state dict, meaning that the
993former could have equivalently been loaded into the local optimizer.
994"""
995initializer = self._model_class[model_class]996if osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:997osd_method = FSDP.optim_state_dict998elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:999osd_method = FSDP.sharded_optim_state_dict1000else:1001osd_method = FSDP.full_optim_state_dict1002
1003# First, run a wrapped model with full world size for a few iterations1004model1, optim1, optim_input1 = initializer(1005wrap=True,1006use_multiple_param_groups=use_multiple_param_groups,1007)1008self._step_model(model1, optim1, num_iters=num_iters)1009fsdp_osd1 = (1010osd_method(model1, optim1, optim_input1)1011if use_optim_input1012else osd_method(model1, optim1)1013)1014if halve_world_size:1015# Create a new process group with halved world size1016new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]1017new_group = dist.new_group(ranks=new_group_ranks)1018if self.rank not in new_group_ranks:1019return1020else:1021# Continue using the same group and hence world size1022new_group = dist.distributed_c10d._get_default_group()1023# Second, run a wrapped model with (possibly) halved world size and1024# (possibly) differing `optim_input` across ranks1025model2, optim2, optim_input2 = initializer(1026wrap=True,1027group=new_group,1028use_multiple_param_groups=use_multiple_param_groups,1029use_diff_optim_inputs=use_diff_optim_inputs,1030**new_model_kwargs, # specify `wrap_alt` to change wrapping1031)1032self._step_model(model2, optim2, num_iters=num_iters)1033fsdp_osd2 = (1034osd_method(model2, optim2, optim_input2, group=new_group)1035if use_optim_input1036else osd_method(model2, optim2, group=new_group)1037)1038# Compute two sharded optim state dicts: (1) for the first model1039# according to the second model and (2) for the second model according1040# to the second model1041if osd_comm_method == _OSDCommMethod.BROADCAST_OBJECT_LIST:1042fsdp_osd1 = self._broadcast_full_osd(fsdp_osd1, group=new_group)1043sharded_osd1 = (1044FSDP.shard_full_optim_state_dict(1045fsdp_osd1, model2, optim_input=optim_input21046)1047if use_optim_input1048else FSDP.shard_full_optim_state_dict(fsdp_osd1, model2, optim=optim2)1049)1050fsdp_osd2 = self._broadcast_full_osd(fsdp_osd2, group=new_group)1051sharded_osd2 = (1052FSDP.shard_full_optim_state_dict(1053fsdp_osd2, model2, optim_input=optim_input21054)1055if use_optim_input1056else FSDP.shard_full_optim_state_dict(fsdp_osd2, model2, optim=optim2)1057)1058elif osd_comm_method == _OSDCommMethod.SCATTER_FULL_OSD:1059sharded_osd1 = (1060FSDP.scatter_full_optim_state_dict(1061fsdp_osd1 if self.rank == 0 else None,1062model2,1063optim_input=optim_input2,1064group=new_group,1065)1066if use_optim_input1067else FSDP.scatter_full_optim_state_dict(1068fsdp_osd1 if self.rank == 0 else None,1069model2,1070optim=optim2,1071group=new_group,1072)1073)1074sharded_osd2 = (1075FSDP.scatter_full_optim_state_dict(1076fsdp_osd2 if self.rank == 0 else None,1077model2,1078optim_input=optim_input2,1079group=new_group,1080)1081if use_optim_input1082else FSDP.scatter_full_optim_state_dict(1083fsdp_osd2 if self.rank == 0 else None,1084model2,1085optim=optim2,1086group=new_group,1087)1088)1089elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:1090sharded_osd1 = FSDP.flatten_sharded_optim_state_dict(1091fsdp_osd1,1092model2,1093optim=optim2,1094)1095sharded_osd2 = FSDP.flatten_sharded_optim_state_dict(1096fsdp_osd2,1097model2,1098optim=optim2,1099)1100elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:1101sharded_osd1 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd1)1102sharded_osd2 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd2)1103
1104# As a sanity check, check that sharding the second model's full/sharded1105# optimizer state dict according to itself is equivalent to its local1106# optimizer's state dict1107local_osd2 = optim2.state_dict()1108check_same_param_keys = True # should all have matching parameter IDs1109self._check_same_param_groups(1110sharded_osd2,1111local_osd2,1112check_same_param_keys=check_same_param_keys,1113)1114self._check_same_state(1115sharded_osd2,1116local_osd2,1117check_same_param_keys=check_same_param_keys,1118)1119# Check that sharding the first model's full/sharded optimizer state dict1120# according to the second model is equivalent to the second model's1121# local optimizer state dict1122self._check_same_param_groups(1123sharded_osd1,1124local_osd2,1125check_same_param_keys=check_same_param_keys,1126)1127self._check_same_state(1128sharded_osd1,1129local_osd2,1130check_same_param_keys=check_same_param_keys,1131)1132# As a sanity check, check that we can load and run a few iterations1133optim2.load_state_dict(sharded_osd2)1134self._step_model(model2, optim2, num_iters=num_iters)1135
1136@skip_if_lt_x_gpu(2)1137@parametrize("state_dict_type", STATE_DICT_TYPES)1138@parametrize("add_to_fsdp_module", [False, True])1139def test_shard_full_optim_state_dict_unmanaged_params(1140self,1141state_dict_type: StateDictType,1142add_to_fsdp_module: bool,1143):1144"""1145Tests :meth:`shard_full_optim_state_dict` when there are unmanaged
1146parameters.
1147- If ``add_to_fsdp_module=True``, then the unmanaged parameters are
1148added to a module to be wrapped with FSDP, in which case there should
1149be an error since we require that all unflattened parameter
1150comprising a flat parameter have the same scalar state (e.g. Adam
1151"step") but the added parameter is missing its entry.
1152- If ``add_to_fsdp_module=False``, then the unmanaged parameters are
1153added to a module not to be wrapped with FSDP, in which case there
1154should be no error (emulating model parallel use cases where some
1155parameters may be managed externally to FSDP).
1156We do not separately test unmanaged parameters for
1157:meth:`scatter_full_optim_state_dict` and `flatten_sharded_optim_state_dict`
1158to save CI cost since it call into the same subroutine
1159:meth:`_flatten_optim_state_dict`.
1160"""
1161if state_dict_type == StateDictType.SHARDED_STATE_DICT:1162use_optim_input = [False]1163else:1164use_optim_input = [False, True]1165self.run_subtests(1166{"use_optim_input": use_optim_input},1167self._test_shard_full_optim_state_dict_unmanaged_params,1168state_dict_type=state_dict_type,1169add_to_fsdp_module=add_to_fsdp_module,1170)1171
1172def _test_shard_full_optim_state_dict_unmanaged_params(1173self,1174state_dict_type: StateDictType,1175add_to_fsdp_module: bool,1176use_optim_input: bool,1177):1178NUM_ITERS = 11179# Create a normal wrapped model1180model, optim, optim_input = self._init_nested_model(wrap=True)1181self._step_model(model, optim, num_iters=NUM_ITERS)1182
1183if state_dict_type == StateDictType.FULL_STATE_DICT:1184fsdp_osd = (1185FSDP.full_optim_state_dict(model, optim, optim_input, rank0_only=False)1186if use_optim_input1187else FSDP.full_optim_state_dict(model, optim, rank0_only=False)1188) # save on all ranks to avoid having to broadcast from rank 01189else:1190fsdp_osd = FSDP.sharded_optim_state_dict(model, optim)1191# Create a new model with the same structure but additional unmanaged1192# parameters, representing the model for which we want to load1193device = torch.device("cuda")1194model = NestedModel().to(device)1195model, unmanaged_params = NestedModel.wrap_with_unmanaged_params(1196model,1197add_to_fsdp_module,1198)1199optim_input = list(model.parameters())1200optim = torch.optim.Adam(optim_input, lr=1e-3)1201if add_to_fsdp_module:1202# If we add the unmanaged parameters to a module wrapped with FSDP,1203# then the flat parameter will be comprised of some unflattened1204# parameters with zero-dimensional tensor state (i.e. Adam "step")1205# and others without (i.e. the unmanaged parameters), which1206# triggers an error that we have to ensure correctness1207error_prefix = (1208"^(All unflattened parameters comprising a "1209"single flat parameter must have scalar state with the "1210"same value and dtype)"1211)1212with self.assertRaisesRegex(ValueError, error_prefix):1213if state_dict_type == StateDictType.FULL_STATE_DICT:1214(1215FSDP.shard_full_optim_state_dict(1216fsdp_osd, model, optim_input=optim_input1217)1218if use_optim_input1219else FSDP.shard_full_optim_state_dict(1220fsdp_osd, model, optim=optim1221)1222)1223else:1224FSDP.flatten_sharded_optim_state_dict(fsdp_osd, model, optim=optim)1225else:1226# If we add the unmanaged parameters to a module not wrapped with1227# FSDP, then we simply ignore them without erroring to enable1228# model parallelism use cases, where some parameters are managed1229# externally to FSDP1230if state_dict_type == StateDictType.FULL_STATE_DICT:1231flattened_osd = (1232FSDP.shard_full_optim_state_dict(1233fsdp_osd, model, optim_input=optim_input1234)1235if use_optim_input1236else FSDP.shard_full_optim_state_dict(fsdp_osd, model, optim=optim)1237)1238else:1239flattened_osd = FSDP.flatten_sharded_optim_state_dict(1240fsdp_osd, model, optim=optim1241)1242# Add entries for the unmanaged parameters to be able to load1243for unmanaged_param in unmanaged_params:1244NestedModel.add_unmanaged_param_entry(1245flattened_osd,1246unmanaged_param,1247NUM_ITERS,1248)1249# Check that we can load the optimizer state dict1250optim.load_state_dict(flattened_osd)1251
1252@skip_if_lt_x_gpu(2)1253@parametrize("state_dict_type", STATE_DICT_TYPES)1254@parametrize("use_multiple_param_groups", [False, True])1255def test_rekey_optim_state_dict_to_ids(1256self,1257state_dict_type: StateDictType,1258use_multiple_param_groups: bool,1259):1260"""Tests :meth:`rekey_optim_state_dict` with the new keys being1261parameter IDs by checking that a wrapped model (i.e. with FSDP modules)
1262can rekey its optimizer state dict to match that of an equivalent
1263non-wrapped model (i.e. without FSDP modules)."""
1264if state_dict_type == StateDictType.SHARDED_STATE_DICT:1265use_optim_input = [False]1266else:1267use_optim_input = [False, True]1268self.run_subtests(1269{"use_optim_input": use_optim_input},1270self._test_rekey_optim_state_dict_to_ids,1271state_dict_type=state_dict_type,1272use_multiple_param_groups=use_multiple_param_groups,1273)1274
1275@skip_if_lt_x_gpu(2)1276def _test_rekey_optim_state_dict_to_ids(1277self,1278state_dict_type: StateDictType,1279use_multiple_param_groups: bool,1280use_optim_input: bool,1281):1282NUM_ITERS = 31283# Run a wrapped model for a few iterations1284model1, optim1, optim_input1 = self._init_nested_model(1285wrap=True,1286use_multiple_param_groups=use_multiple_param_groups,1287)1288self._step_model(model1, optim1, num_iters=NUM_ITERS)1289if state_dict_type == StateDictType.FULL_STATE_DICT:1290fsdp_osd = (1291FSDP.full_optim_state_dict(model1, optim1, optim_input1)1292if use_optim_input1293else FSDP.full_optim_state_dict(model1, optim1)1294)1295# Broadcast instead of `torch.save()`/`torch.load()` so that all ranks1296# have the full state dict1297fsdp_osd = self._broadcast_full_osd(fsdp_osd)1298else:1299fsdp_osd = FSDP.sharded_optim_state_dict(model1, optim1)1300# Run a non-wrapped model for a few iterations1301model2, optim2, optim_input2 = self._init_nested_model(1302wrap=False,1303use_multiple_param_groups=use_multiple_param_groups,1304)1305self._step_model(model2, optim2, num_iters=NUM_ITERS)1306# Re-key the wrapped model's optimizer state dict using parameter IDs1307# according to the non-wrapped model1308rekeyed_osd = (1309FSDP.rekey_optim_state_dict(1310fsdp_osd,1311OptimStateKeyType.PARAM_ID,1312model2,1313optim_input=optim_input2,1314)1315if use_optim_input1316else FSDP.rekey_optim_state_dict(1317fsdp_osd,1318OptimStateKeyType.PARAM_ID,1319model2,1320optim=optim2,1321)1322)1323# Check that the re-keyed dict and actual dict are the same1324osd = optim2.state_dict()1325check_same_param_keys = True1326self._check_same_param_groups(1327rekeyed_osd,1328osd,1329check_same_param_keys=check_same_param_keys,1330)1331self._check_same_state(1332rekeyed_osd,1333osd,1334check_same_param_keys=check_same_param_keys,1335)1336# As a sanity check, check that we can load and run a few iterations1337if state_dict_type != StateDictType.SHARDED_STATE_DICT:1338optim2.load_state_dict(rekeyed_osd)1339self._step_model(model2, optim2, num_iters=NUM_ITERS)1340
1341@skip_if_lt_x_gpu(2)1342def test_rekey_optim_state_dict_to_names(self):1343"""Tests :meth:`rekey_optim_state_dict` with the new keys being1344parameter names by checking that a non-wrapped model (i.e. without FSDP
1345modules) can rekey its optimizer state dict to match the expected
1346output of :meth:`full_optim_state_dict`, hence be sharded using
1347:meth:`shard_full_optim_state_dict`, and finally match the per-rank
1348optimizer state dict of a wrapped model (i.e. with FSDP modules)."""
1349self.run_subtests(1350{"use_optim_input": [False, True]},1351self._test_rekey_optim_state_dict_to_names,1352use_multiple_param_groups=False,1353)1354
1355def _test_rekey_optim_state_dict_to_names(1356self,1357use_multiple_param_groups: bool,1358use_optim_input: bool,1359):1360NUM_ITERS = 31361# Run a wrapped model for a few iterations1362model1, optim1, optim_input1 = self._init_nested_model(1363wrap=True,1364use_multiple_param_groups=use_multiple_param_groups,1365)1366self._step_model(model1, optim1, num_iters=NUM_ITERS)1367# Run a non-wrapped model for a few iterations1368model2, optim2, optim_input2 = self._init_nested_model(1369wrap=False,1370use_multiple_param_groups=use_multiple_param_groups,1371)1372self._step_model(model2, optim2, num_iters=NUM_ITERS)1373# Re-key the non-wrapped model's optimizer state dict using parameter1374# names (still according to itself)1375osd2 = optim2.state_dict()1376rekeyed_osd = (1377FSDP.rekey_optim_state_dict(1378osd2,1379OptimStateKeyType.PARAM_NAME,1380model2,1381optim_input=optim_input2,1382)1383if use_optim_input1384else FSDP.rekey_optim_state_dict(1385osd2,1386OptimStateKeyType.PARAM_NAME,1387model2,1388optim=optim2,1389)1390)1391# Shard the non-wrapped model's re-keyed optimizer state dict, which1392# maps back to (flattened) parameter IDs1393sharded_osd = (1394FSDP.shard_full_optim_state_dict(1395rekeyed_osd,1396model1,1397optim_input=optim_input1,1398)1399if use_optim_input1400else FSDP.shard_full_optim_state_dict(1401rekeyed_osd,1402model1,1403optim=optim1,1404)1405)1406# Check that this sharded optimizer state dict matches the wrapped1407# model's per-rank optimizer state dict1408osd1 = optim1.state_dict()1409check_same_param_keys = True1410self._check_same_param_groups(1411sharded_osd,1412osd1,1413check_same_param_keys=check_same_param_keys,1414)1415self._check_same_state(1416sharded_osd,1417osd1,1418check_same_param_keys=check_same_param_keys,1419)1420# As a sanity check, check that we can load and run a few iterations1421optim1.load_state_dict(sharded_osd)1422self._step_model(model1, optim1, num_iters=NUM_ITERS)1423
1424@skip_if_lt_x_gpu(2)1425def test_optim_input_warning(self):1426"""Tests that passing the ``optim_input`` argument into optimizer state1427checkpointing APIs issues a warning."""
1428
1429def should_check_method(method_name: str):1430# Check every method since they all accept `optim_input`1431return method_name not in (1432"sharded_optim_state_dict",1433"flatten_sharded_optim_state_dict",1434)1435
1436def get_warning_context():1437warning_regex = "`optim_input` argument is deprecated"1438return self.assertWarnsRegex(1439expected_warning=UserWarning, expected_regex=warning_regex1440)1441
1442self._run_on_all_optim_state_apis(1443should_check_method, get_warning_context, fsdp_kwargs=None1444)1445
1446def _run_on_all_optim_state_apis(1447self,1448should_check_method_fn: Callable[[str], bool],1449context_fn: Callable,1450fsdp_kwargs: Optional[Dict[str, Any]],1451):1452"""1453Runs through all optimizer state checkpointing APIs with a context
1454manager instantiated by ``context_fn``. Certain APIs can be skipped
1455via ``should_check_method_fn``, which gets passed the string name of
1456the method.
1457"""
1458wrapped_model, wrapped_optim, wrapped_optim_input = self._init_nested_model(1459wrap=True,1460use_multiple_param_groups=False,1461fsdp_kwargs=fsdp_kwargs,1462)1463self._step_model(wrapped_model, wrapped_optim, num_iters=2)1464
1465# Sharded optim state dict1466if should_check_method_fn("sharded_optim_state_dict"):1467with context_fn():1468fsdp_osd = FSDP.sharded_optim_state_dict(1469wrapped_model,1470wrapped_optim,1471)1472if "fsdp_osd" not in locals():1473fsdp_osd = {} # may not be defined due to previous method erroring1474if should_check_method_fn("flatten_sharded_optim_state_dict"):1475with context_fn():1476FSDP.flatten_sharded_optim_state_dict(1477fsdp_osd,1478wrapped_model,1479wrapped_optim,1480)1481# Full optim state dict1482if should_check_method_fn("full_optim_state_dict"):1483with context_fn():1484fsdp_osd = FSDP.full_optim_state_dict(1485wrapped_model,1486wrapped_optim,1487optim_input=wrapped_optim_input,1488rank0_only=False,1489)1490if should_check_method_fn("shard_full_optim_state_dict"):1491with context_fn():1492FSDP.shard_full_optim_state_dict(1493fsdp_osd,1494wrapped_model,1495optim_input=wrapped_optim_input,1496)1497if should_check_method_fn("scatter_full_optim_state_dict"):1498with context_fn():1499FSDP.scatter_full_optim_state_dict(1500fsdp_osd,1501wrapped_model,1502optim_input=wrapped_optim_input,1503)1504# Rekey optim state dict1505(1506nonwrapped_model,1507nonwrapped_optim,1508nonwrapped_optim_input,1509) = self._init_nested_model(wrap=False, use_multiple_param_groups=False)1510if should_check_method_fn("rekey_optim_state_dict"):1511with context_fn():1512rekeyed_osd = FSDP.rekey_optim_state_dict(1513fsdp_osd, # from `full_optim_state_dict()`1514OptimStateKeyType.PARAM_ID,1515nonwrapped_model,1516optim_input=nonwrapped_optim_input,1517)1518self._step_model(nonwrapped_model, nonwrapped_optim, num_iters=2)1519osd = nonwrapped_optim.state_dict()1520if should_check_method_fn("rekey_optim_state_dict"):1521with context_fn():1522FSDP.rekey_optim_state_dict(1523osd,1524OptimStateKeyType.PARAM_NAME,1525nonwrapped_model,1526optim_input=nonwrapped_optim_input,1527)1528
1529@skip_if_lt_x_gpu(2)1530@parametrize("state_dict_type", STATE_DICT_TYPES)1531def test_save_load_without_0th_param_state(self, state_dict_type: StateDictType):1532"""1533Tests saving and loading an optim state dict for Adam optimizer (i.e.
1534any optimizer with a "step" key in its state) when the first parameter
1535does not have optimizer state (e.g. unused or frozen).
1536"""
1537
1538class Model(nn.Module):1539def __init__(self) -> None:1540super().__init__()1541self.lin1 = nn.Linear(5, 5)1542self.lin2 = nn.Linear(5, 5)1543self.relu = nn.ReLU()1544
1545def forward(self, x: torch.Tensor) -> torch.Tensor:1546# Do not use `lin1`, which is the parameter passed to the1547# optimizer and the one checked for "step" state to see if it1548# is tensor or float1549return self.relu(self.lin2(x))1550
1551model = Model().cuda()1552model.lin1 = FSDP(model.lin1)1553model.lin2 = FSDP(model.lin2)1554fsdp_model = FSDP(model)1555optim = torch.optim.Adam(1556fsdp_model.parameters(), lr=1e-21557) # or any optimizer with "step"1558
1559# Run an iteration to construct optimizer state1560device = torch.device("cuda")1561inp = torch.randn((2, 5), device=device)1562loss = fsdp_model(inp).sum()1563loss.backward()1564optim.step()1565
1566# Check that save and load does not error1567if state_dict_type == StateDictType.FULL_STATE_DICT:1568fsdp_osd = FSDP.full_optim_state_dict(fsdp_model, optim, rank0_only=False)1569flattened_osd = FSDP.shard_full_optim_state_dict(fsdp_osd, fsdp_model)1570elif state_dict_type == StateDictType.SHARDED_STATE_DICT:1571fsdp_osd = FSDP.sharded_optim_state_dict(fsdp_model, optim)1572flattened_osd = FSDP.flatten_sharded_optim_state_dict(1573fsdp_osd, fsdp_model, optim1574)1575optim.load_state_dict(flattened_osd)1576# `__setstate__()` will check the 0th parameter to see if "step" is1577# represented as a tensor or float, so it is imperative that its state1578# is non-empty.1579
1580# Run an iteration as a sanity check1581inp = torch.randn((2, 5), device=device)1582loss = fsdp_model(inp).sum()1583loss.backward()1584optim.step()1585
1586@skip_if_lt_x_gpu(2)1587def test_compatible_with_trec(self):1588class DenseModel(torch.nn.Module):1589def __init__(self):1590super().__init__()1591self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())1592self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())1593self.net3 = nn.Linear(32, 64)1594self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))1595
1596def forward(self, x):1597return self.net4(self.net3(self.net2(self.net1(x))))1598
1599class FakeMPModel(torch.nn.Module):1600def __init__(self):1601super().__init__()1602torch.manual_seed(0)1603self.dense = FSDP(DenseModel().cuda(), use_orig_params=True)1604if dist.get_rank() == 0:1605self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())1606else:1607self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())1608
1609def forward(self, x):1610if dist.get_rank() == 0:1611sparse = self.sparse0(x)1612else:1613sparse = self.sparse1(x)1614dist.all_reduce(sparse)1615return self.dense(sparse)1616
1617models = [FakeMPModel().cuda(), FakeMPModel().cuda()]1618optims = [1619torch.optim.Adam(models[0].parameters(), lr=1e-2),1620_NamedOptimizer(1621models[1].named_parameters(),1622torch.optim.Adam,1623[{"params": models[1].parameters()}],1624models[1],1625lr=1e-2,1626),1627]1628state_dicts = []1629
1630# Train one batch and see if optim_state_dict are the same.1631batch = torch.rand(5, 8, device=torch.device("cuda"))1632for model, optim in zip(models, optims):1633# Eagerly initialize the states1634for param in model.parameters():1635if param.requires_grad:1636t = torch.zeros_like(param)1637param.grad = torch.autograd.Variable(t)1638optim.step()1639loss = model(batch).sum()1640loss.backward()1641optim.step()1642state_dicts.append(deepcopy(FSDP.optim_state_dict(model, optim)))1643
1644self._check_same_param_groups(1645state_dicts[0], state_dicts[1], check_same_param_keys=False1646)1647self._check_same_state(1648state_dicts[0], state_dicts[1], check_same_param_keys=True1649)1650
1651# Make optim1 has a different state.1652for i in range(5):1653batch = torch.rand(5, 8).cuda()1654loss = models[1](batch).sum()1655loss.backward()1656optims[1].step()1657
1658# Load the state back to see if load_optim_state_dict works.1659state_dict_to_load = FSDP.optim_state_dict_to_load(1660models[1], optims[1], state_dicts[1], is_named_optimizer=True1661)1662optims[1].load_state_dict(state_dict_to_load)1663state_dicts[1] = FSDP.optim_state_dict(models[1], optims[1])1664
1665self._check_same_param_groups(1666state_dicts[0], state_dicts[1], check_same_param_keys=False1667)1668self._check_same_state(1669state_dicts[0], state_dicts[1], check_same_param_keys=True1670)1671
1672@skip_if_lt_x_gpu(2)1673def test_optim_state_without_param_groups(self):1674class SimpleModel(torch.nn.Module):1675def __init__(self):1676super().__init__()1677torch.manual_seed(0)1678self.net1 = nn.Sequential(nn.Linear(2, 4), nn.ReLU())1679
1680def forward(self, x):1681return self.net1(x)1682
1683model = FSDP(SimpleModel().cuda())1684optim = torch.optim.Adam(model.parameters(), lr=1e-2)1685
1686# Train one step to save original optimizer state dict and original optimizer param groups.1687batch = torch.rand(3, 2, device=torch.device("cuda"))1688for param in model.parameters():1689if param.requires_grad:1690t = torch.zeros_like(param)1691param.grad = torch.autograd.Variable(t)1692optim.step()1693loss = model(batch).sum()1694loss.backward()1695
1696original_osd = deepcopy(optim.state_dict())1697original_osd_no_param_groups = deepcopy(original_osd)1698# manually remove param_groups from optimizer state dict1699original_param_groups = deepcopy(1700original_osd_no_param_groups.pop("param_groups")1701)1702# passing the osd without param_groups to FSDP1703original_fsdp_optim_state_dict = deepcopy(1704FSDP.optim_state_dict(1705model, optim, optim_state_dict=original_osd_no_param_groups1706)1707)1708# check the state_dict sharded by FSDP does not contain param_groups.1709self.assertEqual(None, original_fsdp_optim_state_dict.get("param_groups"))1710
1711# train another step to make optim a different state.1712for param in model.parameters():1713if param.requires_grad:1714t = torch.zeros_like(param)1715param.grad = torch.autograd.Variable(t)1716optim.step()1717loss = model(batch).sum()1718loss.backward()1719
1720state_dict_to_load = FSDP.optim_state_dict_to_load(1721model, optim, original_fsdp_optim_state_dict1722)1723# manually add param_groups to state_dict_to_load before loading the optimizer state1724state_dict_to_load["param_groups"] = original_param_groups1725optim.load_state_dict(state_dict_to_load)1726self.assertEqual(original_osd, optim.state_dict())1727
1728fsdp_optim_state = FSDP.optim_state_dict(model, optim)1729self._check_same_state(1730original_fsdp_optim_state_dict, fsdp_optim_state, check_same_param_keys=True1731)1732self.assertEqual(original_param_groups, optim.state_dict()["param_groups"])1733
1734@skip_if_lt_x_gpu(2)1735def test_with_empty_optimizer_state(self):1736model = FSDP(TestDummyModel().cuda())1737optim = torch.optim.Adam(model.parameters(), lr=1e-2)1738state_dict = optim.state_dict()1739gathered_state_dict = FSDP.optim_state_dict(model, optim)1740self.assertEqual(gathered_state_dict["state"], state_dict["state"])1741
1742def _test_load_optim_state_with_optim_state_dict(1743self,1744model_class: _ModelClass,1745state_dict_settings: StateDictSettings,1746use_multiple_param_groups: bool,1747halve_world_size: bool,1748use_diff_optim_inputs: bool,1749num_iters: int,1750**new_model_kwargs,1751):1752"""1753(1) Runs a model with full world size for K iterations to generate a
1754full/sharded optimizer state dict;
1755(2) initializes a model with halved world size and possibly different
1756FSDP wrapping scheme (based on ``new_model_kwargs``);
1757(3) loads the full/sharded optimizer state dict from (1) according to the
1758halved-world-size model;
1759(4) runs the halved-world-size model for K iterations; and
1760(5) checks that the sharded optimizer state dict from (3) matches the
1761halved-world-size model's local optimizer state dict, meaning that the
1762former could have equivalently been loaded into the local optimizer.
1763"""
1764initializer = self._model_class[model_class]1765
1766# First, run a wrapped model with full world size for a few iterations1767model1, optim1, optim_input1 = initializer(1768wrap=True,1769use_multiple_param_groups=use_multiple_param_groups,1770)1771FSDP.set_state_dict_type(1772model1,1773state_dict_settings.state_dict_type,1774state_dict_settings.state_dict_config,1775state_dict_settings.optim_state_dict_config,1776)1777self._step_model(model1, optim1, num_iters=num_iters)1778fsdp_osd1 = FSDP.optim_state_dict(model1, optim1)1779if halve_world_size:1780# Create a new process group with halved world size1781new_group_ranks = [r for r in range(self.world_size) if r % 2 == 0]1782new_group = dist.new_group(ranks=new_group_ranks)1783if self.rank not in new_group_ranks:1784return1785else:1786# Continue using the same group and hence world size1787new_group = dist.distributed_c10d._get_default_group()1788# Second, run a wrapped model with (possibly) halved world size and1789# (possibly) differing `optim_input` across ranks1790model2, optim2, optim_input2 = initializer(1791wrap=True,1792group=new_group,1793use_multiple_param_groups=use_multiple_param_groups,1794use_diff_optim_inputs=use_diff_optim_inputs,1795**new_model_kwargs, # specify `wrap_alt` to change wrapping1796)1797FSDP.set_state_dict_type(1798model2,1799state_dict_settings.state_dict_type,1800state_dict_settings.state_dict_config,1801state_dict_settings.optim_state_dict_config,1802)1803self._step_model(model2, optim2, num_iters=num_iters)1804fsdp_osd2 = FSDP.optim_state_dict(model2, optim2, group=new_group)1805# Compute two sharded optim state dicts: (1) for the first model1806# according to the second model and (2) for the second model according1807# to the second model1808sharded_osd2 = FSDP.optim_state_dict_to_load(1809model2, optim2, fsdp_osd2, group=new_group1810)1811
1812# As a sanity check, check that sharding the second model's full/sharded1813# optimizer state dict according to itself is equivalent to its local1814# optimizer's state dict1815local_osd2 = optim2.state_dict()1816self._check_same_param_groups(1817sharded_osd2,1818local_osd2,1819check_same_param_keys=True,1820)1821self._check_same_state(1822sharded_osd2,1823local_osd2,1824check_same_param_keys=True,1825)1826# Check that sharding the first model's full/sharded optimizer state dict1827# according to the second model is equivalent to the second model's1828# local optimizer state dict1829sharded_osd1 = FSDP.optim_state_dict_to_load(1830model2, optim2, fsdp_osd1, group=new_group1831)1832self._check_same_param_groups(1833sharded_osd1,1834local_osd2,1835check_same_param_keys=True,1836)1837self._check_same_state(1838sharded_osd1,1839local_osd2,1840check_same_param_keys=True,1841)1842# As a sanity check, check that we can load and run a few iterations1843optim2.load_state_dict(sharded_osd2)1844self._step_model(model2, optim2, num_iters=num_iters)1845
1846@skip_if_lt_x_gpu(2)1847def test_interface_arguments(self):1848model = FSDP(TestDummyModel().cuda())1849optim = torch.optim.Adam(model.parameters(), lr=1e-2)1850
1851def step():1852loss = model(model.get_input())1853loss.backward(loss)1854optim.step()1855
1856step()1857original_osd = deepcopy(optim.state_dict())1858osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)1859self._check_same_state(1860FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True1861)1862step()1863osd_to_load = FSDP.optim_state_dict_to_load(1864model, optim, osd, load_directly=True1865)1866self._check_same_state(1867optim.state_dict(), original_osd, check_same_param_keys=True1868)1869
1870# Test the default setting.1871osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)1872for state in osd["state"].values():1873for s in state.values():1874self.assertFalse(isinstance(s, ShardedTensor))1875self.assertFalse(s.is_cuda)1876
1877# Test sharded state_dict without offload_to_cpu1878with FSDP.state_dict_type(1879model,1880StateDictType.SHARDED_STATE_DICT,1881ShardedStateDictConfig(),1882ShardedOptimStateDictConfig(offload_to_cpu=False),1883):1884osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)1885for state in osd["state"].values():1886for s in state.values():1887if s.dim() == 0:1888continue1889self.assertTrue(isinstance(s, ShardedTensor))1890if s._local_shards[0]:1891self.assertTrue(s._local_shards[0].tensor.is_cuda)1892
1893# Test full state_dict with rank0_only1894with FSDP.state_dict_type(1895model,1896StateDictType.FULL_STATE_DICT,1897FullStateDictConfig(),1898FullOptimStateDictConfig(1899offload_to_cpu=True,1900rank0_only=True,1901),1902):1903osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)1904if dist.get_rank() > 0:1905self.assertEqual(osd, {})1906else:1907for state in osd["state"].values():1908for s in state.values():1909if s.dim() == 0:1910continue1911self.assertFalse(s.is_cuda)1912self.assertFalse(isinstance(s, ShardedTensor))1913
1914@skip_if_lt_x_gpu(2)1915def test_state_dict_with_none_tensor_state(self):1916def _run_test(use_orig_params, optimizer_has_tensor_state):1917model = FSDP(TestDummyModel().cuda(), use_orig_params=use_orig_params)1918optimizer_cls = (1919torch.optim.Adam if optimizer_has_tensor_state else torch.optim.SGD1920)1921optim = optimizer_cls(model.parameters(), lr=1e-2)1922
1923def step():1924loss = model(model.get_input())1925loss.backward(loss)1926optim.step()1927
1928step()1929original_osd = deepcopy(optim.state_dict())1930for state in original_osd["state"].values():1931# Add customized value1932state["value1"] = 2.741933state["value2"] = None1934
1935osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)1936osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)1937for state in osd_to_load["state"].values():1938self.assertEqual(state["value1"], 2.74)1939self.assertEqual(state["value2"], None)1940
1941self.run_subtests(1942{1943"use_orig_params": [False, True],1944"optimizer_has_tensor_state": [False, True],1945},1946_run_test,1947)1948
1949@skip_if_lt_x_gpu(2)1950def test_with_no_shard(self):1951def _run_test(use_orig_params: bool) -> None:1952model = FSDP(1953TestDummyModel().cuda(),1954sharding_strategy=ShardingStrategy.NO_SHARD,1955use_orig_params=use_orig_params,1956)1957optim = torch.optim.Adam(model.parameters(), lr=1e-2)1958
1959def step():1960loss = model(model.get_input())1961loss.backward(loss)1962optim.step()1963
1964step()1965
1966original_osd = deepcopy(optim.state_dict())1967
1968osd = FSDP.optim_state_dict(model, optim)1969osd_to_load = FSDP.optim_state_dict_to_load(model, optim, osd)1970optim.load_state_dict(osd_to_load)1971
1972new_osd = optim.state_dict()1973
1974self.assertEqual(original_osd, new_osd)1975
1976self.run_subtests({"use_orig_params": [False, True]}, _run_test)1977
1978@skip_if_lt_x_gpu(2)1979def test_no_grad(self):1980model = TestDummyModel(no_grad=True).cuda()1981fsdp_model = FSDP(deepcopy(model), use_orig_params=True)1982fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2)1983
1984for i in range(5):1985if i % 2 == 1:1986fsdp_model.net1[0].weight.requires_grad = True1987fsdp_model.net1[0].bias.requires_grad = True1988else:1989fsdp_model.net1[0].weight.requires_grad = False1990fsdp_model.net1[0].bias.requires_grad = False1991batch = fsdp_model.get_input()1992loss = fsdp_model(batch).sum()1993loss.backward()1994fsdp_optim.step()1995orig_state_dict = deepcopy(fsdp_optim.state_dict())1996optim_state_dict = FSDP.optim_state_dict(fsdp_model, fsdp_optim)1997FSDP.optim_state_dict_to_load(1998fsdp_model,1999fsdp_optim,2000FSDP.optim_state_dict(fsdp_model, fsdp_optim),2001load_directly=True,2002)2003
2004self._check_same_state(2005fsdp_optim.state_dict(),2006orig_state_dict,2007check_same_param_keys=True,2008)2009
2010
2011instantiate_parametrized_tests(TestFSDPOptimState)2012
2013if __name__ == "__main__":2014run_tests()2015