pytorch
68 строк · 2.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6from torch import distributed as dist
7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8from torch.nn import Linear
9from torch.optim import SGD
10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11from torch.testing._internal.common_fsdp import FSDPTest
12from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
13
14if not dist.is_available():
15print("Distributed not available, skipping tests", file=sys.stderr)
16sys.exit(0)
17
18if TEST_WITH_DEV_DBG_ASAN:
19print(
20"Skip dev-asan as torch + multiprocessing spawn have known issues",
21file=sys.stderr,
22)
23sys.exit(0)
24
25
26class TestUnevenParamShard(FSDPTest):
27def _get_ref_results(self, model, input, my_lr):
28with torch.no_grad():
29# Compute one iteration local output.
30weight = model.weight.T.clone().to(self.rank)
31v = torch.Tensor(input[self.rank]).to(self.rank)
32ref_forward_output_my_rank = torch.matmul(v, weight)
33# Compute one iteration global weight update.
34v = torch.Tensor(input[: self.world_size]).to(self.rank)
35grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
36ref_weight_out = weight - grad.T * my_lr
37
38return ref_forward_output_my_rank, ref_weight_out
39
40@skip_if_lt_x_gpu(2)
41def test_one_iteration(self):
42"""Test FSDP with uneven divide of parameter shards."""
43model = Linear(3, 3, bias=False)
44input = torch.rand(8, 3)
45my_lr = 0.1
46
47ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
48model, input, my_lr
49)
50
51model.to(self.rank)
52model = FSDP(model)
53optim = SGD(model.parameters(), lr=my_lr)
54self.assertTrue(len(input) >= self.world_size)
55in_data = torch.Tensor(input[self.rank]).to(self.rank)
56out = model(in_data)
57out.float().sum().backward()
58optim.step()
59optim.zero_grad()
60
61with model.summon_full_params(model):
62weight_out = model.module.weight.T.clone()
63self.assertEqual(ref_forward_output_my_rank, out)
64self.assertEqual(ref_weight_out, weight_out)
65
66
67if __name__ == "__main__":
68run_tests()
69