pytorch
138 строк · 4.7 Кб
1# Owner(s): ["oncall: distributed"]
2
3import torch4
5import torch.distributed.checkpoint as DCP6import torch.nn as nn7from torch.distributed._shard.sharded_tensor.api import ShardedTensor8from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict9
10from torch.distributed.fsdp import FullyShardedDataParallel as FSDP11from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType12from torch.testing._internal.common_distributed import skip_if_lt_x_gpu13from torch.testing._internal.common_utils import (14instantiate_parametrized_tests,15parametrize,16run_tests,17)
18
19from torch.testing._internal.distributed._tensor.common_dtensor import (20DTensorTestBase,21with_comms,22)
23from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir24
25
26class FsdpOptimStateCheckpoint(DTensorTestBase):27def _create_model(self):28# make weight tensor dim_0 as large as the world size for scaling test29layer1_weight_dim = self.world_size30layer2_weight_dim = self.world_size * 231layer3_weight_dim = self.world_size * 332
33class TestDummyModel(torch.nn.Module):34def __init__(self):35super().__init__()36self.net1 = nn.Sequential(nn.Linear(8, layer1_weight_dim), nn.ReLU())37self.net2 = nn.Sequential(38nn.Linear(layer1_weight_dim, layer2_weight_dim), nn.ReLU()39)40self.net3 = nn.Sequential(41nn.Linear(layer2_weight_dim, layer3_weight_dim), nn.ReLU()42)43
44def forward(self, x):45return self.net3(self.net2(self.net1(x)))46
47def get_input(self):48return torch.rand(8, 8, device="cuda")49
50model = TestDummyModel().cuda()51return model52
53@property54def backend(self):55return "cpu:gloo,cuda:nccl"56
57@with_comms58@skip_if_lt_x_gpu(2)59@with_temp_dir60@parametrize("pass_planner", [True, False])61def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:62CHECKPOINT_DIR = self.temp_dir63planner = DCP.DefaultLoadPlanner() if pass_planner else None64
65model = self._create_model()66model = FSDP(model)67optim = torch.optim.Adam(model.parameters(), lr=0.1)68
69# step ahead to initialize the optimizer70model(model.get_input()).sum().backward()71optim.step()72
73FSDP.set_state_dict_type(74model,75StateDictType.SHARDED_STATE_DICT,76)77optim_osd = FSDP.optim_state_dict(model, optim)78
79state_dict = {80"model": model.state_dict(),81"optim": optim_osd,82}83DCP.save_state_dict(84state_dict=state_dict,85storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),86)87
88# now load the model and ensure the values are the same89model_2 = self._create_model()90model_2 = FSDP(model_2)91optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)92
93FSDP.set_state_dict_type(94model_2,95StateDictType.SHARDED_STATE_DICT,96)97# Adam lazily creates its state98self.assertEqual(0, len(optim_2.state))99
100state_dict = {101"model": model_2.state_dict(),102# cannot load the optimizer together with the model103}104DCP.load_state_dict(105state_dict=state_dict,106storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),107)108model_2.load_state_dict(state_dict["model"])109
110optim_state = load_sharded_optimizer_state_dict(111model_state_dict=state_dict["model"],112optimizer_key="optim",113storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),114planner=planner,115)116flattened_osd = FSDP.optim_state_dict_to_load(117model_2, optim_2, optim_state["optim"]118)119optim_2.load_state_dict(flattened_osd)120osd_after_load = FSDP.optim_state_dict(model_2, optim_2)121
122# Compare optim_state_dict prior to save and after load123before_optim_state = optim_osd["state"]124after_optim_state = osd_after_load["state"]125self.assertEqual(len(before_optim_state), len(after_optim_state))126for fqn, states in before_optim_state.items():127for state_name, state in states.items():128state2 = after_optim_state.get(fqn).get(state_name)129if isinstance(state, ShardedTensor):130self.assertTrue(isinstance(state2, ShardedTensor))131self.assertTrue(torch.allclose(state, state2))132else:133self.assertEqual(state, state2)134
135
136instantiate_parametrized_tests(FsdpOptimStateCheckpoint)137if __name__ == "__main__":138run_tests()139