pytorch

Форк
0
/
test_fake_pg.py 
213 строк · 7.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4
import unittest
5

6
import torch
7
import torch.distributed as dist
8
import torch.distributed._functional_collectives as funcol
9
import torch.nn as nn
10
from torch.distributed._tensor import DeviceMesh, init_device_mesh, Shard
11
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12
from torch.distributed.tensor.parallel import (
13
    ColwiseParallel,
14
    parallelize_module,
15
    RowwiseParallel,
16
)
17
from torch.fx.experimental.proxy_tensor import make_fx
18
from torch.testing import FileCheck
19
from torch.testing._internal.common_utils import run_tests, TestCase
20
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
21
from torch.testing._internal.distributed.fake_pg import FakeStore
22

23

24
if not dist.is_available():
25
    print("Distributed not available, skipping tests", file=sys.stderr)
26
    sys.exit(0)
27

28
HAS_CUDA = torch.cuda.is_available()
29

30

31
class TestFakePG(TestCase):
32
    def tearDown(self):
33
        super().tearDown()
34
        dist.destroy_process_group()
35

36
    def test_all_reduce(self):
37
        store = FakeStore()
38
        dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
39

40
        output = torch.ones(3, 3) * dist.get_rank()
41
        dist.all_reduce(output)
42
        self.assertEqual(tuple(output.shape), (3, 3))
43

44
    def test_allgather(self):
45
        store = FakeStore()
46
        dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
47

48
        input_tensor = torch.ones(3, 3) * dist.get_rank()
49
        output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
50
        dist.all_gather(output_tensors, input_tensor)
51
        for _, out_tensor in enumerate(output_tensors):
52
            self.assertEqual(tuple(out_tensor.shape), (3, 3))
53

54
    def test_reduce_scatter(self):
55
        store = FakeStore()
56
        dist.init_process_group(backend="fake", rank=1, world_size=2, store=store)
57

58
        to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
59
        output_tensor = torch.empty(3, 3)
60

61
        dist.reduce_scatter(output_tensor, to_reduce_scatter)
62
        self.assertEqual(tuple(output_tensor.shape), (3, 3))
63

64
    @unittest.skipIf(not HAS_CUDA, "No CUDA")
65
    def test_construct_fsdp(self):
66
        store = FakeStore()
67
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
68
        FSDP(nn.Linear(2, 3, device="cuda"))
69

70
    @unittest.skipIf(not HAS_CUDA, "No CUDA")
71
    def test_fsdp_fake_e2e(self):
72
        store = dist.HashStore()
73
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
74
        my_module = nn.Sequential(
75
            nn.Linear(2, 3, device="cuda"),
76
            nn.ReLU(),
77
            nn.Linear(3, 2, device="cuda"),
78
        )
79
        sharded_module = FSDP(my_module, use_orig_params=True)
80
        optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
81
        input = torch.randn(2, 2)
82
        x = sharded_module(input)
83
        loss = x.sum()
84
        loss.backward()
85
        optim.step()
86

87
    @unittest.skipIf(not HAS_CUDA, "No CUDA")
88
    def test_fake_pg_tracing(self):
89
        store = dist.HashStore()
90
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
91

92
        default_pg = dist.distributed_c10d._get_default_group()
93

94
        def allgather_fn(tensor):
95
            return funcol.all_gather_tensor(tensor, 0, default_pg)
96

97
        gm = make_fx(allgather_fn)(torch.randn(2, 2, device="cuda"))
98
        FileCheck().check("all_gather").check("wait_tensor").run(str(gm.graph))
99

100
    def test_broadcast(self):
101
        store = FakeStore()
102
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
103

104
        # src == rank
105
        output = torch.ones(3, 3)
106
        dist.broadcast(output, src=0)
107
        self.assertEqual(tuple(output.shape), (3, 3))
108

109
        # src != rank
110
        output = torch.ones(3, 3)
111
        dist.broadcast(output, src=1)
112
        self.assertEqual(tuple(output.shape), (3, 3))
113

114
    def test_scatter(self):
115
        store = FakeStore()
116
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
117

118
        # src == rank
119
        output = torch.ones(3, 3)
120
        to_scatter = [torch.ones(3, 3) * rank for rank in range(2)]
121
        dist.scatter(output, to_scatter)
122
        self.assertEqual(tuple(output.shape), (3, 3))
123

124
        # src != rank
125
        output = torch.ones(3, 3)
126
        dist.scatter(output, None, src=1)
127
        self.assertEqual(tuple(output.shape), (3, 3))
128

129
    def test_alltoall(self):
130
        store = FakeStore()
131
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
132

133
        output_list = [torch.ones(3, 3) for _ in range(2)]
134
        input_list = [torch.ones(3, 3) for _ in range(2)]
135
        dist.all_to_all(output_list, input_list)
136
        self.assertEqual(len(output_list), 2)
137
        for output in output_list:
138
            self.assertEqual(tuple(output.shape), (3, 3))
139

140
    def test_alltoall_base(self):
141
        store = FakeStore()
142
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
143

144
        out_tensor = torch.ones(3, 3)
145
        in_tensor = torch.ones(3, 3)
146
        output_split = [1, 1]
147
        input_split = [1, 1]
148
        dist.all_to_all_single(out_tensor, in_tensor, output_split, input_split)
149
        self.assertEqual(tuple(out_tensor.shape), (3, 3))
150

151
    def test_send(self):
152
        store = FakeStore()
153
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
154

155
        tensor = torch.ones(3, 3)
156
        dist.send(tensor, 1)
157
        self.assertEqual(tuple(tensor.shape), (3, 3))
158

159
    def test_recv(self):
160
        store = FakeStore()
161
        dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
162

163
        output = torch.ones(3, 3)
164
        dist.recv(output, 1)
165
        self.assertEqual(tuple(output.shape), (3, 3))
166

167
    @unittest.skipIf(not HAS_CUDA, "No CUDA or TP+FSDP")
168
    def test_fsdp_tp_fake_e2e(self):
169
        world_size = 4
170
        tp_size = 2
171

172
        store = dist.HashStore()
173
        dist.init_process_group(
174
            backend="fake", rank=0, world_size=world_size, store=store
175
        )
176

177
        device_mesh = DeviceMesh("cuda", torch.arange(0, world_size).view(-1, tp_size))
178
        device_mesh = init_device_mesh(
179
            "cuda", (world_size // tp_size, tp_size), mesh_dim_names=["dp", "tp"]
180
        )
181

182
        sequence_parallelize_plan = {
183
            "net1": ColwiseParallel(input_layouts=Shard(0)),
184
            "net2": RowwiseParallel(output_layouts=Shard(0)),
185
        }
186
        pairwise_parallelize_plan = {
187
            "net1": ColwiseParallel(),
188
            "net2": RowwiseParallel(),
189
        }
190
        for parallel_plan in [sequence_parallelize_plan, pairwise_parallelize_plan]:
191
            my_module = parallelize_module(
192
                MLPModule(device="cuda"),
193
                device_mesh["tp"],
194
                parallel_plan,
195
            )
196

197
            sharded_module = FSDP(
198
                my_module, use_orig_params=True, device_mesh=device_mesh["dp"]
199
            )
200
            optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
201

202
            for i in range(10):
203
                dp_rank = dist.get_rank()
204
                torch.manual_seed(i + dp_rank)
205
                input = torch.randn(20, 10).cuda(dist.get_rank())
206
                x = sharded_module(input)
207
                loss = x.sum()
208
                loss.backward()
209
                optim.step()
210

211

212
if __name__ == "__main__":
213
    run_tests()
214

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.