pytorch
213 строк · 7.3 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys
4import unittest
5
6import torch
7import torch.distributed as dist
8import torch.distributed._functional_collectives as funcol
9import torch.nn as nn
10from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard
11from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12from torch.distributed.tensor.parallel import (
13ColwiseParallel,
14parallelize_module,
15RowwiseParallel,
16)
17from torch.fx.experimental.proxy_tensor import make_fx
18from torch.testing import FileCheck
19from torch.testing._internal.common_utils import run_tests, TestCase
20from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
21from torch.testing._internal.distributed.fake_pg import FakeStore
22
23
24if not dist.is_available():
25print("Distributed not available, skipping tests", file=sys.stderr)
26sys.exit(0)
27
28HAS_CUDA = torch.cuda.is_available()
29
30
31class TestFakePG(TestCase):
32def tearDown(self):
33super().tearDown()
34dist.destroy_process_group()
35
36def test_all_reduce(self):
37store = FakeStore()
38dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
39
40output = torch.ones(3, 3) * dist.get_rank()
41dist.all_reduce(output)
42self.assertEqual(tuple(output.shape), (3, 3))
43
44def test_allgather(self):
45store = FakeStore()
46dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
47
48input_tensor = torch.ones(3, 3) * dist.get_rank()
49output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
50dist.all_gather(output_tensors, input_tensor)
51for _, out_tensor in enumerate(output_tensors):
52self.assertEqual(tuple(out_tensor.shape), (3, 3))
53
54def test_reduce_scatter(self):
55store = FakeStore()
56dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
57
58to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
59output_tensor = torch.empty(3, 3)
60
61dist.reduce_scatter(output_tensor, to_reduce_scatter)
62self.assertEqual(tuple(output_tensor.shape), (3, 3))
63
64@unittest.skipIf(not HAS_CUDA, "No CUDA")
65def test_construct_fsdp(self):
66store = FakeStore()
67dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
68FSDP(nn.Linear(2, 3, device="cuda"))
69
70@unittest.skipIf(not HAS_CUDA, "No CUDA")
71def test_fsdp_fake_e2e(self):
72store = dist.HashStore()
73dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
74my_module = nn.Sequential(
75nn.Linear(2, 3, device="cuda"),
76nn.ReLU(),
77nn.Linear(3, 2, device="cuda"),
78)
79sharded_module = FSDP(my_module, use_orig_params=True)
80optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
81input = torch.randn(2, 2)
82x = sharded_module(input)
83loss = x.sum()
84loss.backward()
85optim.step()
86
87@unittest.skipIf(not HAS_CUDA, "No CUDA")
88def test_fake_pg_tracing(self):
89store = dist.HashStore()
90dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
91
92default_pg = dist.distributed_c10d._get_default_group()
93
94def allgather_fn(tensor):
95return funcol.all_gather_tensor(tensor, 0, default_pg)
96
97gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda"))
98FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
99
100def test_broadcast(self):
101store = FakeStore()
102dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
103
104# src == rank
105output = torch.ones(3, 3)
106dist.broadcast(output, src=0)
107self.assertEqual(tuple(output.shape), (3, 3))
108
109# src != rank
110output = torch.ones(3, 3)
111dist.broadcast(output, src=1)
112self.assertEqual(tuple(output.shape), (3, 3))
113
114def test_scatter(self):
115store = FakeStore()
116dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
117
118# src == rank
119output = torch.ones(3, 3)
120to_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
121dist.scatter(output, to_scatter)
122self.assertEqual(tuple(output.shape), (3, 3))
123
124# src != rank
125output = torch.ones(3, 3)
126dist.scatter(output, None, src=1)
127self.assertEqual(tuple(output.shape), (3, 3))
128
129def test_alltoall(self):
130store = FakeStore()
131dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
132
133output_list = [torch.ones(3, 3) for _ in range(2)]
134input_list = [torch.ones(3, 3) for _ in range(2)]
135dist.all_to_all(output_list, input_list)
136self.assertEqual(len(output_list), 2)
137for output in output_list:
138self.assertEqual(tuple(output.shape), (3, 3))
139
140def test_alltoall_base(self):
141store = FakeStore()
142dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
143
144out_tensor = torch.ones(3, 3)
145in_tensor = torch.ones(3, 3)
146output_split = [1, 1]
147input_split = [1, 1]
148dist.all_to_all_single(out_tensor, in_tensor, output_split, input_split)
149self.assertEqual(tuple(out_tensor.shape), (3, 3))
150
151def test_send(self):
152store = FakeStore()
153dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
154
155tensor = torch.ones(3, 3)
156dist.send(tensor, 1)
157self.assertEqual(tuple(tensor.shape), (3, 3))
158
159def test_recv(self):
160store = FakeStore()
161dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
162
163output = torch.ones(3, 3)
164dist.recv(output, 1)
165self.assertEqual(tuple(output.shape), (3, 3))
166
167@unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP")
168def test_fsdp_tp_fake_e2e(self):
169world_size = 4
170tp_size = 2
171
172store = dist.HashStore()
173dist.init_process_group(
174backend="fake", rank=0, world_size=world_size, store=store
175)
176
177device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size))
178device_mesh = init_device_mesh(
179"cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
180)
181
182sequence_parallelize_plan = {
183"net1": ColwiseParallel(input_layouts=Shard(0)),
184"net2": RowwiseParallel(output_layouts=Shard(0)),
185}
186pairwise_parallelize_plan = {
187"net1": ColwiseParallel(),
188"net2": RowwiseParallel(),
189}
190for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
191my_module = parallelize_module(
192MLPModule(device="cuda"),
193device_mesh["tp"],
194parallel_plan,
195)
196
197sharded_module = FSDP(
198my_module, use_orig_params=True, device_mesh=device_mesh["dp"]
199)
200optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
201
202for i in range(10):
203dp_rank = dist.get_rank()
204torch.manual_seed(i + dp_rank)
205input = torch.randn(20, 10).cuda(dist.get_rank())
206x = sharded_module(input)
207loss = x.sum()
208loss.backward()
209optim.step()
210
211
212if __name__ == "__main__":
213run_tests()
214