pytorch
174 строки · 5.8 Кб
1# Owner(s): ["oncall: distributed"]
2
3from copy import deepcopy4
5import torch6import torch.optim as optim7from torch.distributed._shard import shard_parameter, sharded_tensor8from torch.distributed._shard.sharded_optim import ShardedOptimizer9from torch.distributed._shard.sharding_spec import ChunkShardingSpec10from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu11from torch.testing._internal.common_utils import run_tests12from torch.testing._internal.distributed._shard.sharded_tensor import (13ShardedTensorTestBase,14with_comms,15)
16
17
18class MyShardedModel(torch.nn.Module):19def __init__(self, spec=None, group=None):20super().__init__()21# Use same seed.22torch.manual_seed(0)23self.param = torch.nn.Parameter(torch.rand(5, 10))24if spec is not None:25self.sharded_param = torch.nn.Parameter(26sharded_tensor.rand(27spec, 20, 10, requires_grad=True, process_group=group28)29)30else:31self.sharded_param = torch.nn.Parameter(torch.rand(5, 10))32
33def forward(self, input):34if isinstance(self.sharded_param, sharded_tensor.ShardedTensor):35return self.param + self.sharded_param.local_shards()[0].tensor + input36else:37return self.sharded_param + self.param + input38
39
40class MyShardedLinear(torch.nn.Module):41def __init__(self, rank=None):42super().__init__()43# Use same seed.44torch.manual_seed(0)45self.linear1 = torch.nn.Linear(17, 12)46self.linear2 = torch.nn.Linear(12, 29)47self.gelu = torch.nn.GELU()48
49if rank:50self.linear1.cuda(rank)51self.linear2.cuda(rank)52
53def shard_parameter(self):54rowwise_sharding_spec = ChunkShardingSpec(55dim=0,56placements=[57"rank:0/cuda:0",58"rank:1/cuda:1",59"rank:2/cuda:2",60"rank:3/cuda:3",61],62)63
64colwise_sharding_spec = ChunkShardingSpec(65dim=1,66placements=[67"rank:0/cuda:0",68"rank:1/cuda:1",69"rank:2/cuda:2",70"rank:3/cuda:3",71],72)73
74shard_parameter(self.linear1, "weight", rowwise_sharding_spec)75shard_parameter(self.linear2, "weight", colwise_sharding_spec)76
77def forward(self, inp):78return self.linear2(self.gelu(self.linear1(inp)))79
80
81class TestShardedOptimizer(ShardedTensorTestBase):82@with_comms(init_rpc=False)83@skip_if_lt_x_gpu(4)84@requires_nccl()85def test_sharded_optim(self):86rowwise_spec = ChunkShardingSpec(87dim=0,88placements=[89"rank:0/cuda:0",90"rank:1/cuda:1",91"rank:2/cuda:2",92"rank:3/cuda:3",93],94)95local_model = MyShardedModel().cuda()96sharded_model = MyShardedModel(spec=rowwise_spec).cuda()97
98# copy the parameters from local model99sharded_model.sharded_param.local_shards()[0].tensor = (100local_model.sharded_param.detach().clone().requires_grad_()101)102
103local_optim = optim.SGD(local_model.parameters(), lr=0.1)104sharded_model_params = dict(sharded_model.named_parameters())105sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)106
107local_optim.zero_grad()108sharded_optim.zero_grad()109
110before_update = deepcopy(sharded_optim.named_params)111
112inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_()113
114# run forward115local_output = local_model(inp)116sharded_output = sharded_model(inp)117# backward118local_output.sum().backward()119sharded_output.sum().backward()120
121# optimizer update122local_optim.step()123sharded_optim.step()124
125# make sure the parameters (including sharded param)126# get updated by the optimizer, and the updated127# local params are the same as the sharded params128for key, val in before_update.items():129new_val = sharded_optim.named_params[key]130if isinstance(val, sharded_tensor.ShardedTensor):131self.assertNotEqual(132val.local_shards()[0].tensor, new_val.local_shards()[0].tensor133)134self.assertEqual(135new_val.local_shards()[0].tensor, local_model.sharded_param136)137else:138self.assertNotEqual(val, new_val)139self.assertEqual(new_val, local_model.param)140
141@with_comms(init_rpc=False)142@skip_if_lt_x_gpu(4)143@requires_nccl()144def test_named_params_with_sharded_tensor(self):145rowwise_spec = ChunkShardingSpec(146dim=0,147placements=[148"rank:0/cuda:0",149"rank:1/cuda:1",150"rank:2/cuda:2",151"rank:3/cuda:3",152],153)154sharded_model = MyShardedModel(spec=rowwise_spec).cuda()155sharded_model_params = dict(sharded_model.named_parameters())156param_keys = list(sharded_model_params.keys())157self.assertEqual(len(param_keys), 2)158self.assertTrue("param" in param_keys)159self.assertTrue("sharded_param" in param_keys)160
161sharded_linear = MyShardedLinear(rank=self.rank).cuda()162sharded_linear.shard_parameter()163sharded_linear_params = dict(sharded_linear.named_parameters())164param_keys = list(sharded_linear_params.keys())165self.assertEqual(len(param_keys), 4)166self.assertTrue("linear1.bias" in param_keys)167self.assertTrue("linear2.bias" in param_keys)168self.assertTrue("linear1.weight" in param_keys)169self.assertTrue("linear2.weight" in param_keys)170self.assertFalse("bias" in param_keys)171
172
173if __name__ == "__main__":174run_tests()175