pytorch

Форк
0
/
test_c10d_gloo.py 
2559 строк · 92.4 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import copy
4
import logging
5
import math
6
import operator
7
import os
8
import random
9
import sys
10
import tempfile
11
from datetime import timedelta
12
from functools import reduce
13
from itertools import groupby
14

15
import torch
16
import torch.distributed as c10d
17

18
if not c10d.is_available() or not c10d.is_gloo_available():
19
    print("c10d GLOO not available, skipping tests", file=sys.stderr)
20
    sys.exit(0)
21

22
import test_c10d_common
23
import torch.distributed as dist
24
import torch.nn.functional as F
25
import torch.testing._internal.common_utils as common
26
from test_c10d_common import (
27
    gpus_for_rank,
28
    LOOPBACK,
29
    ModuleForDdpCommHook,
30
    SparseGradientModule,
31
    Task,
32
)
33
from torch import nn
34
from torch.distributed._shard.sharded_tensor import (
35
    init_from_local_shards,
36
    Shard,
37
    ShardedTensor,
38
    ShardMetadata,
39
)
40
from torch.nn.parallel import DistributedDataParallel
41
from torch.testing._internal.common_distributed import (
42
    create_device,
43
    MultiProcessTestCase,
44
    requires_gloo,
45
    simple_sparse_reduce_tests,
46
    skip_if_lt_x_gpu,
47
    skip_if_win32,
48
    verify_ddp_error_logged,
49
)
50
from torch.testing._internal.common_utils import (
51
    retry_on_connect_failures,
52
    run_tests,
53
    skip_but_pass_in_sandcastle,
54
    TestCase,
55
)
56

57

58
def simple_reduce_tests(rank, world_size):
59
    tests = [
60
        (
61
            c10d.ReduceOp.SUM,
62
            torch.tensor([rank + 1.0]),
63
            torch.tensor([float(world_size * (world_size + 1) / 2)]),
64
        ),
65
        (
66
            c10d.ReduceOp.PRODUCT,
67
            torch.tensor([rank + 1.0]),
68
            torch.tensor([float(math.factorial(world_size))]),
69
        ),
70
        (
71
            c10d.ReduceOp.MIN,
72
            torch.tensor([rank + 1.0]),
73
            torch.tensor([1.0]),
74
        ),
75
        (
76
            c10d.ReduceOp.MAX,
77
            torch.tensor([rank + 1.0]),
78
            torch.tensor([float(world_size)]),
79
        ),
80
    ]
81

82
    # Generate tests for BAND.
83
    # The bit that is set changes in every iteration to check
84
    # that the output changes accordingly.
85
    for i in range(4):
86
        vin = rank | (1 << i)
87
        vout = 1 << i
88
        tests.append(
89
            (
90
                c10d.ReduceOp.BAND,
91
                torch.tensor([vin], dtype=torch.int32),
92
                torch.tensor([vout], dtype=torch.int32),
93
            ),
94
        )
95

96
    # Generate tests for BOR.
97
    # These emulate a larger world size per iteration by having every
98
    # rank contribute multiple values that are pre-OR'ed.
99
    for i in range(1, 5):
100
        vin = reduce(operator.or_, [rank * i + j for j in range(i)])
101
        vout = reduce(operator.or_, range(world_size * i))
102
        tests.append(
103
            (
104
                c10d.ReduceOp.BOR,
105
                torch.tensor([vin], dtype=torch.int32),
106
                torch.tensor([vout], dtype=torch.int32),
107
            ),
108
        )
109

110
    # Generate tests for XOR.
111
    # These emulate a larger world size per iteration by having every
112
    # rank contribute multiple values that are pre-XOR'ed.
113
    for i in range(1, 5):
114
        vin = reduce(operator.xor, [rank * i + j for j in range(i)])
115
        vout = reduce(operator.xor, range(world_size * i))
116
        tests.append(
117
            (
118
                c10d.ReduceOp.BXOR,
119
                torch.tensor([vin], dtype=torch.int32),
120
                torch.tensor([vout], dtype=torch.int32),
121
            ),
122
        )
123

124
    return tests
125

126

127
def simple_coalesced_reduce_tests(rank, world_size):
128
    return [
129
        (
130
            c10d.ReduceOp.SUM,
131
            [torch.tensor([rank + 1.0]), torch.tensor([(rank + 1.0) ** 2])],
132
            [
133
                torch.tensor([float(world_size * (world_size + 1) / 2)]),
134
                torch.tensor(
135
                    [float(world_size * (world_size + 1) * (2 * world_size + 1) / 6)]
136
                ),
137
            ],
138
        ),
139
        (
140
            c10d.ReduceOp.PRODUCT,
141
            [torch.tensor([rank + 1.0]), torch.tensor([rank + 2.0])],
142
            [
143
                torch.tensor([float(math.factorial(world_size))]),
144
                torch.tensor([float(math.factorial(world_size + 1))]),
145
            ],
146
        ),
147
        (
148
            c10d.ReduceOp.MIN,
149
            [torch.tensor([rank + x]) for x in [0.0, 1.0]],
150
            [torch.tensor([0.0]), torch.tensor([1.0])],
151
        ),
152
        (
153
            c10d.ReduceOp.MAX,
154
            [torch.tensor([rank + x]) for x in [1.0, 2.0]],
155
            [torch.tensor([float(world_size)]), torch.tensor([world_size + 1.0])],
156
        ),
157
    ]
158

159

160
def simple_multi_input_reduce_tests(rank, world_size):
161
    return [
162
        (
163
            c10d.ReduceOp.SUM,
164
            [torch.tensor([2 * rank + 0.0]), torch.tensor([2 * rank + 1.0])],
165
            torch.tensor([float(world_size * (2 * world_size - 1))]),
166
        ),
167
        (
168
            c10d.ReduceOp.PRODUCT,
169
            [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],
170
            torch.tensor([float(math.factorial(2 * world_size))]),
171
        ),
172
        (
173
            c10d.ReduceOp.MIN,
174
            [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],
175
            torch.tensor([1.0]),
176
        ),
177
        (
178
            c10d.ReduceOp.MAX,
179
            [torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],
180
            torch.tensor([2.0 * world_size]),
181
        ),
182
    ]
183

184

185
class RendezvousEnvTest(TestCase):
186
    @requires_gloo()
187
    @retry_on_connect_failures
188
    def test_logging_init(self):
189
        os.environ["WORLD_SIZE"] = "1"
190
        os.environ["MASTER_ADDR"] = "127.0.0.1"
191
        os.environ["MASTER_PORT"] = str(common.find_free_port())
192
        os.environ["RANK"] = "0"
193

194
        previous_handlers = logging.root.handlers
195

196
        c10d.init_process_group(backend="gloo", init_method="env://")
197

198
        current_handlers = logging.root.handlers
199
        self.assertEqual(len(previous_handlers), len(current_handlers))
200
        for current, previous in zip(current_handlers, previous_handlers):
201
            self.assertEqual(current, previous)
202

203
        c10d.destroy_process_group()
204

205

206
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
207
    @requires_gloo()
208
    @retry_on_connect_failures
209
    def test_default_store_timeout_gloo(self):
210
        self._test_default_store_timeout("gloo")
211

212

213
class ProcessGroupGlooTest(MultiProcessTestCase):
214
    def _create_process_group_gloo(self, store, rank, world_size, opts):
215
        pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts)
216
        dist.barrier(group=pg)
217
        return pg
218

219
    def setUp(self):
220
        super().setUp()
221
        self._spawn_processes()
222

223
    def opts(self, threads=2):
224
        opts = c10d.ProcessGroupGloo._Options()
225
        opts._timeout = 50.0
226
        opts._devices = [create_device(interface=LOOPBACK)]
227
        opts._threads = threads
228
        return opts
229

230
    @requires_gloo()
231
    def test_multi_device_constructor(self):
232
        store = c10d.FileStore(self.file_name, self.world_size)
233
        opts = c10d.ProcessGroupGloo._Options()
234
        opts._timeout = 5.0
235
        opts._devices = [
236
            create_device(interface=LOOPBACK),
237
            create_device(interface=LOOPBACK),
238
        ]
239
        pg = self._create_process_group_gloo(store, self.rank, self.world_size, opts)
240

241
        # Execute 2x the number of operations to ensure we use every device.
242
        for fut in [pg.allreduce(torch.ones(i + 1)).get_future() for i in range(4)]:
243
            fut.wait()
244

245
    @requires_gloo()
246
    def test_empty_tensors(self):
247
        store = c10d.FileStore(self.file_name, self.world_size)
248
        pg = self._create_process_group_gloo(
249
            store, self.rank, self.world_size, self.opts()
250
        )
251

252
        xs = [torch.FloatTensor([])]
253
        fut = pg.broadcast(xs).get_future()
254
        fut.wait()
255
        output = fut.value()
256
        self.assertEqual(0, output[0].numel())
257
        self.assertEqual(xs[0], output[0])
258

259
    @requires_gloo()
260
    def test_broadcast_checks(self):
261
        store = c10d.FileStore(self.file_name, self.world_size)
262
        pg = self._create_process_group_gloo(
263
            store, self.rank, self.world_size, self.opts()
264
        )
265

266
        t1 = torch.zeros([1], dtype=torch.float32)
267
        t2 = torch.zeros([1], dtype=torch.float64)
268
        t3 = torch.zeros([2], dtype=torch.float32)
269

270
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
271
            opts = c10d.BroadcastOptions()
272
            opts.rootRank = -1
273
            opts.rootTensor = 0
274
            pg.broadcast([t1], opts)
275

276
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
277
            opts = c10d.BroadcastOptions()
278
            opts.rootRank = self.world_size
279
            opts.rootTensor = 0
280
            pg.broadcast([t1], opts)
281

282
        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
283
            opts = c10d.BroadcastOptions()
284
            opts.rootRank = self.rank
285
            opts.rootTensor = -1
286
            pg.broadcast([t1], opts)
287

288
        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
289
            opts = c10d.BroadcastOptions()
290
            opts.rootRank = self.rank
291
            opts.rootTensor = 1
292
            pg.broadcast([t1], opts)
293

294
        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
295
            opts = c10d.BroadcastOptions()
296
            opts.rootRank = self.rank
297
            opts.rootTensor = 0
298
            pg.broadcast([], opts)
299

300
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
301
            opts = c10d.BroadcastOptions()
302
            opts.rootRank = self.rank
303
            opts.rootTensor = 0
304
            pg.broadcast([t1, t2], opts)
305

306
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
307
            opts = c10d.BroadcastOptions()
308
            opts.rootRank = self.rank
309
            opts.rootTensor = 0
310
            pg.broadcast([t1, t3], opts)
311

312
    def _test_broadcast_basics(self, fn):
313
        store = c10d.FileStore(self.file_name, self.world_size)
314
        pg = self._create_process_group_gloo(
315
            store, self.rank, self.world_size, self.opts()
316
        )
317

318
        def broadcast(xs, rootRank, rootTensor):
319
            opts = c10d.BroadcastOptions()
320
            opts.rootRank = rootRank
321
            opts.rootTensor = rootTensor
322
            fut = pg.broadcast(xs, opts).get_future()
323
            fut.wait()
324
            return fut.value()
325

326
        # Every rank is root once
327
        for i in range(self.world_size):
328
            # Run with 1 input tensor
329
            x = fn(torch.tensor([self.rank]))
330
            output = broadcast([x], i, 0)
331
            self.assertEqual(torch.tensor([i]), output[0])
332

333
            # Run with 2 input tensors
334
            num = 2
335
            for j in range(num):
336
                xs = [
337
                    fn(torch.tensor([self.rank * num + 0.0])),
338
                    fn(torch.tensor([self.rank * num + 1.0])),
339
                ]
340

341
                output = broadcast(xs, i, j)
342
                self.assertEqual(torch.tensor([i * num + j], dtype=torch.float32), output[0])
343
                self.assertEqual(torch.tensor([i * num + j], dtype=torch.float32), output[1])
344

345
        # Test overloaded convenience function
346
        x = torch.tensor([self.rank + 1.0])
347
        fut = pg.broadcast(x, root=0).get_future()
348
        fut.wait()
349
        result = fut.value()
350
        self.assertEqual(torch.tensor([1.0]), result[0])
351

352
    @requires_gloo()
353
    def test_broadcast_basics(self):
354
        self._test_broadcast_basics(lambda t: t.clone())
355

356
    @skip_if_lt_x_gpu(2)
357
    @requires_gloo()
358
    def test_broadcast_basics_cuda(self):
359
        self._test_broadcast_basics(lambda t: t.clone().cuda())
360

361
    def _test_broadcast_stress(self, inputs):
362
        store = c10d.FileStore(self.file_name, self.world_size)
363
        pg = self._create_process_group_gloo(
364
            store, self.rank, self.world_size, self.opts(threads=8)
365
        )
366
        work_handles = [
367
            pg.broadcast(inputs[i], root=(i % self.world_size))
368
            for i in range(len(inputs))
369
        ]
370
        for i, work_handle in enumerate(work_handles):
371
            work_handle.wait()
372
            self.assertEqual(
373
                torch.tensor([(i * self.world_size) + (i % self.world_size)]),
374
                inputs[i],
375
                msg=("Mismatch in iteration %d" % i),
376
            )
377

378
    @requires_gloo()
379
    def test_broadcast_stress(self):
380
        inputs = [torch.tensor([i * self.world_size + self.rank]) for i in range(1000)]
381
        self._test_broadcast_stress(inputs)
382

383
    @skip_if_lt_x_gpu(2)
384
    @requires_gloo()
385
    def test_broadcast_stress_cuda(self):
386
        inputs = [
387
            torch.tensor([i * self.world_size + self.rank]).cuda() for i in range(1000)
388
        ]
389
        self._test_broadcast_stress(inputs)
390

391
    @requires_gloo()
392
    def test_allreduce_checks(self):
393
        store = c10d.FileStore(self.file_name, self.world_size)
394
        pg = self._create_process_group_gloo(
395
            store, self.rank, self.world_size, self.opts()
396
        )
397

398
        t1 = torch.zeros([1], dtype=torch.float32)
399
        t2 = torch.zeros([1], dtype=torch.float64)
400
        t3 = torch.zeros([2], dtype=torch.float32)
401

402
        with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):
403
            opts = c10d.AllreduceOptions()
404
            pg.allreduce([], opts)
405

406
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
407
            opts = c10d.AllreduceOptions()
408
            pg.allreduce([t1, t2], opts)
409

410
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
411
            opts = c10d.AllreduceOptions()
412
            pg.allreduce([t1, t3], opts)
413

414
    def _test_allreduce_basics(self, fn):
415
        store = c10d.FileStore(self.file_name, self.world_size)
416
        pg = self._create_process_group_gloo(
417
            store, self.rank, self.world_size, self.opts()
418
        )
419

420
        # Single input tests
421
        tests = simple_reduce_tests(self.rank, self.world_size)
422
        for (op, input, expected) in tests:
423
            opts = c10d.AllreduceOptions()
424
            opts.reduceOp = op
425
            tensor = fn(input)
426
            fut = pg.allreduce([tensor], opts).get_future()
427
            fut.wait()
428
            result = fut.value()
429
            self.assertEqual(expected, result[0])
430

431
        # Multi input tests
432
        tests = simple_multi_input_reduce_tests(self.rank, self.world_size)
433
        for (op, inputs, output) in tests:
434
            opts = c10d.AllreduceOptions()
435
            opts.reduceOp = op
436
            tensors = [fn(input) for input in inputs]
437
            fut = pg.allreduce(tensors, opts).get_future()
438
            fut.wait()
439
            result = fut.value()
440
            for tensor in result:
441
                self.assertEqual(output, tensor)
442

443
        # Test overloaded convenience function (defaults to using sum)
444
        x = fn(torch.tensor([self.rank + 1.0]))
445
        fut = pg.allreduce(x).get_future()
446
        fut.wait()
447
        result = fut.value()
448
        self.assertEqual(
449
            torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
450
            result[0],
451
        )
452

453
    @requires_gloo()
454
    def test_allreduce_basics(self):
455
        self._test_allreduce_basics(lambda t: t.clone())
456

457
    @skip_if_lt_x_gpu(2)
458
    @requires_gloo()
459
    def test_allreduce_basics_cuda(self):
460
        self._test_allreduce_basics(lambda t: t.clone().cuda())
461

462
    def _test_allreduce_stress(self, inputs):
463
        store = c10d.FileStore(self.file_name, self.world_size)
464
        pg = self._create_process_group_gloo(
465
            store, self.rank, self.world_size, self.opts(threads=8)
466
        )
467
        future_handles = [
468
            pg.allreduce(inputs[i]).get_future() for i in range(len(inputs))
469
        ]
470
        for i, future_handle in enumerate(future_handles):
471
            future_handle.wait()
472
            self.assertEqual(
473
                torch.tensor(
474
                    [
475
                        (i * self.world_size)
476
                        + (self.world_size * (self.world_size - 1) // 2)
477
                    ]
478
                ),
479
                future_handle.value()[0],
480
                msg=("Mismatch in iteration %d" % i),
481
            )
482

483
    @requires_gloo()
484
    def test_allreduce_stress(self):
485
        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
486
        self._test_allreduce_stress(inputs)
487

488
    @skip_if_lt_x_gpu(2)
489
    @requires_gloo()
490
    def test_allreduce_stress_cuda(self):
491
        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
492
        self._test_allreduce_stress(inputs)
493

494
    @requires_gloo()
495
    def test_allreduce_coalesced_checks(self):
496
        store = c10d.FileStore(self.file_name, self.world_size)
497
        pg = self._create_process_group_gloo(
498
            store, self.rank, self.world_size, self.opts()
499
        )
500

501
        t1 = torch.zeros(1, dtype=torch.float32)
502
        t2 = torch.zeros(1, dtype=torch.float64)
503
        t3 = torch.sparse_coo_tensor([[0]], [1], size=(1,))
504

505
        with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):
506
            opts = c10d.AllreduceCoalescedOptions()
507
            pg.allreduce_coalesced([], opts)
508

509
        with self.assertRaisesRegex(RuntimeError, "tensors must all have the same type"):
510
            opts = c10d.AllreduceCoalescedOptions()
511
            pg.allreduce_coalesced([t1, t2], opts)
512

513
        with self.assertRaisesRegex(RuntimeError, "invalid tensor layout at index"):
514
            opts = c10d.AllreduceCoalescedOptions()
515
            pg.allreduce_coalesced([t1, t3], opts)
516

517
        with self.assertRaisesRegex(RuntimeError, "unsupported layout"):
518
            opts = c10d.AllreduceCoalescedOptions()
519
            pg.allreduce_coalesced([t3, t3.clone()], opts)
520

521
    @skip_if_lt_x_gpu(1)
522
    @requires_gloo()
523
    def test_allreduce_coalesced_checks_cuda(self):
524
        store = c10d.FileStore(self.file_name, self.world_size)
525
        pg = self._create_process_group_gloo(
526
            store, self.rank, self.world_size, self.opts()
527
        )
528

529
        t1 = torch.zeros(1, dtype=torch.float32)
530

531
        with self.assertRaisesRegex(RuntimeError, "unsupported device type"):
532
            opts = c10d.AllreduceCoalescedOptions()
533
            pg.allreduce_coalesced([t1.cuda(), t1.cuda()], opts)
534

535
    def _test_allreduce_coalesced_basics(self, fn):
536
        store = c10d.FileStore(self.file_name, self.world_size)
537
        pg = self._create_process_group_gloo(
538
            store, self.rank, self.world_size, self.opts()
539
        )
540

541
        test_cases = simple_coalesced_reduce_tests(self.rank, self.world_size)
542
        for op, inputs, outputs in test_cases:
543
            opts = c10d.AllreduceCoalescedOptions()
544
            opts.reduceOp = op
545
            tensors = [fn(x) for x in inputs]
546
            fut = pg.allreduce_coalesced(tensors, opts).get_future()
547
            fut.wait()
548
            result = fut.value()
549
            for result_tensor, expected in zip(result, outputs):
550
                self.assertEqual(result_tensor, expected)
551

552
    @requires_gloo()
553
    def test_allreduce_coalesced_basics(self):
554
        self._test_allreduce_coalesced_basics(lambda t: t.clone())
555

556
    def _expected_output(self, i):
557
        ws = self.world_size
558
        return 2 * [torch.tensor([(i * ws) + (ws * (ws - 1) // 2)])]
559

560
    def _test_allreduce_coalesced_stress(self, inputs):
561
        store = c10d.FileStore(self.file_name, self.world_size)
562
        pg = self._create_process_group_gloo(
563
            store, self.rank, self.world_size, self.opts(threads=8)
564
        )
565
        future_handles = [
566
            pg.allreduce_coalesced(input).get_future() for input in inputs
567
        ]
568
        for i, future_handle in enumerate(future_handles):
569
            future_handle.wait()
570
            result = future_handle.value()
571
            self.assertEqual(
572
                self._expected_output(i),
573
                result,
574
                msg=f"Mismatch in iteration {i}",
575
            )
576

577
    @requires_gloo()
578
    def test_allreduce_coalesced_stress(self):
579
        inputs = [2 * [torch.tensor([i + self.rank])] for i in range(1000)]
580
        self._test_allreduce_coalesced_stress(inputs)
581

582
    @requires_gloo()
583
    def test_allreduce_coalesced_async(self):
584
        store = c10d.FileStore(self.file_name, self.world_size)
585
        c10d.init_process_group(
586
            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
587
        )
588

589
        xs = [2 * [torch.tensor([i + self.rank])] for i in range(2)]
590
        futs = [c10d.all_reduce_coalesced(x, async_op=True) for x in xs]
591
        torch.futures.wait_all(futs)
592
        for i, fut in enumerate(futs):
593
            self.assertEqual(
594
                self._expected_output(i),
595
                fut.wait(),
596
                msg=f"Mismatch in iteration {i}",
597
            )
598

599
    @requires_gloo()
600
    def test_sparse_allreduce_checks(self):
601
        store = c10d.FileStore(self.file_name, self.world_size)
602
        pg = self._create_process_group_gloo(
603
            store, self.rank, self.world_size, self.opts()
604
        )
605

606
        t1 = torch.zeros([1])
607
        t2 = torch.sparse_coo_tensor([[0]], [1], size=(2,))
608
        t3 = torch.sparse_coo_tensor([[0]], [1], size=(4,))
609

610
        with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):
611
            opts = c10d.AllreduceOptions()
612
            pg.allreduce([], opts)
613

614
        with self.assertRaisesRegex(RuntimeError, "invalid tensor layout"):
615
            opts = c10d.AllreduceOptions()
616
            pg.allreduce([t1, t2], opts)
617

618
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
619
            opts = c10d.AllreduceOptions()
620
            pg.allreduce([t2, t3], opts)
621

622
        # Sparse allreduce only works with c10d.ReduceOp.SUM.
623
        for op in [c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX]:
624
            with self.assertRaisesRegex(RuntimeError, "unsupported reduction operation"):
625
                opts = c10d.AllreduceOptions()
626
                opts.reduceOp = op
627
                pg.allreduce([t3], opts)
628

629
    def _test_sparse_allreduce_basics(self, fn):
630
        store = c10d.FileStore(self.file_name, self.world_size)
631
        pg = self._create_process_group_gloo(
632
            store, self.rank, self.world_size, self.opts()
633
        )
634

635
        for num_inputs_per_rank in [1, 2]:
636
            tests = simple_sparse_reduce_tests(
637
                self.rank, self.world_size, num_inputs=num_inputs_per_rank
638
            )
639
            for (inputs, outputs) in tests:
640
                tensors = [fn(input) for input in inputs]
641
                fut = pg.allreduce(tensors).get_future()
642
                fut.wait()
643
                result = fut.value()
644
                self.assertEqual(tensors, outputs)
645
                self.assertEqual(result, outputs)
646

647
    @requires_gloo()
648
    def test_sparse_allreduce_basics(self):
649
        self._test_sparse_allreduce_basics(lambda t: t)
650

651
    @skip_if_lt_x_gpu(2)
652
    @requires_gloo()
653
    def test_sparse_allreduce_basics_cuda(self):
654
        self._test_sparse_allreduce_basics(lambda t: t.clone().cuda())
655

656
    @skip_if_lt_x_gpu(2)
657
    @requires_gloo()
658
    def test_sparse_allreduce_cuda_dispatched(self):
659
        store = c10d.FileStore(self.file_name, self.world_size)
660
        dist.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
661
        tests = simple_sparse_reduce_tests(
662
            self.rank, self.world_size, num_inputs=1
663
        )
664
        for (inputs, outputs) in tests:
665
            tensors = inputs[-1].clone().cuda()
666
            work = dist.all_reduce(tensors, async_op=True)
667
            work.wait()
668
            self.assertEqual([tensors], outputs)
669

670
    @requires_gloo()
671
    def test_allgather_into_tensor_coalesced(self):
672
        store = c10d.FileStore(self.file_name, self.world_size)
673
        dist.init_process_group(
674
            backend="gloo",
675
            store=store,
676
            rank=self.rank,
677
            world_size=self.world_size,
678
        )
679
        torch.manual_seed(42)
680
        in_shapes = [(5, 5), (10, 10), (15, 15)]
681
        out_shapes = [(s[0] * self.world_size,) + s[1:] for s in in_shapes]
682

683
        outputs = [torch.empty(s) for s in out_shapes]
684
        inputs = [torch.rand(s) for s in in_shapes]
685
        work = dist.group.WORLD.allgather_into_tensor_coalesced(outputs, inputs)
686
        work.wait()
687

688
        for output, input in zip(outputs, inputs):
689
            expect = torch.cat([input] * self.world_size)
690
            self.assertTrue(torch.allclose(output, expect))
691

692
    @requires_gloo()
693
    def test_reduce_scatter_tensor(self):
694
        store = c10d.FileStore(self.file_name, self.world_size)
695
        dist.init_process_group(
696
            backend="gloo",
697
            store=store,
698
            rank=self.rank,
699
            world_size=self.world_size,
700
        )
701
        torch.manual_seed(42)
702
        out_shape = (20, 20)
703
        in_shape = (out_shape[0] * self.world_size,) + out_shape[1:]
704

705
        output = torch.empty(out_shape)
706
        input = torch.rand(in_shape)
707
        work = dist.reduce_scatter_tensor(output, input, async_op=True)
708
        work.wait()
709

710
        expect = input.view(self.world_size, *out_shape) \
711
            .chunk(self.world_size)[self.rank] * self.world_size
712
        self.assertTrue(torch.allclose(output, expect))
713

714
    @requires_gloo()
715
    def test_reduce_scatter_tensor_coalesced(self):
716
        store = c10d.FileStore(self.file_name, self.world_size)
717
        dist.init_process_group(
718
            backend="gloo",
719
            store=store,
720
            rank=self.rank,
721
            world_size=self.world_size,
722
        )
723
        torch.manual_seed(42)
724
        out_shapes = [(5, 5), (10, 10), (15, 15)]
725
        in_shapes = [(s[0] * self.world_size,) + s[1:] for s in out_shapes]
726

727
        outputs = [torch.empty(s) for s in out_shapes]
728
        inputs = [torch.rand(s) for s in in_shapes]
729
        work = dist.group.WORLD.reduce_scatter_tensor_coalesced(outputs, inputs)
730
        work.wait()
731

732
        for output, input in zip(outputs, inputs):
733
            expect = input.view(self.world_size, *output.shape) \
734
                .chunk(self.world_size)[self.rank] * self.world_size
735
            self.assertTrue(torch.allclose(output, expect))
736

737
    @requires_gloo()
738
    def test_scatter_checks(self):
739
        store = c10d.FileStore(self.file_name, self.world_size)
740
        pg = self._create_process_group_gloo(
741
            store, self.rank, self.world_size, self.opts()
742
        )
743

744
        t1 = torch.zeros([1], dtype=torch.float32)
745
        t2 = torch.zeros([1], dtype=torch.float64)
746
        t3 = torch.zeros([2], dtype=torch.float32)
747

748
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
749
            opts = c10d.ScatterOptions()
750
            opts.rootRank = -1
751
            pg.scatter([t1], [], opts)
752

753
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
754
            opts = c10d.ScatterOptions()
755
            opts.rootRank = self.world_size
756
            pg.scatter([t1], [], opts)
757

758
        with self.assertRaisesRegex(
759
            RuntimeError, "requires a single-element output tensor list"
760
        ):
761
            opts = c10d.ScatterOptions()
762
            opts.rootRank = 0
763
            pg.scatter([], [], opts)
764

765
        with self.assertRaisesRegex(
766
            RuntimeError, "requires a single-element output tensor list"
767
        ):
768
            opts = c10d.ScatterOptions()
769
            opts.rootRank = 0
770
            pg.scatter([t1, t1], [], opts)
771

772
        with self.assertRaisesRegex(RuntimeError, "requires a single-element input list"):
773
            opts = c10d.ScatterOptions()
774
            opts.rootRank = self.rank
775
            pg.scatter([t1], [], opts)
776

777
        with self.assertRaisesRegex(RuntimeError, "requires a single-element input list"):
778
            opts = c10d.ScatterOptions()
779
            opts.rootRank = self.rank
780
            pg.scatter([t1], [[t1] * self.world_size, [t1] * self.world_size], opts)
781

782
        desired_list_size = self.world_size
783
        incorrect_list_size = self.world_size - 1
784
        err_str = "Incorrect input list size {}. Input list size should be {}"
785
        with self.assertRaisesRegex(
786
            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
787
        ):
788
            opts = c10d.ScatterOptions()
789
            opts.rootRank = self.rank
790
            pg.scatter([t1], [[t1] * incorrect_list_size], opts)
791

792
        incorrect_list_size = self.world_size + 1
793
        with self.assertRaisesRegex(
794
            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
795
        ):
796
            opts = c10d.ScatterOptions()
797
            opts.rootRank = self.rank
798
            pg.scatter([t1], [[t1] * incorrect_list_size], opts)
799

800
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
801
            opts = c10d.ScatterOptions()
802
            opts.rootRank = self.rank
803
            pg.scatter([t1], [[t2] * self.world_size], opts)
804

805
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
806
            opts = c10d.ScatterOptions()
807
            opts.rootRank = self.rank
808
            pg.scatter([t1], [[t3] * self.world_size], opts)
809

810
        with self.assertRaisesRegex(RuntimeError, "requires empty input on non-root"):
811
            opts = c10d.ScatterOptions()
812
            opts.rootRank = (self.rank + 1) % self.world_size
813
            pg.scatter([t1], [[t1] * self.world_size], opts)
814

815
    def _test_scatter_basics(self, fn):
816
        store = c10d.FileStore(self.file_name, self.world_size)
817
        pg = self._create_process_group_gloo(
818
            store, self.rank, self.world_size, self.opts()
819
        )
820

821
        # Preallocate tensors for input/output
822
        input = [fn(torch.tensor([self.rank])) for _ in range(self.world_size)]
823
        outputs = [fn(torch.tensor([-1])) for _ in range(self.world_size)]
824

825
        # Take turns being the scatter root and accumulate work items
826
        futures = []
827
        for i in range(self.world_size):
828
            opts = c10d.ScatterOptions()
829
            opts.rootRank = i
830
            if i == self.rank:
831
                futures.append(pg.scatter([outputs[i]], [input], opts).get_future())
832
            else:
833
                futures.append(pg.scatter([outputs[i]], [], opts).get_future())
834

835
        # Wait for work to complete
836
        for i in range(self.world_size):
837
            futures[i].wait()
838
            result = futures[i].value()
839
            self.assertEqual(torch.tensor([i]), result[0])
840

841
    @requires_gloo()
842
    def test_scatter_basics(self):
843
        self._test_scatter_basics(lambda t: t.clone())
844

845
    @skip_if_lt_x_gpu(2)
846
    @requires_gloo()
847
    def test_scatter_basics_cuda(self):
848
        self._test_scatter_basics(lambda t: t.clone().cuda())
849

850
    def _test_scatter_stress(self, inputs, fn):
851
        store = c10d.FileStore(self.file_name, self.world_size)
852
        pg = self._create_process_group_gloo(
853
            store, self.rank, self.world_size, self.opts(threads=8)
854
        )
855
        outputs = [
856
            [fn(torch.tensor([-1])) for _ in range(self.world_size)]
857
            for _ in range(len(inputs))
858
        ]
859
        future_handles = []
860
        for i in range(len(inputs)):
861
            for root in range(self.world_size):
862
                opts = c10d.ScatterOptions()
863
                opts.rootRank = root
864
                if root == self.rank:
865
                    fut = pg.scatter(
866
                        [outputs[i][root]], [[fn(e) for e in inputs[i]]], opts
867
                    ).get_future()
868
                else:
869
                    fut = pg.scatter([outputs[i][root]], [], opts).get_future()
870
                future_handles.append(fut)
871

872
        for i, future_handle in enumerate(future_handles):
873
            future_handle.wait()
874
            iter = i // self.world_size
875
            root = i % self.world_size
876
            result = future_handle.value()
877

878
            self.assertEqual(
879
                torch.tensor([iter + root]),
880
                result[0],
881
                msg=("Mismatch in iteration %d for rank %d" % (iter, root)),
882
            )
883

884
    @requires_gloo()
885
    def test_set_gloo_pg_timeout(self):
886
        store = c10d.FileStore(self.file_name, self.world_size)
887
        pg = self._create_process_group_gloo(
888
            store, self.rank, self.world_size, self.opts()
889
        )
890
        pg.allreduce(torch.rand(10))
891
        self.assertEqual(pg.options._timeout, timedelta(seconds=50))
892
        pg._set_default_timeout(timedelta(seconds=23))
893
        self.assertEqual(pg.options._timeout, timedelta(seconds=23))
894

895
    @requires_gloo()
896
    def test_scatter_stress(self):
897
        inputs = [
898
            [torch.tensor([i + self.rank]) for _ in range(self.world_size)]
899
            for i in range(1000)
900
        ]
901
        self._test_scatter_stress(inputs, lambda t: t.clone())
902

903
    @skip_but_pass_in_sandcastle(
904
        "Test is flaky, see https://github.com/pytorch/pytorch/issues/15963"
905
    )
906
    @skip_if_lt_x_gpu(2)
907
    @requires_gloo()
908
    def test_scatter_stress_cuda(self):
909
        inputs = [
910
            [torch.tensor([i + self.rank]) for _ in range(self.world_size)]
911
            for i in range(1000)
912
        ]
913
        self._test_scatter_stress(inputs, lambda t: t.clone().cuda())
914

915
    @requires_gloo()
916
    def test_gather_checks(self):
917
        store = c10d.FileStore(self.file_name, self.world_size)
918
        pg = self._create_process_group_gloo(
919
            store, self.rank, self.world_size, self.opts()
920
        )
921

922
        t1 = torch.zeros([1], dtype=torch.float32)
923
        t2 = torch.zeros([1], dtype=torch.float64)
924
        t3 = torch.zeros([2], dtype=torch.float32)
925

926
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
927
            opts = c10d.GatherOptions()
928
            opts.rootRank = -1
929
            pg.gather([], [t1], opts)
930

931
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
932
            opts = c10d.GatherOptions()
933
            opts.rootRank = self.world_size
934
            pg.gather([], [t1], opts)
935

936
        with self.assertRaisesRegex(
937
            RuntimeError, "requires a single-element input tensor list"
938
        ):
939
            opts = c10d.GatherOptions()
940
            opts.rootRank = 0
941
            pg.gather([], [], opts)
942

943
        with self.assertRaisesRegex(
944
            RuntimeError, "requires a single-element input tensor list"
945
        ):
946
            opts = c10d.GatherOptions()
947
            opts.rootRank = 0
948
            pg.gather([], [t1, t1], opts)
949

950
        with self.assertRaisesRegex(
951
            RuntimeError, "requires a single-element output list"
952
        ):
953
            opts = c10d.GatherOptions()
954
            opts.rootRank = self.rank
955
            pg.gather([], [t1], opts)
956

957
        with self.assertRaisesRegex(
958
            RuntimeError, "requires a single-element output list"
959
        ):
960
            opts = c10d.GatherOptions()
961
            opts.rootRank = self.rank
962
            pg.gather([[t1] * self.world_size, [t1] * self.world_size], [t1], opts)
963

964
        desired_list_size = self.world_size
965
        incorrect_list_size = self.world_size - 1
966
        err_str = "Incorrect output list size {}. Output list size should be {}"
967
        with self.assertRaisesRegex(
968
            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
969
        ):
970
            opts = c10d.GatherOptions()
971
            opts.rootRank = self.rank
972
            pg.gather([[t1] * incorrect_list_size], [t1], opts)
973

974
        incorrect_list_size = self.world_size + 1
975
        with self.assertRaisesRegex(
976
            RuntimeError, err_str.format(incorrect_list_size, desired_list_size)
977
        ):
978
            opts = c10d.GatherOptions()
979
            opts.rootRank = self.rank
980
            pg.gather([[t1] * incorrect_list_size], [t1], opts)
981

982
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
983
            opts = c10d.GatherOptions()
984
            opts.rootRank = self.rank
985
            pg.gather([[t2] * self.world_size], [t1], opts)
986

987
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
988
            opts = c10d.GatherOptions()
989
            opts.rootRank = self.rank
990
            pg.gather([[t3] * self.world_size], [t1], opts)
991

992
        with self.assertRaisesRegex(RuntimeError, "requires empty output on non-root"):
993
            opts = c10d.GatherOptions()
994
            opts.rootRank = (self.rank + 1) % self.world_size
995
            pg.gather([[t1] * self.world_size], [t1], opts)
996

997
    def _test_gather_basics(self, fn):
998
        store = c10d.FileStore(self.file_name, self.world_size)
999
        pg = self._create_process_group_gloo(
1000
            store, self.rank, self.world_size, self.opts()
1001
        )
1002

1003
        # Preallocate tensors for input/output
1004
        input = [fn(torch.tensor([self.rank]))]
1005
        outputs = [fn(torch.tensor([-1])) for _ in range(self.world_size)]
1006

1007
        # Take turns being the gather root and accumulate work items
1008
        futures = []
1009
        for i in range(self.world_size):
1010
            opts = c10d.GatherOptions()
1011
            opts.rootRank = i
1012
            if i == self.rank:
1013
                futures.append(pg.gather([outputs], input, opts).get_future())
1014
            else:
1015
                futures.append(pg.gather([], input, opts).get_future())
1016

1017
        # Wait for work to complete
1018
        expected = [fn(torch.tensor([rank])) for rank in range(self.world_size)]
1019
        for i in range(self.world_size):
1020
            futures[i].wait()
1021
            result = futures[i].value()
1022
            if i == self.rank:
1023
                self.assertEqual(expected, result)
1024

1025
    @requires_gloo()
1026
    def test_gather_basics(self):
1027
        self._test_gather_basics(lambda t: t.clone())
1028

1029
    @skip_if_lt_x_gpu(2)
1030
    @requires_gloo()
1031
    def test_gather_basics_cuda(self):
1032
        self._test_gather_basics(lambda t: t.clone().cuda())
1033

1034
    @requires_gloo()
1035
    def test_gather_noncontiguous_input(self):
1036
        # Take a column of 2D tensor, such that memory is not dense
1037
        self._test_gather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
1038

1039
    def _test_gather_stress(self, inputs, fn):
1040
        store = c10d.FileStore(self.file_name, self.world_size)
1041
        pg = self._create_process_group_gloo(
1042
            store, self.rank, self.world_size, self.opts(threads=8)
1043
        )
1044
        future_handles = []
1045
        outputs = [
1046
            [[fn(torch.tensor([-1])) for _ in range(self.world_size)]]
1047
            for _ in range(len(inputs))
1048
        ]
1049
        expected_outputs = [
1050
            [[torch.tensor([i + j]) for j in range(self.world_size)]]
1051
            for i in range(len(inputs))
1052
        ]
1053
        for i in range(len(inputs)):
1054
            for root in range(self.world_size):
1055
                opts = c10d.GatherOptions()
1056
                opts.rootRank = root
1057
                if root == self.rank:
1058
                    fut = pg.gather(outputs[i], [fn(inputs[i])], opts).get_future()
1059
                else:
1060
                    fut = pg.gather([], [fn(inputs[i])], opts).get_future()
1061
                future_handles.append(fut)
1062

1063
        for i, future_handle in enumerate(future_handles):
1064
            future_handle.wait()
1065
            iter = i // self.world_size
1066
            root = i % self.world_size
1067
            if root == self.rank:
1068
                result = future_handle.value()
1069
                self.assertEqual(
1070
                    expected_outputs[iter],
1071
                    [result],
1072
                    msg=("Mismatch in iteration %d for root %d" % (iter, root)),
1073
                )
1074

1075
    @requires_gloo()
1076
    def test_gather_stress(self):
1077
        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
1078
        self._test_gather_stress(inputs, lambda t: t.clone())
1079

1080
    @skip_if_lt_x_gpu(2)
1081
    @requires_gloo()
1082
    def test_gather_stress_cuda(self):
1083
        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
1084
        self._test_gather_stress(inputs, lambda t: t.clone().cuda())
1085

1086
    @requires_gloo()
1087
    def test_allgather_checks(self):
1088
        store = c10d.FileStore(self.file_name, self.world_size)
1089
        pg = self._create_process_group_gloo(
1090
            store, self.rank, self.world_size, self.opts()
1091
        )
1092

1093
        t1 = torch.zeros([1], dtype=torch.float32)
1094
        t2 = torch.zeros([1], dtype=torch.float64)
1095
        t3 = torch.zeros([2], dtype=torch.float32)
1096

1097
        with self.assertRaisesRegex(RuntimeError, "requires non-empty input tensor list"):
1098
            pg.allgather([], [])
1099

1100
        with self.assertRaisesRegex(
1101
            RuntimeError, "requires input/output tensor lists to have the same length"
1102
        ):
1103
            pg.allgather([], [t1])
1104

1105
        with self.assertRaisesRegex(
1106
            RuntimeError, "requires input/output tensor lists to have the same length"
1107
        ):
1108
            pg.allgather([[t1] * self.world_size, [t1] * self.world_size], [t1])
1109

1110
        with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"):
1111
            pg.allgather([[t1] * (self.world_size - 1)], [t1])
1112

1113
        with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"):
1114
            pg.allgather([[t1] * (self.world_size + 1)], [t1])
1115

1116
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
1117
            pg.allgather(
1118
                [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2]
1119
            )
1120

1121
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
1122
            pg.allgather(
1123
                [[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3]
1124
            )
1125

1126
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):
1127
            pg.allgather([([t1, t2] * (self.world_size))[: self.world_size]], [t1])
1128

1129
        with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):
1130
            pg.allgather([([t1, t3] * (self.world_size))[: self.world_size]], [t1])
1131

1132
    def _test_allgather_basics(self, fn):
1133
        store = c10d.FileStore(self.file_name, self.world_size)
1134
        pg = self._create_process_group_gloo(
1135
            store, self.rank, self.world_size, self.opts()
1136
        )
1137

1138
        # Run with N input tensor per rank
1139
        for n in [1, 2, 3]:
1140
            input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]
1141
            output = [
1142
                [fn(torch.tensor([-1])) for _ in range(n * self.world_size)]
1143
                for _ in range(n)
1144
            ]
1145
            expected_output = [
1146
                [fn(torch.tensor([i])) for i in range(n * self.world_size)]
1147
                for _ in range(n)
1148
            ]
1149
            fut = pg.allgather(output, input).get_future()
1150
            fut.wait()
1151
            result = fut.value()
1152
            if n == 1:
1153
                result = [result]
1154
            self.assertEqual(expected_output, result)
1155

1156
    @requires_gloo()
1157
    def test_allgather_basics(self):
1158
        self._test_allgather_basics(lambda t: t.clone())
1159

1160
    @skip_if_lt_x_gpu(2)
1161
    @requires_gloo()
1162
    def test_allgather_basics_cuda(self):
1163
        self._test_allgather_basics(lambda t: t.clone().cuda())
1164

1165
    @requires_gloo()
1166
    def test_allgather_noncontiguous_input(self):
1167
        # Take a column of 2D tensor, such that memory is not dense
1168
        self._test_allgather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
1169

1170
    def _test_allgather_stress(self, inputs, fn):
1171
        store = c10d.FileStore(self.file_name, self.world_size)
1172
        pg = self._create_process_group_gloo(
1173
            store, self.rank, self.world_size, self.opts(threads=8)
1174
        )
1175
        future_handles = []
1176
        outputs = [
1177
            [[fn(torch.tensor([-1])) for _ in range(self.world_size)]]
1178
            for _ in range(len(inputs))
1179
        ]
1180
        expected_outputs = [
1181
            [[torch.tensor([i + j]) for j in range(self.world_size)]]
1182
            for i in range(len(inputs))
1183
        ]
1184
        input_holder = {}
1185
        for i in range(len(inputs)):
1186
            # Note that this works around the data race discussed in
1187
            # https://github.com/pytorch/pytorch/issues/75529, but we should
1188
            # actually be able to pass the list directly into allgather when
1189
            # that race is fixed.
1190
            input_holder[i] = [fn(inputs[i])]
1191
            fut = pg.allgather(outputs[i], input_holder[i]).get_future()
1192
            future_handles.append(fut)
1193

1194
        for i, future_handle in enumerate(future_handles):
1195
            future_handle.wait()
1196
            result = future_handle.value()
1197
            self.assertEqual(
1198
                expected_outputs[i],
1199
                [result],
1200
                msg=("Mismatch in iteration %d" % i),
1201
            )
1202

1203
    @requires_gloo()
1204
    def test_allgather_stress(self):
1205
        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
1206
        self._test_allgather_stress(inputs, lambda t: t.clone())
1207

1208
    @skip_if_lt_x_gpu(2)
1209
    @requires_gloo()
1210
    def test_allgather_stress_cuda(self):
1211
        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
1212
        self._test_allgather_stress(inputs, lambda t: t.clone().cuda())
1213

1214
    @requires_gloo()
1215
    def test_allgather_coalesced_checks(self):
1216
        store = c10d.FileStore(self.file_name, self.world_size)
1217
        pg = self._create_process_group_gloo(
1218
            store, self.rank, self.world_size, self.opts()
1219
        )
1220
        dummy_input = [torch.zeros([1], dtype=torch.float32)]
1221
        dummy_output_lists = [
1222
            [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size)
1223
        ]
1224

1225
        # One of output tensors does not match input list.
1226
        dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)]
1227
        with self.assertRaisesRegex(
1228
            RuntimeError, "invalid size of output tensor at index 0"
1229
        ):
1230
            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1231

1232
        # One of output tensors does not match input list.
1233
        dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)]
1234
        with self.assertRaisesRegex(RuntimeError, "invalid tensor type at index 0"):
1235
            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1236

1237
        # Output lists have too many elements
1238
        dummy_output_lists = [
1239
            [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size + 1)
1240
        ]
1241
        with self.assertRaisesRegex(
1242
            RuntimeError, "output lists should be equal to world size"
1243
        ):
1244
            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1245

1246
        # Output is not a list of lists.
1247
        dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]
1248
        with self.assertRaisesRegex(
1249
            TypeError, "Invalid function argument.*output_tensor_lists"
1250
        ):
1251
            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
1252

1253
    @requires_gloo()
1254
    def test_allgather_coalesced_async(self):
1255
        store = c10d.FileStore(self.file_name, self.world_size)
1256
        c10d.init_process_group(
1257
            backend="gloo", rank=self.rank, world_size=self.world_size, store=store
1258
        )
1259

1260
        xxs = [2 * [torch.tensor([i + self.rank])] for i in range(2)]
1261
        yys = [[[torch.zeros_like(x) for x in xx] for _ in range(self.world_size)] for xx in xxs]
1262
        futs = [c10d.all_gather_coalesced(yy, xx, async_op=True) for xx, yy in zip(xxs, yys)]
1263

1264
        # expected outputs
1265
        zzs = [[2 * [torch.tensor([i + r])] for r in range(self.world_size)] for i in range(2)]
1266

1267
        torch.futures.wait_all(futs)
1268
        for yy, zz in zip(yys, zzs):
1269
            # one iteration
1270
            for y_out, z_out in zip(yy, zz):
1271
                # one output tensor list
1272
                for y, z in zip(y_out, z_out):
1273
                    # one tensor in output tensor list
1274
                    self.assertEqual(y, z)
1275

1276
        # Added to address https://github.com/pytorch/pytorch/issues/65231
1277
        # In the failed tests, all assertEqual are passed on all processes.
1278
        # However, one of the processes didn't call ProcessGroupGloo
1279
        # destructor before exiting program. This is not surprising as the only
1280
        # guarantee that Python makes is that garbage collection MAY happen
1281
        # before the program exits. If GC didn't happen, the two threads in
1282
        # ProcessGroup might be destructed before joined.
1283
        # FIXME: it's still unclear why only this test require explicit
1284
        # destroy_process_group()
1285
        c10d.destroy_process_group()
1286

1287
    @requires_gloo()
1288
    def test_reduce_checks(self):
1289
        store = c10d.FileStore(self.file_name, self.world_size)
1290
        pg = pg = self._create_process_group_gloo(
1291
            store, self.rank, self.world_size, self.opts()
1292
        )
1293

1294
        t1 = torch.zeros([1], dtype=torch.float32)
1295

1296
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
1297
            opts = c10d.ReduceOptions()
1298
            opts.rootRank = -1
1299
            opts.rootTensor = 0
1300
            pg.reduce([t1], opts)
1301

1302
        with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
1303
            opts = c10d.ReduceOptions()
1304
            opts.rootRank = self.world_size
1305
            opts.rootTensor = 0
1306
            pg.reduce([t1], opts)
1307

1308
        with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):
1309
            opts = c10d.ReduceOptions()
1310
            opts.rootRank = self.rank
1311
            opts.rootTensor = 1
1312
            pg.reduce([t1], opts)
1313

1314
        with self.assertRaisesRegex(
1315
            RuntimeError, "requires a single-element tensor list"
1316
        ):
1317
            opts = c10d.ReduceOptions()
1318
            opts.rootRank = self.rank
1319
            opts.rootTensor = 0
1320
            pg.reduce([t1, t1], opts)
1321

1322
    def _test_reduce_basics(self, fn):
1323
        store = c10d.FileStore(self.file_name, self.world_size)
1324
        pg = self._create_process_group_gloo(
1325
            store, self.rank, self.world_size, self.opts()
1326
        )
1327
        for (op, input, output) in simple_reduce_tests(self.rank, self.world_size):
1328
            for root in range(self.world_size):
1329
                opts = c10d.ReduceOptions()
1330
                opts.reduceOp = op
1331
                opts.rootRank = root
1332
                tmp = fn(input)
1333
                fut = pg.reduce([tmp], opts).get_future()
1334
                fut.wait()
1335
                result = fut.value()
1336
                if root == self.rank:
1337
                    self.assertEqual(output, result[0])
1338

1339
    @requires_gloo()
1340
    def test_reduce_basics(self):
1341
        self._test_reduce_basics(lambda t: t.clone())
1342

1343
    @skip_if_lt_x_gpu(2)
1344
    @requires_gloo()
1345
    def test_reduce_basics_cuda(self):
1346
        self._test_reduce_basics(lambda t: t.clone().cuda())
1347

1348
    def _test_reduce_stress(self, inputs):
1349
        store = c10d.FileStore(self.file_name, self.world_size)
1350
        pg = self._create_process_group_gloo(
1351
            store, self.rank, self.world_size, self.opts(threads=8)
1352
        )
1353
        future_handles = []
1354
        outputs = []
1355
        for i in range(len(inputs)):
1356
            for root in range(self.world_size):
1357
                opts = c10d.ReduceOptions()
1358
                opts.rootRank = root
1359
                tmp = inputs[i].clone()
1360
                outputs.append(tmp)
1361
                fut = pg.reduce([tmp], opts).get_future()
1362
                future_handles.append(fut)
1363

1364
        for i, future_handle in enumerate(future_handles):
1365
            future_handle.wait()
1366
            result = future_handle.value()
1367
            iter = i // self.world_size
1368
            root = i % self.world_size
1369
            if root == self.rank:
1370
                self.assertEqual(
1371
                    torch.tensor(
1372
                        [
1373
                            (iter * self.world_size)
1374
                            + (self.world_size * (self.world_size - 1) // 2)
1375
                        ]
1376
                    ),
1377
                    result[0],
1378
                    msg=("Mismatch in iteration %d with root rank %d" % (iter, root)),
1379
                )
1380

1381
    @requires_gloo()
1382
    def test_reduce_stress(self):
1383
        inputs = [torch.tensor([i + self.rank]) for i in range(1000)]
1384
        self._test_reduce_stress(inputs)
1385

1386
    @skip_if_lt_x_gpu(2)
1387
    @requires_gloo()
1388
    def test_reduce_stress_cuda(self):
1389
        inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]
1390
        self._test_reduce_stress(inputs)
1391

1392
    @requires_gloo()
1393
    def test_send_recv_all_to_all(self):
1394
        store = c10d.FileStore(self.file_name, self.world_size)
1395
        pg = self._create_process_group_gloo(
1396
            store, self.rank, self.world_size, self.opts()
1397
        )
1398

1399
        # Preallocate tensors for input/output
1400
        inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
1401
        outputs = [torch.tensor([-1]) for _ in range(self.world_size)]
1402

1403
        # Issue sends
1404
        send_work = []
1405
        for i in range(self.world_size):
1406
            if i == self.rank:
1407
                continue
1408
            send_work.append(pg.send([inputs[i]], i, 0))
1409

1410
        # Issue recvs
1411
        recv_work = []
1412
        for i in range(self.world_size):
1413
            if i == self.rank:
1414
                continue
1415
            recv_work.append(pg.recv([outputs[i]], i, 0))
1416

1417
        # Wait for sends to complete
1418
        for work in send_work:
1419
            work.wait()
1420
            self.assertTrue(work.is_completed())
1421

1422
        # Wait for recvs to complete
1423
        for work in recv_work:
1424
            work.wait()
1425
            self.assertTrue(work.is_completed())
1426

1427
        # Test that every output other than our own contains the respective rank
1428
        for i in range(self.world_size):
1429
            if i == self.rank:
1430
                continue
1431
            self.assertEqual(torch.tensor([i]), outputs[i])
1432

1433
    @requires_gloo()
1434
    def test_barrier_implies_wait(self):
1435
        store = c10d.FileStore(self.file_name, self.world_size)
1436
        pg = self._create_process_group_gloo(
1437
            store, self.rank, self.world_size, self.opts()
1438
        )
1439

1440
        # Kick off allreduce operations
1441
        size = (100, 100)
1442
        num = 16
1443
        tensors = [torch.full(size, float(i)) for i in range(num)]
1444
        for tensor in tensors:
1445
            # Note: leak the returned work handle
1446
            pg.allreduce(tensor)
1447

1448
        # Barrier should ensure all previous work has completed
1449
        pg.barrier().get_future().wait()
1450

1451
        for i, tensor in enumerate(tensors):
1452
            self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)
1453

1454
    @skip_if_win32()
1455
    @requires_gloo()
1456
    def test_round_robin(self):
1457
        num_process_groups = 2
1458
        store = c10d.FileStore(self.file_name, self.world_size)
1459
        c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
1460
        pg = c10d._round_robin_process_groups(
1461
            [
1462
                c10d.new_group(pg_options=self.opts())
1463
                for i in range(num_process_groups)
1464
            ]
1465
        )
1466

1467
        # Run a few collectives so that we have called each process group
1468
        for _ in range(num_process_groups + 1):
1469
            tensor = torch.full([100, 100], float(self.rank))
1470
            pg.broadcast(tensor, root=0).wait()
1471
            self.assertEqual(torch.full([100, 100], 0.0), tensor)
1472

1473
    @skip_if_win32()
1474
    @requires_gloo()
1475
    def test_round_robin_create_destroy(self):
1476
        store = c10d.FileStore(self.file_name, self.world_size)
1477
        c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
1478

1479
        def create(num, prefix):
1480
            return c10d._round_robin_process_groups(
1481
                [
1482
                    c10d.new_group(pg_options=self.opts())
1483
                    for i in range(num)
1484
                ]
1485
            )
1486

1487
        # Run create/use/destroy twice
1488
        for i in range(2):
1489
            num_process_groups = 2
1490
            pg = create(num=num_process_groups, prefix=i)
1491
            for _ in range(3):
1492
                tensor = torch.ones([10, 10])
1493
                pg.allreduce(tensor).wait()
1494
                self.assertEqual(torch.full([10, 10], float(self.world_size)), tensor)
1495
            del pg
1496

1497

1498
class DistributedDataParallelTest(
1499
    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
1500
):
1501
    def setUp(self):
1502
        super().setUp()
1503
        self._spawn_processes()
1504

1505
    def _get_process_group(self):
1506
        store = self._get_store()
1507
        c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
1508
        return c10d.distributed_c10d._get_default_group()
1509

1510
    def _test_gloo_backend(
1511
        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
1512
    ):
1513
        store = c10d.FileStore(self.file_name, self.world_size)
1514
        c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
1515
        process_group = c10d.distributed_c10d._get_default_group()
1516
        device = devices[-1]
1517
        backend = process_group._get_backend(device)
1518
        backend.create_device(interface=LOOPBACK)
1519
        self._test_ddp_with_process_group(
1520
            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
1521
        )
1522

1523
    @requires_gloo()
1524
    def test_gloo_backend_cpu_module(self):
1525
        self._test_gloo_backend([torch.device("cpu")], None)
1526

1527
    @requires_gloo()
1528
    def test_gloo_backend_cpu_module_grad_is_view(self):
1529
        self._test_gloo_backend(
1530
            [torch.device("cpu")], None, gradient_as_bucket_view=True
1531
        )
1532

1533
    @requires_gloo()
1534
    @skip_if_lt_x_gpu(2)
1535
    def test_gloo_backend_1gpu_module_device_ids_integer_list(self):
1536
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
1537
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1538
        self._test_gloo_backend(devices, int_devices)
1539

1540
    @requires_gloo()
1541
    @skip_if_lt_x_gpu(2)
1542
    def test_gloo_backend_1gpu_module_device_ids_torch_device_list(self):
1543
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
1544
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1545
        self._test_gloo_backend(devices, devices)
1546

1547
    @requires_gloo()
1548
    @skip_if_lt_x_gpu(4)
1549
    def test_gloo_backend_2gpu_module(self):
1550
        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
1551
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1552
        self._test_gloo_backend(devices, None, multi_device=True)
1553

1554
    @requires_gloo()
1555
    @skip_if_lt_x_gpu(8)
1556
    def test_gloo_backend_4gpu_module(self):
1557
        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
1558
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1559
        self._test_gloo_backend(devices, None, multi_device=True)
1560

1561
    def _test_global_local_unused_params_grad(
1562
        self, gradient_as_bucket_view=False, static_graph=False
1563
    ):
1564
        """
1565
        By simulating a multi-task training, this test is to make sure:
1566
        1) DDP does not touch the grad of globally unused parameters.
1567
        2) DDP does update the grad of locally unused parameters.
1568
        """
1569

1570
        class GlobalLocalUnusedParamModule(nn.Module):
1571
            def __init__(self):
1572
                super().__init__()
1573
                self.t0 = Task()
1574
                self.t1 = Task()
1575
                self.task_unused = Task()
1576

1577
            def task_parameters(self):
1578
                return (self.t0.p, self.t1.p, self.task_unused.p)
1579

1580
            def forward(self, x, rank):
1581
                return self.t0(x) if rank == 0 else self.t1(x)
1582

1583
        def run_and_verify_grad(model):
1584
            # Run forward
1585
            output = model(8, self.rank)
1586

1587
            # The grads of all parameters should be None at this point.
1588
            t0_p, t1_p, task_unused_p = model.module.task_parameters()
1589
            self.assertIsNone(t0_p.grad)
1590
            self.assertIsNone(t1_p.grad)
1591
            self.assertIsNone(task_unused_p.grad)
1592

1593
            # Run backward
1594
            output.mean().backward()
1595

1596
            # Now locally unused parameter should have grad updated on all ranks.
1597
            # However the globally unused parameter should still have None grad.
1598
            self.assertIsNotNone(t0_p.grad)
1599
            self.assertIsNotNone(t1_p.grad)
1600
            self.assertIsNone(task_unused_p.grad)
1601

1602
        process_group = self._get_process_group()
1603

1604
        # Test on CPU
1605
        cpu_model = DistributedDataParallel(
1606
            GlobalLocalUnusedParamModule().cpu(),
1607
            process_group=process_group,
1608
            find_unused_parameters=True,
1609
            gradient_as_bucket_view=gradient_as_bucket_view,
1610
            static_graph=static_graph,
1611
        )
1612
        run_and_verify_grad(cpu_model)
1613

1614
        # Test on GPU
1615
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1616
        gpu_model = DistributedDataParallel(
1617
            GlobalLocalUnusedParamModule().to(device_id),
1618
            device_ids=[device_id],
1619
            process_group=process_group,
1620
            find_unused_parameters=True,
1621
            gradient_as_bucket_view=gradient_as_bucket_view,
1622
            static_graph=static_graph,
1623
        )
1624
        run_and_verify_grad(gpu_model)
1625

1626
    @requires_gloo()
1627
    @skip_if_lt_x_gpu(2)
1628
    def test_global_local_unused_params_grad(self):
1629
        self._test_global_local_unused_params_grad()
1630

1631
    @requires_gloo()
1632
    @skip_if_lt_x_gpu(2)
1633
    def test_global_local_unused_params_grad_with_grad_is_view(self):
1634
        self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)
1635

1636
    @requires_gloo()
1637
    @skip_if_lt_x_gpu(2)
1638
    def test_global_local_unused_params_grad_with_static_graph(self):
1639
        self._test_global_local_unused_params_grad(static_graph=True)
1640

1641
    @requires_gloo()
1642
    @skip_if_lt_x_gpu(2)
1643
    def test_find_unused_parameters_when_unused_parameters_empty(self):
1644
        """
1645
        An empty unused_parameters array does not imply find_unused_parameters =
1646
        false. This test makes sure that DDP allreduces unused parameters
1647
        accordingly where the forward pass in some process uses all parameters.
1648
        This unit test creates a module that uses all parameters in rank = 0, and
1649
        has unused parameters in other ranks.
1650
        """
1651

1652
        class FindUnusedParamModule(nn.Module):
1653
            def __init__(self):
1654
                super().__init__()
1655
                self.t0 = Task()
1656
                self.t1 = Task()
1657

1658
            def task_parameters(self):
1659
                return (self.t0.p, self.t1.p)
1660

1661
            def forward(self, x, rank):
1662
                return self.t1(self.t0(x)) if rank == 0 else self.t1(x)
1663

1664
        def run_and_verify_grad(model):
1665
            # Run forward
1666
            output = model(8, self.rank)
1667

1668
            # The grads of all parameters should be None at this point.
1669
            [self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]
1670

1671
            # Run backward
1672
            output.mean().backward()
1673

1674
            # Now locally unused parameter should have grad updated on all ranks.
1675
            [self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]
1676

1677
        process_group = self._get_process_group()
1678

1679
        # Test on CPU
1680
        cpu_model = DistributedDataParallel(
1681
            FindUnusedParamModule().cpu(),
1682
            process_group=process_group,
1683
            find_unused_parameters=True,
1684
        )
1685
        run_and_verify_grad(cpu_model)
1686

1687
        # Test on GPU
1688
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1689
        gpu_model = DistributedDataParallel(
1690
            FindUnusedParamModule().to(device_id),
1691
            device_ids=[device_id],
1692
            process_group=process_group,
1693
            find_unused_parameters=True,
1694
        )
1695
        run_and_verify_grad(gpu_model)
1696

1697
    @requires_gloo()
1698
    def test_ignored_output(self):
1699
        """
1700
        Test that the output of a model can be ignored and that there is no
1701
        implicit requirement that `backward` gets called.
1702
        """
1703
        process_group = self._get_process_group()
1704

1705
        class IgnoredOutput(nn.Module):
1706
            def __init__(self):
1707
                super().__init__()
1708
                self.fc1 = nn.Linear(2, 10, bias=False)
1709
                self.fc2 = nn.Linear(10, 4, bias=False)
1710
                self.relu = nn.ReLU()
1711

1712
            def forward(self, x):
1713
                x = self.relu(self.fc1(x))
1714
                x = self.relu(self.fc2(x))
1715
                return F.softmax(x, dim=1)
1716

1717
        model = DistributedDataParallel(
1718
            IgnoredOutput().float(),
1719
            process_group=process_group,
1720
        )
1721

1722
        batch_size = 4
1723
        criterion = nn.CrossEntropyLoss()
1724
        input = torch.rand([batch_size, 2], dtype=torch.float)
1725
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
1726

1727
        # Run a few iterations where we ignore the output.
1728
        for _ in range(4):
1729
            output = model(input)
1730
            del output
1731

1732
        # Run a few iterations where we use the output.
1733
        for _ in range(4):
1734
            output = model(input)
1735
            loss = criterion(output, target)
1736
            loss.backward()
1737

1738
    @requires_gloo()
1739
    def test_ignored_output_with_unused_parameters(self):
1740
        """
1741
        Test that the output of a model can be ignored and that there is no
1742
        implicit requirement that `backward` gets called, if not all model
1743
        parameters participated in computing the model output.
1744
        """
1745
        process_group = self._get_process_group()
1746

1747
        class IgnoredOutputWithUnusedParameters(nn.Module):
1748
            def __init__(self):
1749
                super().__init__()
1750
                self.fc1 = nn.Linear(2, 10, bias=False)
1751
                self.fc2 = nn.Linear(10, 4, bias=False)
1752
                self.fc3 = nn.Linear(4, 4, bias=False)
1753
                self.relu = nn.ReLU()
1754

1755
            def forward(self, x):
1756
                x = self.relu(self.fc1(x))
1757
                x = self.relu(self.fc2(x))
1758
                return F.softmax(x, dim=1)
1759

1760
        model = DistributedDataParallel(
1761
            IgnoredOutputWithUnusedParameters().float(),
1762
            process_group=process_group,
1763
            find_unused_parameters=True,
1764
        )
1765

1766
        batch_size = 4
1767
        criterion = nn.CrossEntropyLoss()
1768
        input = torch.rand([batch_size, 2], dtype=torch.float)
1769
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
1770

1771
        # Run a few iterations where we ignore the output.
1772
        for _ in range(4):
1773
            output = model(input)
1774
            del output
1775

1776
        # Run a few iterations where we use the output.
1777
        for _ in range(4):
1778
            output = model(input)
1779
            loss = criterion(output, target)
1780
            loss.backward()
1781

1782
    @requires_gloo()
1783
    @skip_if_lt_x_gpu(2)
1784
    def test_ignored_sharded_tensor(self):
1785
        class MyModule(nn.Module):
1786
            def __init__(self, shard_tensor: ShardedTensor) -> None:
1787
                super().__init__()
1788
                self.fc1 = nn.Linear(2, 10, bias=False)
1789
                self.st = nn.Parameter(shard_tensor)
1790
                self.relu = nn.ReLU()
1791

1792
            def forward(self, x):
1793
                x = self.relu(self.fc1(x))
1794
                return F.softmax(x, dim=1)
1795
        pg = dist.init_process_group(
1796
            "gloo",
1797
            init_method=f"file://{self.file_name}",
1798
            world_size=self.world_size,
1799
            rank=self.rank,
1800
        )
1801
        device = torch.device(f"cuda:{self.rank}")
1802
        local_shard_metadata = ShardMetadata(
1803
            shard_offsets=[(self.rank % 2) * 5, 0],
1804
            shard_sizes=[5, 10],
1805
            placement=f"rank:{self.rank}/cuda:{self.rank}"
1806
        )
1807
        local_shards = [Shard(torch.randn(5, 10, device=device), local_shard_metadata)]
1808
        st = init_from_local_shards(local_shards, [10, 10])
1809
        m = MyModule(st)
1810
        DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
1811
            module=m,
1812
            params_and_buffers_to_ignore={'st'}
1813
        )
1814
        # test to make DDP constructor will not fail when module includes a ShardedTensor when ignored
1815
        DistributedDataParallel(
1816
            m,
1817
            device_ids=[device] if device.type == "gpu" else None,
1818
            process_group=pg,
1819
            gradient_as_bucket_view=True,
1820
            broadcast_buffers=False,
1821
            static_graph=True,
1822
        )
1823

1824
    def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
1825
        mult = 2
1826
        batch_size = mult * self.world_size
1827
        criterion = nn.CrossEntropyLoss()
1828
        input = torch.randint(0, 10, [batch_size, 2])
1829
        target = torch.randint(0, 10, [batch_size])
1830

1831
        # Run with entire batch against single process version
1832
        criterion(vanilla_model(input), target).backward()
1833

1834
        # Run with partial batch against multi process version
1835
        partial_input = input.split(mult)[self.rank]
1836
        partial_target = target.split(mult)[self.rank]
1837
        criterion(ddp_model(partial_input), partial_target).backward()
1838

1839
        # Check that the gradients are sparse and identical
1840
        vanilla_parameter = next(vanilla_model.parameters())
1841
        ddp_parameter = next(ddp_model.parameters())
1842
        self.assertEqual(vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce())
1843

1844
    @requires_gloo()
1845
    @skip_if_lt_x_gpu(2)
1846
    def test_save_load_checkpoint(self):
1847
        dist.init_process_group(
1848
            "gloo",
1849
            init_method=f"file://{self.file_name}",
1850
            world_size=self.world_size,
1851
            rank=self.rank,
1852
        )
1853

1854
        class TestModel(nn.Module):
1855
            def __init__(self):
1856
                super().__init__()
1857
                self.fc1 = nn.Linear(2, 10, bias=False)
1858
                self.fc2 = nn.Linear(10, 4, bias=False)
1859
                self.relu = nn.ReLU()
1860

1861
            def forward(self, x):
1862
                x = self.relu(self.fc1(x))
1863
                x = self.relu(self.fc2(x))
1864
                return F.softmax(x, dim=1)
1865

1866
        def train_loop(model, optimizer, iterations):
1867
            for _ in range(iterations):
1868
                optimizer.zero_grad()
1869
                output = model(input)
1870
                loss = criterion(output, target)
1871
                loss.backward()
1872
                optimizer.step()
1873

1874
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1875

1876
        model_withload = TestModel().float().to(device_id)
1877
        model_withoutload = TestModel().float().to(device_id)
1878

1879
        ddp_withload = DistributedDataParallel(
1880
            model_withload,
1881
            device_ids=[device_id],
1882
        )
1883
        ddp_withoutload = DistributedDataParallel(
1884
            model_withoutload,
1885
            device_ids=[device_id],
1886
        )
1887

1888
        # ensure that all the three models start with the same set of parameters. By default they are randomized on construction
1889
        for p in ddp_withload.parameters():
1890
            with torch.no_grad():
1891
                p.zero_()
1892
        for p in model_withload.parameters():
1893
            with torch.no_grad():
1894
                p.zero_()
1895
        for p in ddp_withoutload.parameters():
1896
            with torch.no_grad():
1897
                p.zero_()
1898

1899
        batch_size = 4
1900
        criterion = nn.CrossEntropyLoss()
1901

1902
        optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
1903
        optimizer_non_ddp_withload = torch.optim.SGD(
1904
            model_withload.parameters(), lr=0.001
1905
        )
1906
        optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
1907

1908
        input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
1909
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1910
            device_id
1911
        )
1912

1913
        # run the model for 6 iterations, with a checkpoint in the middle
1914
        train_loop(ddp_withload, optimizer_withload, 3)
1915

1916
        # zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict
1917
        checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"
1918
        if self.rank == 0:
1919
            torch.save(ddp_withload.state_dict(), checkpoint_path)
1920

1921
        dist.barrier()
1922
        map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
1923
        ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)
1924

1925
        for model in [ddp_withload, model_withload]:
1926
            for p in ddp_withload.parameters():
1927
                with torch.no_grad():
1928
                    p.zero_()
1929
        ddp_withload.load_state_dict(ddp_state_dict)
1930
        # the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
1931
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
1932
            ddp_state_dict, "module."
1933
        )
1934
        model_withload.load_state_dict(ddp_state_dict)
1935

1936
        train_loop(ddp_withload, optimizer_withload, 3)
1937
        train_loop(model_withload, optimizer_non_ddp_withload, 3)
1938

1939
        # re-run the model with the same inputs for 6 iterations with no checkpoint
1940
        train_loop(ddp_withoutload, optimizer_withoutload, 6)
1941

1942
        for p_withload, p_withoutload, p_non_ddp_withload in zip(
1943
            ddp_withload.parameters(),
1944
            ddp_withoutload.parameters(),
1945
            model_withload.parameters(),
1946
        ):
1947
            self.assertEqual(p_withload, p_withoutload)
1948
            self.assertEqual(p_non_ddp_withload, p_withoutload)
1949

1950
    def _test_sparse_gradients(self, gradient_as_bucket_view=False):
1951
        process_group = self._get_process_group()
1952

1953
        # Ensure initialized weights and inputs are identical across processes
1954
        torch.manual_seed(1337)
1955

1956
        vanilla_model = SparseGradientModule()
1957
        ddp_model = DistributedDataParallel(
1958
            copy.deepcopy(vanilla_model),
1959
            process_group=process_group,
1960
            gradient_as_bucket_view=gradient_as_bucket_view,
1961
        )
1962

1963
        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
1964

1965
    @requires_gloo()
1966
    def test_sparse_gradients(self):
1967
        self._test_sparse_gradients()
1968

1969
    @requires_gloo()
1970
    def test_sparse_gradients_grad_is_view(self):
1971
        self._test_sparse_gradients(gradient_as_bucket_view=True)
1972

1973
    @requires_gloo()
1974
    def test_ddp_comm_hook_future_passing_cpu(self):
1975
        """
1976
        This unit test verifies whether the Future object is passed properly.
1977
        The callback function creates a Future object and sets a value to it.
1978
        """
1979
        store = c10d.FileStore(self.file_name, self.world_size)
1980
        process_group = self._get_process_group()
1981

1982
        # Test on CPU
1983
        cpu_model = DistributedDataParallel(
1984
            ModuleForDdpCommHook().cpu(), process_group=process_group
1985
        )
1986

1987
        # Register DDP Communication Hook
1988
        cpu_model.register_comm_hook(None, self._simple_hook)
1989

1990
        # check whether the grads are equal to what then callback returns.
1991
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
1992
        self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
1993

1994
    def _gpu_model_with_ddp_comm_hook(
1995
        self, process_group, hook=None, gradient_as_bucket_view=False, state=None
1996
    ):
1997
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1998
        gpu_model = DistributedDataParallel(
1999
            ModuleForDdpCommHook().to(device_id),
2000
            device_ids=[device_id],
2001
            process_group=process_group,
2002
            gradient_as_bucket_view=gradient_as_bucket_view,
2003
        )
2004

2005
        # Register a DDP communication hook if any.
2006
        if hook is not None:
2007
            gpu_model.register_comm_hook(state, hook)
2008

2009
        return gpu_model
2010

2011
    @requires_gloo()
2012
    @skip_if_lt_x_gpu(2)
2013
    def test_ddp_comm_hook_future_passing_gpu_gloo(self):
2014
        """
2015
        This unit test verifies whether the Future object is passed properly using gloo backend.
2016
        The hook callback function creates a Future object and sets a value to it.
2017
        """
2018
        process_group = self._get_process_group()
2019

2020
        # Get GPU model with simple_hook registered.
2021
        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
2022

2023
        # check whether the grads are equal to what simple_hook's then callback returns.
2024
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
2025
        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
2026

2027
    @requires_gloo()
2028
    def test_ddp_invalid_comm_hook_init(self):
2029
        """
2030
        This unit test makes sure that register_comm_hook properly checks the format
2031
        of hook defined by user. The Python hook must be callable. This test also
2032
        checks whether bucket annotation checked properly if defined.
2033
        """
2034
        process_group = self._get_process_group()
2035

2036
        model = DistributedDataParallel(
2037
            ModuleForDdpCommHook(), process_group=process_group
2038
        )
2039

2040
        with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):
2041
            model.register_comm_hook(state=None, hook=1)
2042

2043
        with self.assertRaisesRegex(
2044
            ValueError, "bucket annotation should be dist.GradBucket."
2045
        ):
2046

2047
            def comm_hook(
2048
                state: object, bucket: int
2049
            ) -> torch.futures.Future[torch.Tensor]:
2050
                return torch.futures.Future()
2051

2052
            model.register_comm_hook(state=None, hook=comm_hook)
2053

2054
    @requires_gloo()
2055
    def test_ddp_invalid_comm_hook_return_type(self):
2056
        """
2057
        This test checks whether return annotation checked properly if defined. It also
2058
        checks whether an internal error is thrown if return type is incorrect and user
2059
        hasn't specified any return type annotation.
2060
        """
2061
        process_group = self._get_process_group()
2062

2063
        model = DistributedDataParallel(
2064
            ModuleForDdpCommHook(), process_group=process_group
2065
        )
2066

2067
        expected_err = "Communication hook: return annotation should be torch.futures.Future"
2068
        with self.assertRaisesRegex(
2069
            ValueError,
2070
            expected_err,
2071
        ):
2072

2073
            def comm_hook(state: object, bucket: dist.GradBucket) -> int:
2074
                return torch.futures.Future()
2075

2076
            model.register_comm_hook(state=None, hook=comm_hook)
2077

2078
        verify_ddp_error_logged(model, expected_err)
2079

2080
        with self.assertRaisesRegex(
2081
            RuntimeError,
2082
            "callback must return a torch.futures.Future object, but got",
2083
        ):
2084

2085
            def comm_hook(state: object, bucket: dist.GradBucket):
2086
                return 1
2087

2088
            model.register_comm_hook(state=None, hook=comm_hook)
2089

2090
            # Run forward
2091
            output = model(8, self.rank)
2092

2093
            # Run backward
2094
            output.mean().backward()
2095

2096
    @requires_gloo()
2097
    def test_ddp_comm_hook_register_just_once(self):
2098
        """
2099
        DDP communication hook can only be registered once. This test validates whether
2100
        the error is thrown properly when register_comm_hook is called more than once.
2101
        """
2102
        process_group = self._get_process_group()
2103

2104
        model = DistributedDataParallel(
2105
            ModuleForDdpCommHook(), process_group=process_group
2106
        )
2107

2108
        def dummy_hook(state, bucket):
2109
            fut = torch.futures.Future()
2110
            fut.set_result([bucket.buffer()])
2111
            return fut
2112

2113
        model.register_comm_hook(None, dummy_hook)
2114

2115
        with self.assertRaisesRegex(
2116
            RuntimeError,
2117
            "register_comm_hook or register_builtin_comm_hook can only be called once.",
2118
        ):
2119
            model.register_comm_hook(None, dummy_hook)
2120

2121
    @requires_gloo()
2122
    def test_ddp_comm_hook_sparse_gradients(self):
2123
        """
2124
        Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
2125
        simple hook that does allreduce and works with gloo backend for this test.
2126
        """
2127
        process_group = self._get_process_group()
2128

2129
        # Ensure initialized weights and inputs are identical across processes
2130
        torch.manual_seed(1337)
2131

2132
        vanilla_model = SparseGradientModule()
2133
        ddp_model = DistributedDataParallel(
2134
            copy.deepcopy(vanilla_model),
2135
            process_group=process_group,
2136
        )
2137

2138
        def allreduce_hook_gloo(
2139
            state: object, bucket: dist.GradBucket
2140
        ) -> torch.futures.Future[torch.Tensor]:
2141
            def div_by_world_size(fut):
2142
                # Divide the result by 2 * world_size.
2143
                return fut.wait()[0] / self.world_size
2144

2145
            # Prepare allreduced grad bucket tensors by running an async work.
2146
            fut = process_group.allreduce([bucket.buffer()]).get_future()
2147
            return fut.then(div_by_world_size)
2148

2149
        ddp_model.register_comm_hook(None, allreduce_hook_gloo)
2150

2151
        self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
2152

2153

2154
class ReducerModule(nn.Module):
2155
    def __init__(self):
2156
        super().__init__()
2157
        self.fc1 = nn.Linear(2, 10, bias=False)
2158
        self.fc2 = nn.Linear(10, 4, bias=False)
2159
        self.fc3 = nn.Linear(4, 4, bias=False)
2160
        self.relu = nn.ReLU()
2161

2162
    def forward(self, x, use_fc3=True):
2163
        x = self.relu(self.fc1(x)).float()
2164
        x = self.relu(self.fc2(x)).float()
2165
        if use_fc3:
2166
            x = self.fc3(x).float()
2167
        return F.softmax(x, dim=1)
2168

2169

2170
class ReducerTest(TestCase):
2171
    def setUp(self):
2172
        self.file = tempfile.NamedTemporaryFile(delete=False)
2173
        world_size = 1
2174
        self.store = c10d.FileStore(self.file.name, world_size)
2175
        c10d.init_process_group(backend="gloo", store=self.store, rank=0, world_size=world_size)
2176
        self.process_group = c10d.distributed_c10d._get_default_group()
2177

2178
    def tearDown(self):
2179
        c10d.destroy_process_group()
2180
        try:
2181
            os.remove(self.file.name)
2182
        except OSError as e:
2183
            print(str(e))
2184
            pass
2185

2186
    @requires_gloo()
2187
    def test_single_dtype_single_bucket(self):
2188
        model = ReducerModule()
2189
        parameters = list(model.parameters())
2190
        buckets = [list(range(len(parameters)))]
2191
        dist.Reducer(parameters, buckets, [dist._DEFAULT_FIRST_BUCKET_BYTES], self.process_group)
2192

2193
    def _create_mixed_precision_model(self):
2194
        model = ReducerModule()
2195
        model.float()
2196
        model.fc1.double()
2197
        return model
2198

2199
    @requires_gloo()
2200
    def test_multi_dtype_single_bucket(self):
2201
        model = self._create_mixed_precision_model()
2202

2203
        # Raise if there are multiple types per bucket.
2204
        # In this case we create one bucket for all parameters.
2205
        with self.assertRaises(RuntimeError):
2206
            parameters = list(model.parameters())
2207
            buckets = [list(range(len(parameters)))]
2208
            dist.Reducer(
2209
                parameters,
2210
                buckets,
2211
                [dist._DEFAULT_FIRST_BUCKET_BYTES],
2212
                self.process_group
2213
            )
2214

2215
    @requires_gloo()
2216
    def test_multi_dtype_multi_bucket(self):
2217
        model = self._create_mixed_precision_model()
2218
        parameters = list(model.parameters())
2219
        group_by_dtype = groupby(
2220
            range(len(parameters)), key=lambda i: parameters[i].dtype
2221
        )
2222
        buckets = [list(indices) for _, indices in group_by_dtype]
2223
        dist.Reducer(
2224
            parameters,
2225
            buckets,
2226
            [dist._DEFAULT_FIRST_BUCKET_BYTES for _ in buckets],
2227
            self.process_group
2228
        )
2229

2230
    def _create_reducer_for_models(self, models, find_unused_parameters=False):
2231
        self.assertEqual(len(models), 1)
2232
        parameters = list(models[0].parameters())
2233
        group_by_dtype = groupby(
2234
            range(len(parameters)), key=lambda i: parameters[i].dtype
2235
        )
2236
        buckets = [list(indices) for _, indices in group_by_dtype]
2237
        return dist.Reducer(
2238
            parameters,
2239
            buckets,
2240
            [dist._DEFAULT_FIRST_BUCKET_BYTES for _ in range(len(buckets))],
2241
            self.process_group,
2242
            find_unused_parameters=find_unused_parameters,
2243
        )
2244

2245
    @requires_gloo()
2246
    def test_forward_backward(self):
2247
        batch_size = 10
2248
        model = self._create_mixed_precision_model()
2249
        reducer = self._create_reducer_for_models([model])
2250
        reducer.prepare_for_forward()
2251
        loss = nn.CrossEntropyLoss()
2252
        input = torch.rand([batch_size, 2], dtype=torch.double)
2253
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
2254
        output = loss(model(input), target)
2255
        reducer.prepare_for_backward(output)
2256
        output.backward()
2257

2258
    @requires_gloo()
2259
    def test_forward_backward_unused_parameters(self):
2260
        batch_size = 10
2261
        model = self._create_mixed_precision_model()
2262
        reducer = self._create_reducer_for_models([model], find_unused_parameters=True)
2263
        reducer.prepare_for_forward()
2264
        loss = nn.CrossEntropyLoss()
2265
        input = torch.rand([batch_size, 2], dtype=torch.double)
2266
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
2267
        output = loss(model(input, use_fc3=False), target)
2268

2269
        # Check that the grad of fc3 is not set.
2270
        self.assertEqual(None, model.fc3.weight.grad)
2271

2272
        # Compute and accumulate gradients.
2273
        reducer.prepare_for_backward(output)
2274
        output.backward()
2275

2276
        # The reducer will have marked the grad of fc3 as ready, because
2277
        # it doesn't show up in the autograd graph of `output`. Since fc3.weight
2278
        # is considered being globally unused, it will be kept untouched as None.
2279
        self.assertEqual(None, model.fc3.weight.grad)
2280

2281
    @requires_gloo()
2282
    def test_forward_backward_optimizer(self):
2283
        batch_size = 10
2284
        model = self._create_mixed_precision_model()
2285
        reducer = self._create_reducer_for_models([model], find_unused_parameters=True)
2286
        reducer.prepare_for_forward()
2287
        loss = nn.CrossEntropyLoss()
2288
        optimizer = torch.optim.Adam(model.parameters())
2289
        for i in range(3):
2290
            input = torch.rand([batch_size, 2], dtype=torch.double)
2291
            target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])
2292

2293
            # The `zero_grad` function calls `detach_` and `zero_` on the grad
2294
            # tensors of model parameters. If we tried to set the grad tensors
2295
            # to a view of the reducer's bucket tensors, this would blow up.
2296
            optimizer.zero_grad()
2297

2298
            # Unused parameter only in the first iteration.
2299
            output = loss(model(input, use_fc3=(i > 0)), target)
2300
            reducer.prepare_for_backward(output)
2301
            output.backward()
2302
            optimizer.step()
2303

2304

2305
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
2306
    @property
2307
    def device(self):
2308
        return "cpu"
2309

2310

2311
    def setUp(self):
2312
        super().setUp()
2313
        self._spawn_processes()
2314

2315
    def tearDown(self):
2316
        super().tearDown()
2317
        try:
2318
            os.remove(self.file_name)
2319
        except OSError:
2320
            pass
2321

2322
    def _test_broadcast_coalesced(self, process_group, device, root_rank):
2323
        half = torch.float16
2324

2325
        # No support for float16 for CPU tensors
2326
        if device == torch.device("cpu"):
2327
            half = torch.float32
2328

2329
        target = torch.arange(60, dtype=half, device=device).chunk(5)
2330
        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2331
        target += torch.arange(60, dtype=half, device=device).chunk(5)
2332
        target += torch.arange(60, dtype=torch.float64, device=device).chunk(5)
2333
        target += torch.arange(60, dtype=half, device=device).chunk(5)
2334
        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2335

2336
        # The tensors to pass to broadcast are identical to the target
2337
        # only on the process that is the root of the broadcast.
2338
        if self.rank == root_rank:
2339
            tensors = [tensor.clone() for tensor in target]
2340
        else:
2341
            tensors = [torch.zeros_like(tensor) for tensor in target]
2342

2343
        if self.rank != root_rank:
2344
            self.assertNotEqual(tensors, target)
2345

2346
        c10d._broadcast_coalesced(
2347
            process_group, tensors, buffer_size=256, src=root_rank
2348
        )
2349

2350
        if self.rank != root_rank:
2351
            self.assertEqual(tensors, target)
2352

2353
    @requires_gloo()
2354
    @skip_if_lt_x_gpu(2)
2355
    def test_broadcast_coalesced_gloo_cuda(self):
2356
        store = c10d.FileStore(self.file_name, self.world_size)
2357
        c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
2358
        process_group = c10d.distributed_c10d._get_default_group()
2359
        device = torch.device("cuda:%d" % self.rank)
2360
        backend = process_group._get_backend(device)
2361
        backend.create_device(interface=LOOPBACK)
2362
        ranks = list(range(self.world_size))
2363
        for root_rank in ranks:
2364
            self._test_broadcast_coalesced(process_group, device, root_rank)
2365

2366
    @requires_gloo()
2367
    def test_broadcast_coalesced_gloo_cpu(self):
2368
        store = c10d.FileStore(self.file_name, self.world_size)
2369
        c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
2370
        process_group = c10d.distributed_c10d._get_default_group()
2371
        device = torch.device("cpu")
2372
        backend = process_group._get_backend(device)
2373
        backend.create_device(interface=LOOPBACK)
2374
        ranks = list(range(self.world_size))
2375
        for root_rank in ranks:
2376
            self._test_broadcast_coalesced(process_group, device, root_rank)
2377

2378
    @requires_gloo()
2379
    @skip_if_lt_x_gpu(2)
2380
    def test_sequence_num_set_default_pg_gloo(self):
2381
        self._test_sequence_num_set_default_pg(backend="gloo")
2382

2383
    @requires_gloo()
2384
    @skip_if_lt_x_gpu(2)
2385
    def test_sequence_num_set_gloo_new_group(self):
2386
        self._test_sequence_num_set_new_group(backend="gloo")
2387

2388
    @skip_if_lt_x_gpu(2)
2389
    @requires_gloo()
2390
    def test_sequence_num_incremented_gloo_default(self):
2391
        self._test_sequence_num_incremented_default_group("gloo")
2392

2393
    @skip_if_lt_x_gpu(4)
2394
    @requires_gloo()
2395
    def test_sequence_num_incremented_gloo_subgroup(self):
2396
        if self.world_size < 4:
2397
            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
2398
        self._test_sequence_num_incremented_subgroup("gloo")
2399

2400
    @skip_if_lt_x_gpu(2)
2401
    @requires_gloo()
2402
    def test_gloo_warn_not_in_group(self):
2403
        self._test_warn_not_in_group(backend="gloo")
2404

2405
    @skip_if_lt_x_gpu(2)
2406
    @requires_gloo()
2407
    def test_gloo_rank_membership(self):
2408
        self._test_rank_membership(backend="gloo")
2409

2410
    @skip_if_lt_x_gpu(2)
2411
    @requires_gloo()
2412
    def test_tensor_dtype_mismatch(self):
2413
        self._test_tensor_dtype_mismatch(backend="gloo")
2414

2415
    @skip_if_lt_x_gpu(2)
2416
    @requires_gloo()
2417
    def test_tensor_dtype_complex(self):
2418
        self._test_tensor_dtype_complex(backend="gloo")
2419

2420
    @requires_gloo()
2421
    def test_bool_tensors(self):
2422
        self._test_bool_tensors(backend="gloo")
2423

2424
class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGroupWithDispatchedCollectivesTests):
2425
    @requires_gloo()
2426
    def test_collectives(self):
2427
        self._test_collectives(backend="gloo")
2428

2429
    @requires_gloo()
2430
    def test_allreduce_coalesced(self):
2431
        self._test_allreduce_coalesced(backend="gloo")
2432

2433
    @requires_gloo()
2434
    def test_all_to_all_single(self):
2435
        self._test_all_to_all_single(backend="gloo")
2436

2437
    @requires_gloo()
2438
    def test_allgather_coalesced(self):
2439
        store = dist.FileStore(self.file_name, self.world_size)
2440
        dist.init_process_group(
2441
            "gloo",
2442
            world_size=self.world_size,
2443
            rank=self.rank,
2444
            store=store,
2445
        )
2446
        input_tensor = torch.ones(10, 10, dtype=torch.float32)
2447
        output_tensor_list = [torch.zeros_like(input_tensor)]
2448
        dist.all_gather_coalesced([output_tensor_list], [input_tensor])
2449
        self.assertEqual(output_tensor_list, [input_tensor])
2450

2451
    @requires_gloo()
2452
    def test_monitored_barrier(self):
2453
        store = dist.FileStore(self.file_name, self.world_size)
2454
        dist.init_process_group(
2455
            "gloo",
2456
            world_size=self.world_size,
2457
            rank=self.rank,
2458
            store=store,
2459
        )
2460
        dist.monitored_barrier()
2461

2462
class CompilerTest(test_c10d_common.CompilerTest):
2463

2464
    @property
2465
    def world_size(self):
2466
        return 2
2467

2468
    def _get_default_group(self):
2469
        store = c10d.FileStore(self.file_name, self.world_size)
2470
        dist.init_process_group(
2471
            backend="gloo",
2472
            rank=self.rank,
2473
            world_size=self.world_size,
2474
            store=store,
2475
        )
2476
        return dist.distributed_c10d._get_default_group()
2477

2478
    def test_allreduce_work_wait_cpu(self):
2479
        self._test_allreduce_work_wait(torch.ones(2, 2) * self.rank)
2480

2481
    @skip_if_lt_x_gpu(2)
2482
    def test_allreduce_work_wait_gpu(self):
2483
        self._test_allreduce_work_wait(
2484
            torch.ones(2, 2, device=self.rank) * self.rank
2485
        )
2486

2487
    def test_allgather_work_wait_cpu(self):
2488
        self._test_allgather_work_wait(torch.ones(2, 2) * self.rank)
2489

2490
    @skip_if_lt_x_gpu(2)
2491
    def test_allgather_work_wait_gpu(self):
2492
        self._test_allgather_work_wait(
2493
            torch.ones(2, 2, device=self.rank) * self.rank
2494
        )
2495

2496
    def test_broadcast_work_wait_cpu(self):
2497
        self._test_broadcast_work_wait(torch.ones(2, 2) * self.rank)
2498

2499
    @skip_if_lt_x_gpu(2)
2500
    def test_broadcast_work_wait_gpu(self):
2501
        self._test_broadcast_work_wait(
2502
            torch.ones(2, 2, device=self.rank) * self.rank
2503
        )
2504

2505
    def test_scatter_work_wait_cpu(self):
2506
        self._test_scatter_work_wait(torch.ones(2, 2) * self.rank)
2507

2508
    @skip_if_lt_x_gpu(2)
2509
    def test_scatter_work_wait_gpu(self):
2510
        self._test_scatter_work_wait(
2511
            torch.ones(2, 2, device=self.rank) * self.rank
2512
        )
2513

2514
    def test_nested_comm_tensor_wrapping(self):
2515
        self._test_nested_comm_tensor_wrapping(torch.ones(2, 2) * self.rank)
2516

2517
    def test_consecutive_comm_work_wait_cpu(self):
2518
        self._test_consecutive_comm_work_wait(torch.ones(2, 2) * self.rank)
2519

2520
    @skip_if_lt_x_gpu(2)
2521
    def test_consecutive_comm_work_wait_gpu(self):
2522
        self._test_consecutive_comm_work_wait(
2523
            torch.ones(2, 2, device=self.rank) * self.rank
2524
        )
2525

2526
class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):
2527
    def setUp(self):
2528
        super().setUp()
2529
        self._spawn_processes()
2530

2531
    def tearDown(self):
2532
        super().tearDown()
2533
        try:
2534
            os.remove(self.file_name)
2535
        except OSError:
2536
            pass
2537

2538
    @property
2539
    def device(self):
2540
        return torch.device("cpu")
2541

2542
    @requires_gloo()
2543
    def test_new_group_local_sync(self):
2544
        self._test_new_group_local_sync(backend="gloo")
2545

2546
    @requires_gloo()
2547
    def test_new_group_local_sync_sanity_check(self):
2548
        self._test_new_group_local_sync_sanity_check(backend="gloo")
2549

2550
    @requires_gloo()
2551
    def test_new_group_local_sync_duplicate_pg(self):
2552
        self._test_new_group_local_sync_duplicate_pg(backend="gloo")
2553

2554
if __name__ == "__main__":
2555
    assert (
2556
        not torch.cuda._initialized
2557
    ), "test_distributed must not have initialized CUDA context on main process"
2558

2559
    run_tests()
2560

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.