pytorch
323 строки · 12.6 Кб
1# Owner(s): ["oncall: distributed"]
2
3import io4from copy import deepcopy5
6import torch7import torch.distributed as dist8import torch.nn as nn9from torch.distributed._shard.sharded_tensor import ShardedTensor10
11from torch.distributed._tensor import DTensor, Replicate, Shard12from torch.distributed.device_mesh import init_device_mesh13from torch.distributed.fsdp import FullyShardedDataParallel as FSDP14from torch.distributed.fsdp.api import (15ShardedOptimStateDictConfig,16ShardedStateDictConfig,17ShardingStrategy,18StateDictType,19)
20from torch.testing._internal.common_utils import (21instantiate_parametrized_tests,22parametrize,23run_tests,24)
25
26from torch.testing._internal.distributed._tensor.common_dtensor import (27DTensorTestBase,28skip_if_lt_x_gpu,29with_comms,30)
31
32
33# Simple and boring model to test interface and some corner cases that do not
34# require complicated wrapping strategy.
35class DenseModel(torch.nn.Module):36def __init__(self):37super().__init__()38torch.manual_seed(0)39self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())40self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())41self.net3 = nn.Sequential(nn.Linear(32, 64), nn.ReLU())42self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))43
44def forward(self, x):45return self.net4(self.net3(self.net2(self.net1(x))))46
47def get_input(self):48return torch.rand(4, 8, device="cuda")49
50
51# TODO: Consolidate DeviceMesh based FSDP and HSDP test cases.
52class TestHSDPWithDeviceMeshAndDTensor(DTensorTestBase):53def _create_model(self, device_mesh=None):54if device_mesh:55model = FSDP(56DenseModel().cuda(),57device_mesh=device_mesh,58sharding_strategy=ShardingStrategy.HYBRID_SHARD,59)60else:61mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))62intra_node_pg = mesh_2d.get_group(mesh_dim=1)63inter_node_pg = mesh_2d.get_group(mesh_dim=0)64model = FSDP(65DenseModel().cuda(),66process_group=(intra_node_pg, inter_node_pg),67sharding_strategy=ShardingStrategy.HYBRID_SHARD,68)69
70optim = torch.optim.Adam(model.parameters(), lr=0.1)71model(model.get_input()).sum().backward()72optim.step()73
74return model, optim75
76@with_comms77@skip_if_lt_x_gpu(4)78def test_hsdp_init_with_device_mesh(self):79mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))80model, optim = self._create_model(mesh_2d)81
82FSDP.set_state_dict_type(83model,84StateDictType.SHARDED_STATE_DICT,85)86state_dict = model.state_dict()87optim_state_dict = FSDP.optim_state_dict(model, optim)88
89for v in state_dict.values():90self.assertEqual(type(v), DTensor)91self.assertEqual(len(v.placements), 2)92self.assertEqual(v.placements, (Replicate(), Shard(0)))93self.assertEqual(v.device_mesh, mesh_2d)94
95for state in optim_state_dict["state"].values():96for k, v in state.items():97if k != "step":98self.assertEqual(type(v), DTensor)99self.assertEqual(len(v.placements), 2)100self.assertEqual(v.placements, (Replicate(), Shard(0)))101self.assertEqual(v.device_mesh, mesh_2d)102
103state_dict_type = model.get_state_dict_type(model)104# If device_mesh is used when initializing FSDP, the field _use_dtensor will105# automatically be set to True.106self.assertEqual(state_dict_type.state_dict_config._use_dtensor, True)107self.assertEqual(state_dict_type.optim_state_dict_config._use_dtensor, True)108
109@with_comms110@skip_if_lt_x_gpu(4)111@parametrize("offload_to_cpu", [True, False])112def test_dtensor_sharded_tensor_state_dict_identical(self, offload_to_cpu):113mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))114model, optim = self._create_model(mesh_2d)115
116FSDP.set_state_dict_type(117model,118StateDictType.SHARDED_STATE_DICT,119state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),120optim_state_dict_config=ShardedOptimStateDictConfig(121offload_to_cpu=offload_to_cpu122),123)124dtensor_sd = model.state_dict()125dtensor_osd = FSDP.optim_state_dict(model, optim)126
127ref_model, ref_optim = self._create_model()128FSDP.set_state_dict_type(129ref_model,130StateDictType.SHARDED_STATE_DICT,131state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),132optim_state_dict_config=ShardedOptimStateDictConfig(133offload_to_cpu=offload_to_cpu134),135)136sharded_tensor_sd = ref_model.state_dict()137sharded_tensor_osd = FSDP.optim_state_dict(ref_model, ref_optim)138
139# Check dtensor and sharded_tensor model state dict values are identical140for dtensor_sd_item, sharded_tensor_sd_item in zip(141dtensor_sd.items(), sharded_tensor_sd.items()142):143k1, v1 = dtensor_sd_item144k2, v2 = sharded_tensor_sd_item145self.assertEqual(k1, k2)146
147self.assertEqual(type(v1), DTensor)148self.assertEqual(type(v2), ShardedTensor)149# check whether local_tensor are the same150self.assertEqual(v1.to_local(), v2.local_tensor())151# check whether device are the same152self.assertEqual(v1.to_local().device, v2.local_tensor().device)153
154# Check dtensor and sharde_tensor optim state dict values are identical155for dtensor_osd_state, sharded_tensor_osd_state in zip(156dtensor_osd["state"].items(), sharded_tensor_osd["state"].items()157):158# check FQN are the same159self.assertEqual(dtensor_osd_state[0], sharded_tensor_osd_state[0])160for dtensor_hyper_param, sharded_tensor_hyper_param in zip(161dtensor_osd_state[1].items(),162sharded_tensor_osd_state[1].items(),163):164k1, v1 = dtensor_hyper_param165k2, v2 = sharded_tensor_hyper_param166self.assertEqual(k1, k2)167
168if k1 != "step":169self.assertEqual(type(v1), DTensor)170self.assertEqual(type(v2), ShardedTensor)171# check whether local_tensor are the same172self.assertEqual(v1.to_local(), v2.local_tensor())173# check whether device are the same174self.assertEqual(v1.to_local().device, v2.local_tensor().device)175else:176self.assertEqual(v1, v2)177
178@with_comms179@skip_if_lt_x_gpu(4)180@parametrize("offload_to_cpu", [True, False])181def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu):182mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))183model, optim = self._create_model(mesh_2d)184
185FSDP.set_state_dict_type(186model,187StateDictType.SHARDED_STATE_DICT,188optim_state_dict_config=ShardedOptimStateDictConfig(189offload_to_cpu=offload_to_cpu190),191)192
193checkpoint = io.BytesIO()194torch.save(FSDP.optim_state_dict(model, optim), checkpoint)195# Deepcopy to save current optim_state_dict to compare with the optim_state_dict loaded back below.196ref_optim_state_dict = deepcopy(FSDP.optim_state_dict(model, optim))197
198# Update the parameters so FSDP.optim_state_dict() will be different from ref_optim_state_dict.199model(model.get_input()).sum().backward()200optim.step()201
202# Load ref_optim_state_dict back.203checkpoint.seek(0)204load_ref_optim_state_dict = torch.load(checkpoint)205optim.load_state_dict(206FSDP.optim_state_dict_to_load(model, optim, load_ref_optim_state_dict)207)208new_optim_state_dict = FSDP.optim_state_dict(model, optim)209
210# Check whether new_optim_state_dict is the same as ref_optim_state_dict.211for new_optim_state_dict_item, ref_optim_state_dict_item in zip(212new_optim_state_dict["state"].items(),213ref_optim_state_dict["state"].items(),214):215# check FQN are the same216self.assertEqual(new_optim_state_dict_item[0], ref_optim_state_dict_item[0])217for new_optim_hyper_param, ref_optim_hyper_param in zip(218new_optim_state_dict_item[1].items(),219ref_optim_state_dict_item[1].items(),220):221k1, v1 = new_optim_hyper_param222k2, v2 = ref_optim_hyper_param223# check whether keys are the same224self.assertEqual(k1, k2)225# check whether DTensor are the same226self.assertEqual(v1, v2)227
228if k1 != "step":229self.assertEqual(type(v1), DTensor)230self.assertEqual(type(v2), DTensor)231
232@with_comms233@skip_if_lt_x_gpu(4)234@parametrize("offload_to_cpu", [True, False])235def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):236mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))237model, optim = self._create_model(mesh_2d)238
239FSDP.set_state_dict_type(240model,241StateDictType.SHARDED_STATE_DICT,242state_dict_config=ShardedStateDictConfig(offload_to_cpu=offload_to_cpu),243)244
245checkpoint = io.BytesIO()246torch.save(model.state_dict(), checkpoint)247# Deepcopy to save current state_dict to compare with the state_dict loaded back below.248ref_state_dict = deepcopy(model.state_dict())249
250# Update the parameters so model.state_dict() will be different from ref_dtensor_sd.251model(model.get_input()).sum().backward()252optim.step()253
254# Load ref_state_dict back.255checkpoint.seek(0)256load_ref_state_dict = torch.load(checkpoint)257model.load_state_dict(load_ref_state_dict)258new_state_dict = model.state_dict()259
260# Check whether new_state_dict is the same as ref_state_dict.261for (k1, v1), (k2, v2) in zip(ref_state_dict.items(), new_state_dict.items()):262# check whether fqn are the same263self.assertEqual(k1, k2)264
265self.assertEqual(type(v1), DTensor)266self.assertEqual(type(v2), DTensor)267# check whether DTensor are the same268self.assertEqual(v1, v2)269
270@with_comms271@skip_if_lt_x_gpu(4)272def test_root_module_is_not_FSDP(self):273class FakeMPModel(torch.nn.Module):274def __init__(self, device_mesh):275super().__init__()276torch.manual_seed(0)277self.dense = FSDP(278DenseModel().cuda(),279use_orig_params=True,280sharding_strategy=ShardingStrategy.HYBRID_SHARD,281device_mesh=device_mesh,282)283if dist.get_rank() == 0:284self.sparse0 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())285else:286self.sparse1 = nn.Sequential(nn.Linear(8, 8), nn.ReLU())287
288def forward(self, x):289if dist.get_rank() == 0:290sparse = self.sparse0(x)291else:292sparse = self.sparse1(x)293dist.all_reduce(sparse)294return self.dense(sparse)295
296mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))297model = FakeMPModel(device_mesh=mesh_2d).cuda()298optim = torch.optim.Adam(model.parameters(), lr=1e-2)299
300batch = torch.rand(5, 8, device=torch.device("cuda"))301model(batch).sum().backward()302optim.step()303osd = optim.state_dict()304
305with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):306osd = FSDP.optim_state_dict(model, optim, osd)307
308for param, state in osd["state"].items():309if "dense" in param:310self.assertIsInstance(state["exp_avg"], DTensor)311self.assertIsInstance(state["exp_avg_sq"], DTensor)312self.assertEqual(state["exp_avg"].placements, (Replicate(), Shard(0)))313self.assertEqual(314state["exp_avg_sq"].placements, (Replicate(), Shard(0))315)316else:317self.assertIsInstance(state["exp_avg"], torch.Tensor)318self.assertIsInstance(state["exp_avg_sq"], torch.Tensor)319
320
321instantiate_parametrized_tests(TestHSDPWithDeviceMeshAndDTensor)322if __name__ == "__main__":323run_tests()324