pytorch
76 строк · 2.1 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4
5import torch
6from torch import distributed as dist
7from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8from torch.nn import Linear, Module
9from torch.optim import SGD
10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11from torch.testing._internal.common_fsdp import FSDPTest
12from torch.testing._internal.common_utils import (
13instantiate_parametrized_tests,
14parametrize,
15run_tests,
16subtest,
17TEST_WITH_DEV_DBG_ASAN,
18)
19
20if not dist.is_available():
21print("Distributed not available, skipping tests", file=sys.stderr)
22sys.exit(0)
23
24if TEST_WITH_DEV_DBG_ASAN:
25print(
26"Skip dev-asan as torch + multiprocessing spawn have known issues",
27file=sys.stderr,
28)
29sys.exit(0)
30
31
32class TestInput(FSDPTest):
33@property
34def world_size(self):
35return 1
36
37@skip_if_lt_x_gpu(1)
38@parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")])
39def test_input_type(self, input_cls):
40"""Test FSDP with input being a list or a dict, only single GPU."""
41
42class Model(Module):
43def __init__(self):
44super().__init__()
45self.layer = Linear(4, 4)
46
47def forward(self, input):
48if isinstance(input, list):
49input = input[0]
50else:
51assert isinstance(input, dict), input
52input = input["in"]
53return self.layer(input)
54
55model = FSDP(Model()).cuda()
56optim = SGD(model.parameters(), lr=0.1)
57
58for _ in range(5):
59in_data = torch.rand(64, 4).cuda()
60in_data.requires_grad = True
61if input_cls is list:
62in_data = [in_data]
63else:
64self.assertTrue(input_cls is dict)
65in_data = {"in": in_data}
66
67out = model(in_data)
68out.sum().backward()
69optim.step()
70optim.zero_grad()
71
72
73instantiate_parametrized_tests(TestInput)
74
75if __name__ == "__main__":
76run_tests()
77