pytorch

Форк
0
/
test_c10d_nccl.py 
4490 строк · 174.2 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import copy
4
import json
5
import os
6
import pickle
7
import random
8
import re
9
import signal
10
import sys
11
import tempfile
12
import threading
13
import time
14
import warnings
15
from contextlib import contextmanager
16
from datetime import datetime, timedelta
17
from enum import auto, Enum
18
from itertools import chain, product
19
from unittest import mock, SkipTest
20

21
import torch
22
import torch.distributed as c10d
23

24

25
if not c10d.is_available() or not c10d.is_nccl_available():
26
    print("c10d NCCL not available, skipping tests", file=sys.stderr)
27
    sys.exit(0)
28

29
from typing import Dict, List
30

31
import test_c10d_common
32
from test_c10d_common import ConvNet, DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook
33

34
import torch.distributed as dist
35
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
36
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
37
import torch.nn.functional as F
38
import torch.testing._internal.common_utils as common
39
from torch import nn
40
from torch._C._distributed_c10d import OpType
41
from torch.nn.parallel import DistributedDataParallel
42
from torch.testing._internal.common_cuda import TEST_MULTIGPU
43
from torch.testing._internal.common_distributed import (
44
    get_timeout,
45
    init_multigpu_helper,
46
    MultiProcessTestCase,
47
    requires_gloo,
48
    requires_nccl,
49
    requires_nccl_version,
50
    skip_if_lt_x_gpu,
51
    skip_if_rocm,
52
    TEST_SKIPS,
53
    with_dist_debug_levels,
54
    with_nccl_blocking_wait,
55
)
56
from torch.testing._internal.common_utils import (
57
    instantiate_parametrized_tests,
58
    parametrize,
59
    retry_on_connect_failures,
60
    run_tests,
61
    skip_but_pass_in_sandcastle,
62
    skip_but_pass_in_sandcastle_if,
63
    TEST_CUDA,
64
    TEST_WITH_DEV_DBG_ASAN,
65
    TEST_WITH_ROCM,
66
    TestCase,
67
)
68

69

70
if TEST_WITH_DEV_DBG_ASAN:
71
    print(
72
        "Skip ASAN as torch + multiprocessing spawn have known issues", file=sys.stderr
73
    )
74
    sys.exit(0)
75

76
# bfloat16 is only supported by CUDA 11+
77
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
78
    (torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 11)
79
    or torch.version.hip is not None
80
)
81

82

83
class RendezvousEnvTest(TestCase):
84
    @retry_on_connect_failures
85
    @requires_nccl()
86
    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test")
87
    def test_common_errors(self):
88
        vars = {
89
            "WORLD_SIZE": "1",
90
            "RANK": "0",
91
            "MASTER_ADDR": "127.0.0.1",
92
            "MASTER_PORT": str(common.find_free_port()),
93
        }
94

95
        class Env:
96
            def __init__(self, vars):
97
                self.env_patcher = mock.patch.dict(os.environ, vars, clear=True)
98

99
            def __enter__(self):
100
                self.env_patcher.start()
101

102
            def __exit__(self, type, value, traceback):
103
                self.env_patcher.stop()
104

105
        def without(d, key):
106
            d = d.copy()
107
            d.pop(key)
108
            return d
109

110
        def withouts(d, keys):
111
            d = d.copy()
112
            for key in keys:
113
                d.pop(key)
114
            return d
115

116
        with Env(without(vars, "WORLD_SIZE")):
117
            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
118
            with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"):
119
                gen = c10d.rendezvous("env://")
120
                next(gen)
121
            c10d.init_process_group(backend="nccl", world_size=1)
122
            self.assertEqual(c10d.get_rank(), 0)
123
            self.assertEqual(c10d.get_world_size(), 1)
124
            c10d.destroy_process_group()
125

126
        with Env(without(vars, "RANK")):
127
            self.assertEqual(None, os.environ.get("RANK"))
128
            with self.assertRaisesRegex(ValueError, "RANK expected"):
129
                gen = c10d.rendezvous("env://")
130
                next(gen)
131
            c10d.init_process_group(backend="nccl", rank=0)
132
            self.assertEqual(c10d.get_rank(), 0)
133
            self.assertEqual(c10d.get_world_size(), 1)
134
            c10d.destroy_process_group()
135

136
        with Env(withouts(vars, ["RANK", "WORLD_SIZE"])):
137
            self.assertEqual(None, os.environ.get("RANK"))
138
            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
139
            c10d.init_process_group(backend="nccl", rank=0, world_size=1)
140
            self.assertEqual(c10d.get_rank(), 0)
141
            self.assertEqual(c10d.get_world_size(), 1)
142
            c10d.destroy_process_group()
143

144
        with Env(vars):
145
            c10d.init_process_group(backend="nccl")
146
            self.assertEqual(c10d.get_rank(), 0)
147
            self.assertEqual(c10d.get_world_size(), 1)
148
            c10d.destroy_process_group()
149

150
        with Env(without(vars, "MASTER_ADDR")):
151
            self.assertEqual(None, os.environ.get("MASTER_ADDR"))
152
            with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"):
153
                gen = c10d.rendezvous("env://")
154
                next(gen)
155

156
        with Env(without(vars, "MASTER_PORT")):
157
            self.assertEqual(None, os.environ.get("MASTER_PORT"))
158
            with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"):
159
                gen = c10d.rendezvous("env://")
160
                next(gen)
161

162
        with Env(without(vars, "WORLD_SIZE")):
163
            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
164
            gen = c10d.rendezvous(f"env://?world_size={1}")
165
            _, _, size = next(gen)
166
            self.assertEqual(size, 1)
167

168
        with Env(without(vars, "RANK")):
169
            self.assertEqual(None, os.environ.get("RANK"))
170
            gen = c10d.rendezvous(f"env://?rank={0}")
171
            _, rank, _ = next(gen)
172
            self.assertEqual(rank, 0)
173

174
        with Env(withouts(vars, ["RANK", "WORLD_SIZE"])):
175
            self.assertEqual(None, os.environ.get("RANK"))
176
            self.assertEqual(None, os.environ.get("WORLD_SIZE"))
177
            gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}")
178
            _, rank, size = next(gen)
179
            self.assertEqual(rank, 0)
180
            self.assertEqual(size, 1)
181

182

183
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
184
    @requires_nccl()
185
    @retry_on_connect_failures
186
    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test")
187
    def test_default_store_timeout_nccl(self):
188
        self._test_default_store_timeout("nccl")
189

190

191
class ProcessGroupNCCLNoGPUTest(TestCase):
192
    MAIN_PROCESS_RANK = 0
193

194
    def setUp(self):
195
        self.rank = self.MAIN_PROCESS_RANK
196
        self.world_size = 1
197
        self.file = tempfile.NamedTemporaryFile(delete=False)
198

199
    def tearDown(self):
200
        pass
201

202
    @requires_nccl()
203
    @skip_but_pass_in_sandcastle_if(TEST_CUDA, "GPUs are available, skipping test")
204
    def test_init_no_gpus(self):
205
        store = c10d.FileStore(self.file.name, self.world_size)
206
        with self.assertRaisesRegex(
207
            ValueError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"
208
        ):
209
            c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
210

211

212
class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
213
    def _create_process_group_nccl(self, store, opts, device_id=None):
214
        # create nccl processgroup with opts
215
        c10d.init_process_group(
216
            "nccl",
217
            world_size=self.world_size,
218
            rank=self.rank,
219
            store=store,
220
            pg_options=opts,
221
            device_id=device_id,
222
        )
223
        pg = c10d.distributed_c10d._get_default_group()
224
        return pg
225

226
    def opts(self, high_priority_stream=False):
227
        opts = c10d.ProcessGroupNCCL.Options()
228
        opts.is_high_priority_stream = high_priority_stream
229
        return opts
230

231
    def setUp(self):
232
        super().setUp()
233
        # Need to skip return code checking for these tests since the child
234
        # processes don't exit cleanly in some cuda versions
235
        self.skip_return_code_checks = [
236
            self.test_nan_assert_float16.__wrapped__,
237
            self.test_nan_assert_float32.__wrapped__,
238
            self.test_nan_assert_float64.__wrapped__,
239
            self.test_nan_assert_bfloat16.__wrapped__,
240
        ]
241

242
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
243
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
244
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
245
        # self.num_gpus = torch.cuda.device_count()
246
        self._spawn_processes()
247

248
    def tearDown(self):
249
        super().tearDown()
250
        try:
251
            os.remove(self.file_name)
252
        except OSError:
253
            pass
254

255
    @property
256
    def world_size(self):
257
        return 2
258

259
    @property
260
    def rank_to_GPU(self):
261
        # return rank to GPU map
262
        return init_multigpu_helper(self.world_size, "nccl")
263

264
    @requires_nccl()
265
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 1 GPU")
266
    @skip_if_lt_x_gpu(1)
267
    def test_nccl_dist_backend_error(self):
268
        store = c10d.FileStore(self.file_name, self.world_size)
269
        self._create_process_group_nccl(store, self.opts())
270

271
        # Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage
272
        with self.assertRaises(dist.DistBackendError) as cm:
273
            dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0)
274
        self.assertTrue(isinstance(cm.exception, dist.DistError))
275

276
        self.assertIsInstance(cm.exception, RuntimeError)
277

278
    @requires_nccl()
279
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
280
    def test_abort_pg(self):
281
        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
282
        # abort the process group.
283
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
284

285
        store = c10d.FileStore(self.file_name, self.world_size)
286
        self._create_process_group_nccl(store, self.opts())
287
        device = self.rank_to_GPU[self.rank][0]
288

289
        t = torch.rand(10, 10, device=device)
290
        # First allreduce to initialize state.
291
        dist.all_reduce(t)
292

293
        def abortpg():
294
            c10d.distributed_c10d._get_default_group()._get_backend(
295
                torch.device(device)
296
            )._shutdown()
297

298
        # Initialize DDP to ensure "destroy_process_group" will not call
299
        # ProcessGroupNCCL destructor since DDP holds a reference to process group.
300
        # Run a single iteration of DDP to initialize state.
301
        model = DistributedDataParallel(
302
            torch.nn.Linear(10, 10).to(device), device_ids=[device]
303
        )
304
        model(t).sum().backward()
305

306
        # Now simulate collective getting stuck and abort gets us unstuck
307
        if self.rank == 0:
308
            dist.all_reduce(t)
309

310
            # Schedule thread before we get stuck to abort pg.
311
            thread = threading.Thread(target=abortpg)
312
            thread.start()
313

314
            # We would get stuck here due to d2h if we didn't abort.
315
            t_cpu = t.cpu()
316

317
            thread.join()
318

319
    @requires_nccl()
320
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
321
    def test_close_pg(self):
322
        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
323
        # abort the process group.
324
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
325

326
        store = c10d.FileStore(self.file_name, self.world_size)
327
        pg = self._create_process_group_nccl(store, self.opts())
328
        device = self.rank_to_GPU[self.rank][0]
329

330
        t = torch.rand(10, 10, device=device)
331
        # First allreduce to initialize state.
332
        pg.allreduce(t)
333

334
        # Destroy pg and validate pg is no longer valid
335
        dist.destroy_process_group()
336
        with self.assertRaises(dist.DistBackendError):
337
            pg.allreduce([t])
338

339
        del pg
340

341
    CUDA_12_AND_ABOVE = torch.cuda.is_available() and (
342
        torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12
343
    )
344

345
    @requires_nccl()
346
    @skip_but_pass_in_sandcastle_if(
347
        not (TEST_MULTIGPU and CUDA_12_AND_ABOVE),
348
        "NCCL test requires 2+ GPUs and Device side assert could cause unexpected errors in lower versions of CUDA",
349
    )
350
    @parametrize("type", [torch.float16, torch.float32, torch.float64, torch.bfloat16])
351
    @skip_if_rocm
352
    def test_nan_assert(self, type):
353
        # Expecting a device-side error when NaN is detected
354
        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
355
        store = c10d.FileStore(self.file_name, self.world_size)
356
        pg = self._create_process_group_nccl(store, self.opts())
357
        device = self.rank_to_GPU[self.rank][0]
358
        size = (10, 10)
359
        nan_tensor = torch.full(size, self.rank, dtype=type, device=device)
360
        # randomly pick an nan element
361
        i = random.randint(0, nan_tensor.size(0) - 1)
362
        j = random.randint(0, nan_tensor.size(1) - 1)
363
        nan_tensor[i, j] = float("nan")
364
        with self.assertRaises(RuntimeError):
365
            pg.allreduce(nan_tensor)
366
        dist.destroy_process_group()
367
        # reset env
368
        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
369

370
    @requires_nccl()
371
    @skip_if_lt_x_gpu(2)
372
    def test_nan_rank_filter(self):
373
        # Putting NaN at recv buffer, program should not fail as NaN checker
374
        # should not check on receive buffer
375
        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
376
        store = c10d.FileStore(self.file_name, self.world_size)
377
        device = torch.device("cuda:%d" % self.rank)
378
        c10d.init_process_group(
379
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
380
        )
381
        t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
382
        if self.rank != 0:
383
            # Putting NaN at recv buffer
384
            t[1, 1] = float("nan")
385
        # Against broadcast
386
        c10d.broadcast(t, 0)
387
        # Against P2P
388
        if self.rank == 0:
389
            c10d.send(t, 1)
390
        elif self.rank == 1:
391
            c10d.recv(t, 0)
392
        c10d.destroy_process_group()
393
        # reset env
394
        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
395

396
    @requires_nccl()
397
    @skip_if_lt_x_gpu(2)
398
    def test_nan_check(self):
399
        # Not expecting an error, NaN check should not make legit code fail
400
        os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
401
        store = c10d.FileStore(self.file_name, self.world_size)
402
        device = torch.device("cuda:%d" % self.rank)
403
        c10d.init_process_group(
404
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
405
        )
406
        x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
407
        t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
408
        c10d.broadcast(x, src=0)
409
        c10d.all_reduce(t)
410
        c10d.barrier()
411
        c10d.destroy_process_group()
412
        # reset env
413
        os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
414

415
    @requires_nccl()
416
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
417
    def test_destruct_before_terminate_pg(self):
418
        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
419
        # abort the process group.
420
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
421
        store = c10d.FileStore(self.file_name, self.world_size)
422
        pg = self._create_process_group_nccl(store, self.opts())
423
        device = self.rank_to_GPU[self.rank][0]
424

425
        t = torch.rand(10, 10, device=device)
426
        # First allreduce to initialize state.
427
        pg.allreduce(t)
428
        # force destruction before terminating comms, destructor would terminate comms
429
        del pg
430

431
    @requires_nccl()
432
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
433
    def test_abort_in_destroy_pg(self):
434
        # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically
435
        # abort the process group.
436
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
437

438
        store = c10d.FileStore(self.file_name, self.world_size)
439
        pg = self._create_process_group_nccl(store, self.opts())
440
        device = self.rank_to_GPU[self.rank][0]
441

442
        t = torch.rand(10, 10, device=device)
443
        # First allreduce to initialize state.
444
        pg.allreduce(t)
445

446
        # Destroy pg and validate pg is NOT in working condition since
447
        # we have shutdown comms
448
        dist.destroy_process_group()
449
        with self.assertRaises(dist.DistBackendError):
450
            pg.allreduce([t])
451

452
    @requires_nccl()
453
    @skip_but_pass_in_sandcastle_if(
454
        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
455
    )
456
    def test_close_multi_pg_unordered(self):
457
        store = c10d.FileStore(self.file_name, self.world_size)
458
        pg = self._create_process_group_nccl(store, self.opts())
459
        device = self.rank_to_GPU[self.rank][0]
460
        t = torch.rand(10, 10, device=device)
461
        # First allreduce to initialize default PG's communicator.
462
        pg.allreduce(t).wait()
463
        new_pg1 = c10d.new_group([0, 1])
464
        new_pg2 = c10d.new_group([0, 1])
465
        if self.rank == 0 or self.rank == 1:
466
            t1 = torch.rand(10, 10, device=device)
467
            t2 = torch.rand(10, 10, device=device)
468
            new_pg1.allreduce(t1).wait()
469
            new_pg2.allreduce(t2).wait()
470
        if self.rank == 0:
471
            dist.destroy_process_group(new_pg2)
472
            # force destruction of pg2 first
473
            del new_pg2
474
            dist.destroy_process_group(new_pg1)
475
            del new_pg1
476
        if self.rank == 1:
477
            c10d.destroy_process_group(new_pg1)
478
            # force destruction of pg1 first
479
            del new_pg1
480
            dist.destroy_process_group(new_pg2)
481
            del new_pg2
482
        dist.destroy_process_group()
483

484
    @requires_nccl()
485
    @skip_but_pass_in_sandcastle_if(
486
        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
487
    )
488
    def test_abort_in_destroy_multi_pgs(self):
489
        store = c10d.FileStore(self.file_name, self.world_size)
490
        pg = self._create_process_group_nccl(store, self.opts())
491
        device = self.rank_to_GPU[self.rank][0]
492
        t = torch.rand(10, 10, device=device)
493
        # First allreduce to initialize default PG's communicator.
494
        pg.allreduce(t).wait()
495
        new_pg1 = c10d.new_group([0, 1])
496
        new_pg2 = c10d.new_group([0, 1])
497
        t1 = torch.rand(10, 10, device=device)
498
        t2 = torch.rand(10, 10, device=device)
499
        new_pg1.allreduce(t1).wait()
500
        new_pg2.allreduce(t2).wait()
501
        backend = pg._get_backend(torch.device(device))
502
        # default PG's backend should have a split count of 2
503
        self.assertEqual(backend.comm_split_count(), 2)
504
        # shutdown all NCCL PGs in one shot
505
        dist.destroy_process_group()
506

507
    @requires_nccl()
508
    @skip_but_pass_in_sandcastle_if(
509
        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
510
    )
511
    def test_abort_in_destroy_mixed_empty_pgs(self):
512
        store = c10d.FileStore(self.file_name, self.world_size)
513
        pg = self._create_process_group_nccl(store, self.opts())
514
        device = self.rank_to_GPU[self.rank][0]
515
        t = torch.rand(10, 10, device=device)
516
        # First allreduce to initialize default PG's communicator.
517
        pg.allreduce(t).wait()
518
        # PG1 is an PG without comms initialized, since we don't call collective on it
519
        new_pg1 = c10d.new_group([0, 1])
520
        new_pg2 = c10d.new_group([0, 1])
521
        t2 = torch.rand(10, 10, device=device)
522

523
        new_pg2.allreduce(t2).wait()
524
        backend = pg._get_backend(torch.device(device))
525
        # default PG's backend should have a split count of 1
526
        self.assertEqual(backend.comm_split_count(), 1)
527
        # shutdown all NCCL PGs in one shot
528
        dist.destroy_process_group()
529

530
    @requires_nccl()
531
    @skip_but_pass_in_sandcastle_if(
532
        torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs"
533
    )
534
    def test_file_store_check(self):
535
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
536
        os.environ["TORCH_NCCL_ENABLE_MONITORING"] = "0"
537
        # FileStore check() would be executed
538
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
539
        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "0"
540

541
        # self.file_name is created using "delete=False"
542
        # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name
543
        store = dist.FileStore(self.file_name, self.world_size)
544
        dist.init_process_group(
545
            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
546
        )
547
        pg = dist.distributed_c10d._get_default_group()
548
        self.assertEqual(pg.rank(), self.rank)
549
        self.assertEqual(pg.size(), self.world_size)
550
        # give enough time for check() to be executed multiple times
551
        time.sleep(2)
552
        dist.destroy_process_group()
553

554
    def _check_nccl_timeout(self, expected_timeout):
555
        pg = dist.distributed_c10d._get_default_group()
556
        options = pg._get_backend(torch.device(f"cuda:{self.rank}")).options
557
        self.assertEqual(options._timeout, expected_timeout)
558

559
    @requires_nccl()
560
    @skip_but_pass_in_sandcastle_if(not TEST_CUDA, "No GPUs available, skipping test")
561
    def test_init_process_group_nccl_timeout(self):
562
        # nccl is handled 'specially' inside init_process_group and its options class is different from the options
563
        # used by the other PG's.  There are specific edge cases for nccl that need to be tested.
564

565
        store = c10d.FileStore(self.file_name, self.world_size)
566
        base_opts = dict(
567
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
568
        )
569

570
        # test the default value coming from the `init_process_group` kwarg default
571
        dist.init_process_group(**base_opts)
572
        self._check_nccl_timeout(torch.distributed.constants.default_pg_nccl_timeout)
573
        dist.destroy_process_group()
574

575
        # test that `kwarg` timeout takes effect
576
        new_timeout = timedelta(seconds=123)
577
        dist.init_process_group(**base_opts, timeout=new_timeout)
578
        self._check_nccl_timeout(new_timeout)
579
        dist.destroy_process_group()
580

581
        # test that timeout value provided via `pg_options` kwarg is ignored and issues warning,
582
        # 'timeout' kwarg (or its kwdefault) taking precedence
583
        opts = dist.ProcessGroupNCCL.Options()
584
        opts._timeout = timedelta(seconds=123)
585
        with warnings.catch_warnings(record=True) as w:
586
            dist.init_process_group(**base_opts, pg_options=opts)
587
            # TODO(whc) i verified that we are indeed emitting this warning, and i can't figure out why i can't catch it.
588
            # self.assertEqual(len(w), 1)
589
            # self.assertTrue("pg_options._timeout was specified" in str(w[-1].message))
590
        self._check_nccl_timeout(torch.distributed.constants.default_pg_nccl_timeout)
591
        dist.destroy_process_group()
592

593
        # test that timeout value provided via `pg_options` kwarg is ignored and issues warning,
594
        # 'timeout' kwarg taking precedence
595
        opts = dist.ProcessGroupNCCL.Options()
596
        opts._timeout = timedelta(seconds=123)
597
        dist.init_process_group(
598
            **base_opts, pg_options=opts, timeout=timedelta(seconds=1240)
599
        )
600
        self._check_nccl_timeout(timedelta(seconds=1240))
601
        dist.destroy_process_group()
602

603
    @requires_nccl()
604
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
605
    @parametrize("backend", [None, "nccl"])
606
    def test_set_nccl_pg_timeout(self, backend):
607
        store = c10d.FileStore(self.file_name, self.world_size)
608
        opts = dict(
609
            backend=backend,
610
            store=store,
611
            rank=self.rank,
612
            world_size=self.world_size,
613
            timeout=timedelta(seconds=123),
614
        )
615
        dist.init_process_group(**opts)
616
        pg = dist.distributed_c10d._get_default_group()
617
        pg.allreduce(torch.rand(10).cuda(self.rank))
618
        self._check_nccl_timeout(timedelta(seconds=123))
619
        pg._get_backend(torch.device(f"cuda:{self.rank}"))._set_default_timeout(
620
            timedelta(seconds=23)
621
        )
622
        self._check_nccl_timeout(timedelta(seconds=23))
623
        pg.allreduce(torch.rand(10).cuda(self.rank))
624
        c10d.distributed_c10d._set_pg_timeout(timedelta(seconds=252), pg)
625
        self._check_nccl_timeout(timedelta(seconds=252))
626

627
    @requires_nccl()
628
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
629
    @parametrize("backend", [None, "nccl"])
630
    def test_extend_nccl_pg_timeout(self, backend):
631
        torch.cuda.set_device(self.rank)
632
        store = c10d.FileStore(self.file_name, self.world_size)
633
        opts = dict(
634
            backend=backend,
635
            store=store,
636
            rank=self.rank,
637
            world_size=self.world_size,
638
            timeout=timedelta(seconds=123),
639
        )
640
        dist.init_process_group(**opts)
641
        pg = dist.distributed_c10d._get_default_group()
642
        bankend = pg._get_backend(torch.device(f"cuda:{self.rank}"))
643
        w = pg.allreduce(torch.rand(10).cuda(self.rank))
644
        self.assertTrue(bankend._verify_work_timeout(w, timedelta(seconds=123)))
645
        w.wait()
646
        bankend._set_default_timeout(timedelta(seconds=3))
647
        if self.rank == 0:
648
            # Ideally we want to sleep for a very long time, but this is not
649
            # feasible in unit test. So this is only a very tiny case.
650
            time.sleep(5)
651
            pg.allreduce(torch.rand(10).cuda(self.rank))
652
            time.sleep(5)
653
            pg.allreduce(torch.rand(5).cuda(self.rank))
654
            w = pg.allreduce(torch.rand(10).cuda(self.rank))
655
            self.assertTrue(bankend._verify_work_timeout(w, timedelta(seconds=3)))
656
            w.wait()
657
        else:
658
            dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
659
                timedelta(seconds=10)
660
            )
661
            w1 = pg.allreduce(torch.rand(10).cuda(self.rank))
662
            w2 = pg.allreduce(torch.rand(5).cuda(self.rank))
663
            self.assertTrue(bankend._verify_work_timeout(w1, timedelta(seconds=13)))
664
            self.assertTrue(bankend._verify_work_timeout(w2, timedelta(seconds=13)))
665
            w1.wait()
666
            dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
667
                timedelta(seconds=5)
668
            )
669
            # Since we are not block wait so use a sync here to leave enough time
670
            # for watchdog to reset first timeout extension.
671
            torch.cuda.synchronize(torch.device(f"cuda:{self.rank}"))
672
            w = pg.allreduce(torch.rand(10).cuda(self.rank))
673
            self.assertTrue(bankend._verify_work_timeout(w, timedelta(seconds=8)))
674
            w.wait()
675

676
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
677
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
678
    def test_comm_split_optimization(self):
679
        # Test the optimization of new groups that contain all world
680
        # ranks use the "transparent" `ncclCommSplit` optimization.
681
        store = c10d.FileStore(self.file_name, self.world_size)
682
        pg = self._create_process_group_nccl(store, self.opts())
683

684
        # Test lazy splitting behavior across each per-device backend.
685
        for device in self.rank_to_GPU[self.rank]:
686
            backend = pg._get_backend(torch.device(device))
687

688
            # split doesn't happen unless the original process group has lazily
689
            # created communicators, so first verify we haven't split even when
690
            # making the new group and running an operation on the original pg.
691
            ng = c10d.new_group()
692
            tensor = torch.tensor([self.rank]).cuda(device)
693
            pg.broadcast(tensor, 0)
694
            self.assertEqual(backend.comm_split_count(), 0)
695

696
            # The new group will force a split of the original on first use.
697
            ng.broadcast(tensor, 0)
698
            self.assertEqual(backend.comm_split_count(), 1)
699

700
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
701
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
702
    @skip_but_pass_in_sandcastle_if(
703
        torch.cuda.nccl.version()[-1] == "x", "NCCL test not for NCCLX"
704
    )
705
    def test_comm_split_subgroup(self):
706
        # Test `ncclCommSplit` for smaller subgroups of the world when
707
        # we've passed a specific device_id to init_process_group.
708
        store = c10d.FileStore(self.file_name, self.world_size)
709
        device = torch.device(f"cuda:{self.rank}")
710
        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
711
        backend = pg._get_backend(torch.device(device))
712

713
        tensor = torch.full((1,), self.rank).cuda(device)
714
        original_tensor = tensor.clone()
715
        ng = c10d.new_group([0])
716

717
        # comm split happens eagerly since device_id is passed to init_process_group.
718
        self.assertEqual(backend.comm_split_count(), 1)
719
        if self.rank == 0:
720
            dist.broadcast(tensor, 0, group=ng)
721

722
        # no additional comm split happens after a collective.
723
        self.assertEqual(backend.comm_split_count(), 1)
724
        self.assertEqual(tensor, original_tensor)
725
        dist.destroy_process_group()
726

727
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
728
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
729
    def test_comm_split_group(self):
730
        # Test `ncclCommSplit` for smaller subgroups of the world when
731
        # we've passed a specific device_id to init_process_group.
732
        store = c10d.FileStore(self.file_name, self.world_size)
733
        device = torch.device(f"cuda:{self.rank}")
734
        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
735
        backend = pg._get_backend(torch.device(device))
736

737
        tensor = torch.full((1,), self.rank).cuda(device)
738
        ng1 = c10d.split_group(pg, [[0, 1]])
739
        backend1 = pg._get_backend(torch.device(device))
740

741
        # check basic options are the same between parent and child
742
        self.assertEqual(backend.options._timeout, backend1.options._timeout)
743
        self.assertEqual(
744
            backend.options.is_high_priority_stream,
745
            backend1.options.is_high_priority_stream,
746
        )
747
        self.assertEqual(ng1.group_desc, "default_pg:split:0")
748

749
        # comm split happens eagerly since device_id is passed to init_process_group.
750
        self.assertEqual(backend.comm_split_count(), 1)
751
        dist.broadcast(tensor, 0, group=ng1)
752
        self.assertEqual(tensor, torch.full((1,), 0))
753

754
        ng2 = c10d.split_group(pg, [[0, 1]])
755
        self.assertEqual(ng2.group_desc, "default_pg:split:1")
756
        self.assertEqual(backend.comm_split_count(), 2)
757

758
        dist.destroy_process_group()
759

760
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
761
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
762
    def test_non_blocking_init(self):
763
        # Test creating a pg using nonblocking mode but not eagerly
764
        os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
765
        os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
766
        store = c10d.FileStore(self.file_name, self.world_size)
767
        device = self.rank_to_GPU[self.rank][0]
768
        pg = self._create_process_group_nccl(store, self.opts())
769
        backend = pg._get_backend(torch.device(device))
770
        self.assertEqual(backend.comm_split_count(), 0)
771
        reduce_tensor = torch.rand(10, 10, device=device)
772
        # Run an allreduce, which should trigger a comm init for pg
773
        pg.allreduce(reduce_tensor).wait()
774
        new_pg = c10d.new_group()
775
        # even after pg's collective call, new pg's comm is not initialized until its own collectcive calls
776
        self.assertEqual(backend.comm_split_count(), 0)
777
        broadcast_tensor = torch.tensor([self.rank]).cuda(device)
778
        new_pg.broadcast(broadcast_tensor, 0).wait()
779
        self.assertEqual(backend.comm_split_count(), 1)
780
        dist.destroy_process_group()
781

782
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
783
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
784
    def test_non_blocking_with_eager_init(self):
785
        # Test creating a pg eagerly with nonblocking mode when
786
        # we've passed a specific device_id to init_process_group.
787
        os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
788
        os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100"
789
        store = c10d.FileStore(self.file_name, self.world_size)
790
        device = torch.device(f"cuda:{self.rank}")
791
        # bound device to triger eager init mode
792
        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
793
        backend = pg._get_backend(torch.device(device))
794
        self.assertEqual(backend.comm_split_count(), 0)
795
        reduce_tensor = torch.rand(10, 10, device=device)
796
        # Run an allreduce, comm should have already started initilizaing,
797
        # but allreduce is issued to CUDA STREAM only after the initialization is a success
798
        pg.allreduce(reduce_tensor).wait()
799
        new_pg = c10d.new_group()
800
        # new pg's comm is initialized eagerly
801
        self.assertEqual(backend.comm_split_count(), 1)
802
        broadcast_tensor = torch.tensor([self.rank]).cuda(device)
803
        new_pg.broadcast(broadcast_tensor, 0).wait()
804
        self.assertEqual(backend.comm_split_count(), 1)
805
        dist.destroy_process_group()
806

807
    @requires_nccl()
808
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
809
    def test_get_uid(self):
810
        store = c10d.FileStore(self.file_name, self.world_size)
811
        device = torch.device(f"cuda:{self.rank}")
812
        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
813
        from torch.distributed.distributed_c10d import _get_process_group_uid
814

815
        self.assertEqual(_get_process_group_uid(pg), 0)
816
        pg_2 = c10d.new_group([0, 1])
817
        self.assertEqual(_get_process_group_uid(pg_2), 1)
818

819
    @requires_nccl()
820
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
821
    def test_set_process_group_desc(self):
822
        store = c10d.FileStore(self.file_name, self.world_size)
823
        device = torch.device(f"cuda:{self.rank}")
824
        pg_default = self._create_process_group_nccl(
825
            store, self.opts(), device_id=device
826
        )
827
        self.assertEqual(pg_default.group_desc, "default_pg")
828
        pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
829
        self.assertEqual(pg_1.group_desc, "test_purpose")
830
        pg_2 = c10d.new_group([0, 1])
831
        self.assertEqual(pg_2.group_desc, "undefined")
832

833

834
class DistributedDataParallelTest(
835
    test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
836
):
837
    def setUp(self):
838
        super().setUp()
839
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
840
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
841
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
842
        self._spawn_processes()
843

844
    def _get_process_group(self):
845
        store = self._get_store()
846
        c10d.init_process_group(
847
            "nccl", store=store, rank=self.rank, world_size=self.world_size
848
        )
849
        return c10d.distributed_c10d._get_default_group()
850

851
    def _test_nccl_backend(
852
        self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
853
    ):
854
        process_group = self._get_process_group()
855
        self._test_ddp_with_process_group(
856
            process_group, devices, device_ids, multi_device, gradient_as_bucket_view
857
        )
858

859
    @requires_nccl()
860
    @skip_if_lt_x_gpu(2)
861
    def test_nccl_propagate_error_reason(self):
862
        # Need to use TORCH_NCCL_BLOCKING_WAIT and not ASYNC_ERROR_HANDLING,
863
        # otherwise process will be taken down and we can't check for errors.
864
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
865
        os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
866
        # Need to disable TORCH_NCCL_DUMP_ON_TIMEOUT otherwise this test times out
867
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "0"
868
        store = c10d.FileStore(self.file_name, self.world_size)
869
        # provide sufficient timeout to initialize NCCL comm.
870
        pg = c10d.ProcessGroupNCCL(
871
            store, self.rank, self.world_size, timeout=timedelta(seconds=15)
872
        )
873
        pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
874
        pg.barrier().wait(timedelta(seconds=5))
875
        # Simulate stuckness in rank 0.
876
        if self.rank == 0:
877
            pg_gloo.barrier().wait()
878
        inp = torch.ones(1).cuda(self.rank)
879

880
        if self.rank != 0:
881
            # Time out due to rank 0 not calling into allreduce.
882
            with self.assertRaises(dist.DistBackendError):
883
                pg.allreduce([inp]).wait(timedelta(seconds=5))
884

885
            # Now when nonzero rank attempts to use communicator, original failure reason should be logged.
886
            try:
887
                pg.allreduce([torch.ones(2).cuda(self.rank)]).wait()
888
            except dist.DistBackendError as e:
889
                self.assertTrue("aborted" in str(e))
890
            else:
891
                self.fail("Expected error to be raised!")
892

893
            # Unblock rank 0
894
            pg_gloo.barrier().wait()
895

896
        # TODO: We can also test that if rank 0 attempts to use the communicator,
897
        # then we should error out with the info that it was aborted due to
898
        # timeout on another rank. Although this would only be the case after
899
        # the watchdog has run on the rank, and there is no reliable way
900
        # to confirm it has run.
901

902
    @requires_nccl()
903
    @skip_if_lt_x_gpu(2)
904
    def test_nccl_backend_multi_device_ids_not_allowed(self):
905
        int_devices = list(range(torch.cuda.device_count()))
906
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
907
        with self.assertRaisesRegex(
908
            ValueError, "device_ids can only be None or contain a single element."
909
        ):
910
            self._test_nccl_backend(devices, int_devices)
911

912
    @requires_nccl()
913
    @skip_if_lt_x_gpu(2)
914
    def test_nccl_backend_single_device_module_device_ids_None(self):
915
        self._test_nccl_backend(None, None)
916

917
    @requires_nccl()
918
    @skip_if_lt_x_gpu(2)
919
    def test_nccl_backend_single_device_module_empty_device_ids(self):
920
        # This tests the backward compatibility of accepting an empty list as `device_ids`,
921
        # although we no longer document this in favor of the default value of `None`,
922
        # which is consistent with multi-device modules and CPU modules.
923
        self._test_nccl_backend(None, [])
924

925
    @requires_nccl()
926
    @skip_if_lt_x_gpu(4)
927
    def test_nccl_backend_multi_device_module_device_ids_None(self):
928
        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
929
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
930
        self._test_nccl_backend(devices, None, multi_device=True)
931

932
    @requires_nccl()
933
    @skip_if_lt_x_gpu(2)
934
    def test_nccl_backend_1gpu_module_device_ids_integer_list(self):
935
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
936
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
937
        self._test_nccl_backend(devices, int_devices)
938

939
    @requires_nccl()
940
    @skip_if_lt_x_gpu(2)
941
    def test_nccl_backend_1gpu_module_device_ids_torch_device_list(self):
942
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
943
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
944
        self._test_nccl_backend(devices, devices)
945

946
    @requires_nccl()
947
    @skip_if_lt_x_gpu(4)
948
    def test_nccl_backend_2gpu_module(self):
949
        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
950
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
951
        self._test_nccl_backend(devices, None, multi_device=True)
952

953
    @requires_nccl()
954
    @skip_if_lt_x_gpu(8)
955
    def test_nccl_backend_4gpu_module(self):
956
        int_devices = gpus_for_rank(self.world_size)[self.rank][:4]
957
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
958
        self._test_nccl_backend(devices, None, multi_device=True)
959

960
    @requires_nccl()
961
    @skip_if_lt_x_gpu(4)
962
    def test_ddp_multi_device_module_config(self):
963
        gpus = gpus_for_rank(self.world_size)[self.rank]
964

965
        self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process")
966

967
        process_group = self._get_process_group()
968

969
        gpus = gpus[:2]
970
        model = DoubleGpuNet(gpus)
971

972
        with self.assertRaisesRegex(
973
            ValueError,
974
            "DistributedDataParallel device_ids and output_device arguments only work with "
975
            "single-device/multiple-device GPU modules or CPU modules",
976
        ):
977
            ddp_model = DistributedDataParallel(
978
                model, output_device=gpus[1], process_group=process_group
979
            )
980

981
        with self.assertRaisesRegex(
982
            ValueError, "device_ids can only be None or contain a single element."
983
        ):
984
            ddp_model = DistributedDataParallel(
985
                model, device_ids=gpus, process_group=process_group
986
            )
987

988
        with self.assertRaisesRegex(
989
            ValueError, "input module must be on the same type of devices"
990
        ):
991
            model.fc1 = model.fc1.cpu()
992
            ddp_model = DistributedDataParallel(model, process_group=process_group)
993

994
        model = model.cpu()
995
        with self.assertRaisesRegex(
996
            ValueError, "device_ids can only be None or contain a single element."
997
        ):
998
            ddp_model = DistributedDataParallel(
999
                model, device_ids=gpus, process_group=process_group
1000
            )
1001

1002
    def _test_fp16(self, gradient_as_bucket_view=False):
1003
        process_group = self._get_process_group()
1004

1005
        gpus = gpus_for_rank(self.world_size)[self.rank]
1006
        model = nn.Linear(1, 1, bias=False).cuda(gpus[0]).half()
1007
        nn.init.constant_(model.weight, 1)
1008
        ddp_model = DistributedDataParallel(
1009
            model,
1010
            device_ids=[gpus[0]],
1011
            process_group=process_group,
1012
            bucket_cap_mb=0.001,
1013
            gradient_as_bucket_view=gradient_as_bucket_view,
1014
        )
1015

1016
        # Input 2**15, so that the gradients will overflow with a
1017
        # world_size of 2, unless we normalize the gradient by the
1018
        # world_size before the reduction
1019
        input = torch.tensor([[2**15]]).cuda(gpus[0]).half()
1020

1021
        # Step model
1022
        ddp_model.train()
1023
        output = ddp_model(input)
1024
        loss = output.sum()
1025
        loss.backward()
1026

1027
        self.assertFalse(any(torch.isinf(p.grad).any() for p in ddp_model.parameters()))
1028

1029
    @requires_nccl()
1030
    @skip_if_lt_x_gpu(2)
1031
    def test_fp16(self):
1032
        self._test_fp16()
1033

1034
    @requires_nccl()
1035
    @skip_if_lt_x_gpu(2)
1036
    def test_fp16_grad_is_view(self):
1037
        self._test_fp16(gradient_as_bucket_view=True)
1038

1039
    def _test_arbitrary_forward_return_value(self, gradient_as_bucket_view=False):
1040
        """
1041
        Note: this test can be sped up by only running it on a CPU module
1042
        once DistributedDataParallel supports them.
1043
        """
1044
        process_group = self._get_process_group()
1045

1046
        class ForwardReturnValueModule(nn.Module):
1047
            def __init__(self) -> None:
1048
                super().__init__()
1049
                self.fc1 = nn.Linear(2, 10, bias=False)
1050
                self.fc2 = nn.Linear(10, 4, bias=False)
1051
                self.fc3 = nn.Linear(4, 4, bias=False)
1052
                self.relu = nn.ReLU()
1053

1054
            def forward(self, x, fn):
1055
                x = self.relu(self.fc1(x))
1056
                x = self.relu(self.fc2(x))
1057
                # The first softmax does NOT include fc3 in its autograd graph
1058
                # whereas the second softmax DOES. If we pass only the first
1059
                # tensor we see in the output to the reducer, it marks the
1060
                # gradient for fc3 as ready (because it doesn't show up). If
1061
                # downstream uses of this return value choose to differentiate
1062
                # against the second output tensor, it would still receive a
1063
                # gradient and a callback for this tensor, resulting in a crash.
1064
                return fn(
1065
                    F.softmax(x, dim=1),
1066
                    F.softmax(self.fc3(x), dim=1),
1067
                )
1068

1069
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1070
        model = DistributedDataParallel(
1071
            ForwardReturnValueModule().float().to(device_id),
1072
            device_ids=[device_id],
1073
            process_group=process_group,
1074
            gradient_as_bucket_view=gradient_as_bucket_view,
1075
        )
1076

1077
        batch_size = 4
1078
        criterion = nn.CrossEntropyLoss()
1079
        input = torch.rand([batch_size, 2], dtype=torch.float)
1080
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1081
            device_id
1082
        )
1083

1084
        # Always run "backward" to ensure the reducer is called by autograd.
1085
        # If we don't correctly capture the output tensors from the return value,
1086
        # the reducer won't see a hook for the unused parameter, and throw an error.
1087
        # The correct capture is what we're testing in this function.
1088
        def test(box, unbox):
1089
            output = model(input, fn=box)
1090
            loss = criterion(unbox(output), target)
1091
            loss.backward()
1092

1093
        # Test with identity return value
1094
        test(
1095
            box=lambda x, y: (x, y),
1096
            unbox=lambda obj: obj[1],
1097
        )
1098

1099
        # Test with list return value
1100
        test(
1101
            box=lambda x, y: ["foo", x, "bar", y],
1102
            unbox=lambda obj: obj[3],
1103
        )
1104

1105
        # Test with tuple return value
1106
        test(
1107
            box=lambda x, y: ("foo", x, "bar", y),
1108
            unbox=lambda obj: obj[3],
1109
        )
1110

1111
        # Test with dict return value
1112
        test(
1113
            box=lambda x, y: {"foo": "bar", "a": x, "b": y},
1114
            unbox=lambda obj: obj["b"],
1115
        )
1116

1117
        # Test with list with dict return value
1118
        test(
1119
            box=lambda x, y: ["foo", "bar", {"a": x, "b": y}],
1120
            unbox=lambda obj: obj[2]["b"],
1121
        )
1122

1123
        # Test with dict with list return value
1124
        test(
1125
            box=lambda x, y: {"foo": "bar", "list": [0, x, 1, y]},
1126
            unbox=lambda obj: obj["list"][3],
1127
        )
1128

1129
    @requires_nccl()
1130
    @skip_if_lt_x_gpu(2)
1131
    def test_arbitrary_forward_return_value(self):
1132
        self._test_arbitrary_forward_return_value()
1133

1134
    @requires_nccl()
1135
    @skip_if_lt_x_gpu(2)
1136
    def test_arbitrary_forward_return_value_grad_is_view(self):
1137
        self._test_arbitrary_forward_return_value(gradient_as_bucket_view=True)
1138

1139
    @requires_nccl()
1140
    @skip_if_lt_x_gpu(2)
1141
    def test_ddp_with_lazy_parameters(self):
1142
        process_group = self._get_process_group()
1143
        with self.assertRaisesRegex(
1144
            RuntimeError, "Modules with uninitialized parameters"
1145
        ):
1146
            DistributedDataParallel(
1147
                torch.nn.LazyLinear(10), process_group=process_group
1148
            )
1149

1150
    def _test_find_unused_parameters_kwarg(self, gradient_as_bucket_view=False):
1151
        """
1152
        Note: this test can be sped up by only running it on a CPU module
1153
        once DistributedDataParallel supports them.
1154
        """
1155
        torch.cuda.set_device(self.rank)
1156
        dist.init_process_group(
1157
            backend="nccl",
1158
            world_size=self.world_size,
1159
            rank=self.rank,
1160
            init_method=f"file://{self.file_name}",
1161
        )
1162
        process_group = c10d.distributed_c10d._get_default_group()
1163

1164
        class FindUnusedParametersModule(nn.Module):
1165
            def __init__(self) -> None:
1166
                super().__init__()
1167
                self.fc1 = nn.Linear(2, 10, bias=False)
1168
                self.fc2 = nn.Linear(10, 4, bias=False)
1169
                self.fc3 = nn.Linear(4, 4, bias=False)
1170
                self.relu = nn.ReLU()
1171

1172
            def forward(self, x):
1173
                x = self.relu(self.fc1(x))
1174
                x = self.relu(self.fc2(x))
1175
                # Return the fc3 module so that the caller can invoke it
1176
                # outside of the forward function. While this is bad practice,
1177
                # we can use it to trigger a reducer error.
1178
                return (F.softmax(x, dim=1), self.fc3)
1179

1180
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1181
        batch_size = 4
1182
        criterion = nn.CrossEntropyLoss()
1183
        input = torch.rand([batch_size, 2], dtype=torch.float)
1184
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1185
            device_id
1186
        )
1187

1188
        ddp_model = None
1189

1190
        def test_find_unused_parameters(
1191
            find_unused_parameters, test_default=False, gradient_as_bucket_view=False
1192
        ):
1193
            if test_default:
1194
                model = DistributedDataParallel(
1195
                    FindUnusedParametersModule().float().to(device_id),
1196
                    device_ids=[device_id],
1197
                    process_group=process_group,
1198
                    gradient_as_bucket_view=gradient_as_bucket_view,
1199
                )
1200
            else:
1201
                model = DistributedDataParallel(
1202
                    FindUnusedParametersModule().float().to(device_id),
1203
                    device_ids=[device_id],
1204
                    process_group=process_group,
1205
                    find_unused_parameters=find_unused_parameters,
1206
                    gradient_as_bucket_view=gradient_as_bucket_view,
1207
                )
1208
            nonlocal ddp_model
1209
            ddp_model = model
1210

1211
            output, fc3 = model(input)
1212
            output = fc3(output)
1213
            loss = criterion(output, target)
1214
            loss.backward()
1215

1216
        # First test that finding unused params under these conditions is to
1217
        # trigger an error when `backward` is called (because fc3 is an unused
1218
        # parameter and will therefore be marked ready twice).
1219
        try:
1220
            test_find_unused_parameters(
1221
                True, gradient_as_bucket_view=gradient_as_bucket_view
1222
            )
1223
        except Exception as ex:
1224
            self.assertTrue(
1225
                str(ex).startswith(
1226
                    "Expected to mark a variable ready only once.",
1227
                )
1228
            )
1229
            unused_index = 2
1230
            unused_index_str = f"Parameter at index {unused_index}"
1231
            model = ddp_model.module
1232
            for module_name, module in model.named_modules():
1233
                if module == model.fc3:
1234
                    for parameter_name, _ in module.named_parameters(recurse=False):
1235
                        unused_fqn = f"{module_name}.{parameter_name}"
1236
                        # Only one such parameter in model.fc3, since bias=False
1237
                        break
1238

1239
            if dist.get_debug_level() != dist.DebugLevel.OFF:
1240
                unused_index_str += f" with name {unused_fqn}"
1241

1242
            self.assertTrue(unused_index_str in str(ex))
1243
        else:
1244
            self.fail("Expected exception")
1245

1246
        dist.barrier(process_group)
1247

1248
        # Then test that the default behavior can be overridden by setting
1249
        # `find_unused_parameters=False`.
1250
        try:
1251
            test_find_unused_parameters(
1252
                False, gradient_as_bucket_view=gradient_as_bucket_view
1253
            )
1254
        except Exception as ex:
1255
            self.fail(f"Unexpected exception: {ex}")
1256

1257
        # Test find_unused_parameters defaults to False
1258
        try:
1259
            test_find_unused_parameters(
1260
                True, test_default=True, gradient_as_bucket_view=gradient_as_bucket_view
1261
            )
1262
        except Exception as ex:
1263
            self.fail(f"Unexpected exception: {ex}")
1264

1265
    # TODO: Combine the following tests once https://github.com/pytorch/pytorch/issues/55967
1266
    # is resolved.
1267
    @requires_nccl()
1268
    @skip_if_lt_x_gpu(2)
1269
    @with_dist_debug_levels(levels=["DETAIL"])
1270
    def test_find_unused_parameters_kwarg_debug_detail(self):
1271
        self._test_find_unused_parameters_kwarg()
1272

1273
    @requires_nccl()
1274
    @skip_if_lt_x_gpu(2)
1275
    @with_dist_debug_levels(levels=["INFO"])
1276
    def test_find_unused_parameters_kwarg_debug_info(self):
1277
        self._test_find_unused_parameters_kwarg()
1278

1279
    @requires_nccl()
1280
    @skip_if_lt_x_gpu(2)
1281
    @with_dist_debug_levels(levels=["OFF"])
1282
    def test_find_unused_parameters_kwarg_debug_off(self):
1283
        self._test_find_unused_parameters_kwarg()
1284

1285
    @requires_nccl()
1286
    @skip_if_lt_x_gpu(2)
1287
    @with_dist_debug_levels(levels=["DETAIL"])
1288
    def test_find_unused_parameters_kwarg_grad_is_view_debug_detail(self):
1289
        self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True)
1290

1291
    @requires_nccl()
1292
    @skip_if_lt_x_gpu(2)
1293
    @with_dist_debug_levels(levels=["INFO"])
1294
    def test_find_unused_parameters_kwarg_grad_is_view_debug_info(self):
1295
        self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True)
1296

1297
    @requires_nccl()
1298
    @skip_if_lt_x_gpu(2)
1299
    @with_dist_debug_levels(levels=["OFF"])
1300
    def test_find_unused_parameters_kwarg_grad_is_view_debug_off(self):
1301
        self._test_find_unused_parameters_kwarg(gradient_as_bucket_view=True)
1302

1303
    def _test_multiple_outputs_multiple_backward(self, gradient_as_bucket_view=False):
1304
        """
1305
        Note: this test can be sped up by only running it on a CPU module
1306
        once DistributedDataParallel supports them.
1307
        """
1308
        process_group = self._get_process_group()
1309

1310
        class MultipleOutputModule(nn.Module):
1311
            def __init__(self) -> None:
1312
                super().__init__()
1313

1314
                def define_module():
1315
                    return nn.Sequential(
1316
                        nn.Linear(2, 10, bias=False),
1317
                        nn.ReLU(),
1318
                        nn.Linear(10, 4, bias=False),
1319
                        nn.ReLU(),
1320
                    )
1321

1322
                self.module0 = define_module()
1323
                self.module1 = define_module()
1324

1325
            def forward(self, x):
1326
                return (
1327
                    F.softmax(self.module0(x), dim=1),
1328
                    F.softmax(self.module1(x), dim=1),
1329
                )
1330

1331
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1332
        model = DistributedDataParallel(
1333
            MultipleOutputModule().float().to(device_id),
1334
            device_ids=[device_id],
1335
            process_group=process_group,
1336
            gradient_as_bucket_view=gradient_as_bucket_view,
1337
        )
1338

1339
        batch_size = 4
1340
        criterion = nn.CrossEntropyLoss()
1341
        input = torch.rand([batch_size, 2], dtype=torch.float)
1342
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1343
            device_id
1344
        )
1345

1346
        # Compute loss and gradients for both outputs
1347
        output1, output2 = model(input)
1348
        loss1 = criterion(output1, target)
1349
        loss1.backward()
1350
        loss2 = criterion(output2, target)
1351
        loss2.backward()
1352

1353
    @requires_nccl()
1354
    @skip_if_lt_x_gpu(2)
1355
    def test_multiple_outputs_multiple_backward(self):
1356
        self._test_multiple_outputs_multiple_backward()
1357

1358
    @requires_nccl()
1359
    @skip_if_lt_x_gpu(2)
1360
    def test_multiple_outputs_multiple_backward_grad_is_view(self):
1361
        self._test_multiple_outputs_multiple_backward(gradient_as_bucket_view=True)
1362

1363
    @requires_nccl()
1364
    @skip_if_lt_x_gpu(2)
1365
    def test_no_grad(self):
1366
        """
1367
        Note: this test can be sped up by only running it on a CPU module
1368
        once DistributedDataParallel supports them.
1369
        """
1370
        process_group = self._get_process_group()
1371

1372
        class NoGradModule(nn.Module):
1373
            def __init__(self) -> None:
1374
                super().__init__()
1375
                self.fc1 = nn.Linear(2, 10, bias=False)
1376
                self.fc2 = nn.Linear(10, 4, bias=False)
1377
                self.relu = nn.ReLU()
1378

1379
            def forward(self, x):
1380
                x = self.relu(self.fc1(x))
1381
                x = self.relu(self.fc2(x))
1382
                return F.softmax(x, dim=1)
1383

1384
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1385
        model = DistributedDataParallel(
1386
            NoGradModule().float().to(device_id),
1387
            device_ids=[device_id],
1388
            process_group=process_group,
1389
        )
1390

1391
        batch_size = 4
1392
        input = torch.rand([batch_size, 2], dtype=torch.float)
1393

1394
        def check_no_grads():
1395
            for p in model.parameters():
1396
                self.assertTrue(p.requires_grad)
1397
                self.assertIsNone(p.grad)
1398

1399
        # After initialization, no parameter has their gradient set.
1400
        check_no_grads()
1401

1402
        # Run `forward` function with torch.no_grad()
1403
        with torch.no_grad():
1404
            output = model(input)
1405
            self.assertTrue(isinstance(output, torch.Tensor))
1406

1407
        # No parameter should have their gradient set.
1408
        check_no_grads()
1409

1410
    def _test_accumulate_gradients_module(self, gradient_as_bucket_view=False):
1411
        # This is NOT the recommended way to implement accumulating grads, but
1412
        # we would like to make sure DDP does not mess up with the underlying
1413
        # module.
1414
        int_devices = gpus_for_rank(self.world_size)[self.rank][:1]
1415
        devices = [torch.device("cuda:" + str(i)) for i in int_devices]
1416
        process_group = self._get_process_group()
1417
        global_batch_size = self.world_size
1418

1419
        model, ddp_model, input, target = self._prepare_single_device_module(
1420
            process_group, devices, devices, global_batch_size, gradient_as_bucket_view
1421
        )
1422

1423
        def step_model(model, input, target):
1424
            model.train()
1425
            output = model(input)
1426
            loss = F.mse_loss(output, target.to(output.device))
1427
            loss.backward()
1428

1429
        # ensure accumulate grads works with no_grad
1430
        with torch.no_grad():
1431
            ddp_model.train()
1432
            ddp_model.module(input)
1433

1434
        # Check two model parameters over 4 iterations.
1435
        # Use 4 iterations because we alternate between reducing and
1436
        # not reducing and want to make sure we switch both ways.
1437
        for iteration in range(4):
1438
            step_model(model, input, target)
1439

1440
            if iteration % 2 == 0:
1441
                # Skip gradients sync without calling prepare_for_backward
1442
                step_model(
1443
                    ddp_model.module,
1444
                    input[self.rank : (self.rank + 1)],
1445
                    target[self.rank : (self.rank + 1)],
1446
                )
1447
                for i, j in zip(model.parameters(), ddp_model.parameters()):
1448
                    self.assertNotEqual(i.grad, j.grad)
1449
            else:
1450
                step_model(
1451
                    ddp_model,
1452
                    input[self.rank : (self.rank + 1)],
1453
                    target[self.rank : (self.rank + 1)],
1454
                )
1455
                for i, j in zip(model.parameters(), ddp_model.parameters()):
1456
                    self.assertEqual(i.grad, j.grad, rtol=1.3e-06, atol=5e-5)
1457

1458
            # Shuffle the input so that DDP input is different
1459
            torch.manual_seed(1337 + iteration)
1460
            input = input[torch.randperm(global_batch_size)]
1461

1462
    @requires_nccl()
1463
    @skip_if_lt_x_gpu(2)
1464
    def test_accumulate_gradients_module(self):
1465
        self._test_accumulate_gradients_module()
1466

1467
    @requires_nccl()
1468
    @skip_if_lt_x_gpu(2)
1469
    def test_accumulate_gradients_module_with_grad_is_view(self):
1470
        self._test_accumulate_gradients_module(gradient_as_bucket_view=True)
1471

1472
    @requires_nccl()
1473
    @skip_if_lt_x_gpu(2)
1474
    def test_failure_recovery(self):
1475
        process_group = self._get_process_group()
1476

1477
        # need to create a separate file for the recovered FileStore, because
1478
        # the original one will be deleted when destructing the first FileStore.
1479
        recovery_filename = self.file_name + "_recovery"
1480

1481
        if self.rank == 0:
1482
            # the file will be deleted by the recovered FileStore
1483
            open(recovery_filename, "w").close()
1484

1485
        # not necessary to run barrier here, as DDP will synchronize
1486

1487
        class TestModel(nn.Module):
1488
            def __init__(self) -> None:
1489
                super().__init__()
1490
                self.fc1 = nn.Linear(2, 10, bias=False)
1491
                self.fc2 = nn.Linear(10, 4, bias=False)
1492
                self.relu = nn.ReLU()
1493

1494
            def forward(self, x):
1495
                x = self.relu(self.fc1(x))
1496
                x = self.relu(self.fc2(x))
1497
                return F.softmax(x, dim=1)
1498

1499
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1500
        model = TestModel().float().to(device_id)
1501
        ddp = DistributedDataParallel(
1502
            model,
1503
            device_ids=[device_id],
1504
            process_group=process_group,
1505
        )
1506

1507
        batch_size = 4
1508
        criterion = nn.CrossEntropyLoss()
1509
        input = torch.rand([batch_size, 2], dtype=torch.float)
1510
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1511
            device_id
1512
        )
1513

1514
        for _ in range(6):
1515
            output = ddp(input)
1516
            loss = criterion(output, target)
1517
            loss.backward()
1518

1519
        del ddp
1520
        c10d.destroy_process_group(process_group)
1521

1522
        store = c10d.FileStore(recovery_filename, self.world_size)
1523
        c10d.init_process_group(
1524
            "nccl", store=store, rank=self.rank, world_size=self.world_size
1525
        )
1526
        process_group = c10d.distributed_c10d._get_default_group()
1527
        ddp = DistributedDataParallel(
1528
            model,
1529
            device_ids=[device_id],
1530
            process_group=process_group,
1531
        )
1532

1533
        input = torch.rand([batch_size, 2], dtype=torch.float)
1534
        target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(
1535
            device_id
1536
        )
1537
        for _ in range(6):
1538
            output = ddp(input)
1539
            loss = criterion(output, target)
1540
            loss.backward()
1541

1542
    @requires_nccl()
1543
    @skip_if_lt_x_gpu(2)
1544
    def test_pass_default_pg(self):
1545
        dist.init_process_group(
1546
            "nccl",
1547
            init_method=f"file://{self.file_name}",
1548
            world_size=self.world_size,
1549
            rank=self.rank,
1550
        )
1551

1552
        default_pg = c10d.distributed_c10d._get_default_group()
1553
        dist.destroy_process_group(default_pg)
1554
        self.assertFalse(dist.is_initialized())
1555

1556
    def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size):
1557
        process_group = self._get_process_group()
1558

1559
        global_batch_size = local_batch_size * self.world_size
1560

1561
        # Carry out some trials with small buckets and some with big buckets.
1562
        bucketsizes = (0.000001, 25)
1563
        # Tuples of lists.  Each list describes per-layer characteristics for one trial.
1564
        layer_formats = (
1565
            [torch.contiguous_format] * 4,
1566
            [torch.channels_last] * 2 + [torch.contiguous_format] * 2,
1567
            [torch.channels_last] * 4,
1568
        )
1569
        layer_dtypes = (
1570
            [torch.float] * 4,
1571
            [torch.float] * 2 + [torch.half] * 2,
1572
            [torch.half] * 4,
1573
        )
1574

1575
        input_dev = layer_devs[0] if isinstance(layer_devs, list) else layer_devs
1576
        target_dev = layer_devs[-1] if isinstance(layer_devs, list) else layer_devs
1577
        input = torch.randn(
1578
            (global_batch_size, 8, 8, 8), device=input_dev, dtype=torch.float
1579
        )
1580
        target = torch.randn(
1581
            (global_batch_size, 8, 4, 4), device=target_dev, dtype=torch.float
1582
        )
1583
        local_batch_start = self.rank * local_batch_size
1584
        local_batch_end = (self.rank + 1) * local_batch_size
1585

1586
        # Reducer.cpp sneakily creates one "initial bucket" that ignores the "bucket_cap_mb"
1587
        # argument.  The following makes sure the initial bucket also complies.
1588
        @contextmanager
1589
        def first_bucket_size(ddp_bucket_mb):
1590
            old_DEFAULT_FIRST_BUCKET_BYTES = dist._DEFAULT_FIRST_BUCKET_BYTES
1591
            dist._DEFAULT_FIRST_BUCKET_BYTES = int(ddp_bucket_mb * 1.0e6)
1592
            try:
1593
                yield
1594
            finally:
1595
                dist._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES
1596

1597
        with torch.backends.cudnn.flags(
1598
            enabled=True, deterministic=True, benchmark=False
1599
        ):
1600
            for formats, dtypes, bucketsize in product(
1601
                layer_formats, layer_dtypes, bucketsizes
1602
            ):
1603
                with first_bucket_size(bucketsize):
1604
                    model_msg = f"rank = {self.rank} formats = {formats} dtypes = {dtypes} bucketsize = {bucketsize} "
1605
                    try:
1606
                        m = ConvNet(layer_devs, formats, dtypes)
1607
                        m_ddp = DistributedDataParallel(
1608
                            copy.deepcopy(m),
1609
                            device_ids=replica_devices,
1610
                            process_group=process_group,
1611
                            bucket_cap_mb=bucketsize,
1612
                        )
1613
                        opt = torch.optim.SGD(m.parameters(), lr=0.1)
1614
                        opt_ddp = torch.optim.SGD(m_ddp.parameters(), lr=0.1)
1615
                        has_half = any(p.dtype is torch.half for p in m.parameters())
1616
                        tol = 1.0e-3 if has_half else 1.0e-5
1617
                    except BaseException:
1618
                        # Prints case-specific debugging info to narrow down failing case.
1619
                        print(
1620
                            "Caught exception during model creation for " + model_msg,
1621
                            flush=True,
1622
                        )
1623
                        raise
1624
                    # 3 iters:  First iter creates grads, second iter retests after rebucketing,
1625
                    # third iter tries zeroed grads.
1626
                    for it in range(3):
1627
                        iter_msg = f"iter = {it} " + model_msg
1628
                        named_msg = iter_msg
1629
                        try:
1630
                            F.mse_loss(m(input).float(), target).backward()
1631
                            F.mse_loss(
1632
                                m_ddp(input[local_batch_start:local_batch_end]).float(),
1633
                                target[local_batch_start:local_batch_end],
1634
                            ).backward()
1635
                            for i, ((layer_name, m_child), m_ddp_child) in enumerate(
1636
                                zip(m.named_children(), m_ddp.module.children())
1637
                            ):
1638
                                named_msg = layer_name + ".weight" + " " + iter_msg
1639
                                self.assertTrue(
1640
                                    m_child.weight.grad.is_contiguous(
1641
                                        memory_format=formats[i]
1642
                                    ),
1643
                                    named_msg,
1644
                                )
1645
                                self.assertTrue(
1646
                                    m_ddp_child.weight.grad.is_contiguous(
1647
                                        memory_format=formats[i]
1648
                                    ),
1649
                                    named_msg,
1650
                                )
1651
                                for j, ((param_name, p), p_ddp) in enumerate(
1652
                                    zip(
1653
                                        m_child.named_parameters(),
1654
                                        m_ddp_child.parameters(),
1655
                                    )
1656
                                ):
1657
                                    named_msg = (
1658
                                        layer_name + "." + param_name + " " + iter_msg
1659
                                    )
1660
                                    self.assertEqual(
1661
                                        p.grad, p_ddp.grad, rtol=tol, atol=tol
1662
                                    )
1663
                            opt.step()
1664
                            opt_ddp.step()
1665
                            if it == 0:
1666
                                for p, p_ddp in zip(m.parameters(), m_ddp.parameters()):
1667
                                    p.grad = None
1668
                                    p_ddp.grad = None
1669
                            else:
1670
                                m.zero_grad()
1671
                                m_ddp.zero_grad()
1672
                        except BaseException:
1673
                            # Makes sure we still get info if an error occurred somewhere other than the asserts.
1674
                            print(
1675
                                "Caught exception during iterations at " + named_msg,
1676
                                flush=True,
1677
                            )
1678
                            raise
1679

1680
    @requires_nccl()
1681
    @skip_if_lt_x_gpu(2)
1682
    def test_grad_layout_1devicemodule_1replicaperprocess(self):
1683
        dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0]))
1684
        # Tells DDP to use just one device.
1685
        replica_devices = [dev0]
1686
        # Tells _test_grad_layout to construct ConvNet with all layers on this process's first assigned device.
1687
        layer_devs = dev0
1688
        local_batch_size = 8
1689
        self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
1690

1691
    @requires_nccl()
1692
    @skip_if_lt_x_gpu(4)
1693
    @skip_if_rocm
1694
    def test_grad_layout_2devicemodule(self):
1695
        int_devices = gpus_for_rank(self.world_size)[self.rank][:2]
1696
        dev0 = torch.device("cuda:" + str(int_devices[0]))
1697
        dev1 = torch.device("cuda:" + str(int_devices[1]))
1698
        # DDP's default behavior for a multi-device module is "don't replicate."
1699
        replica_devices = None
1700
        # Tells _test_grad_layout to constructs this process's ConvNet on 2 devices, with 2 layers on each device.
1701
        layer_devs = [dev0] * 2 + [dev1] * 2
1702
        local_batch_size = 8
1703
        self._test_grad_layout(replica_devices, layer_devs, local_batch_size)
1704

1705
    @requires_nccl()
1706
    @skip_if_lt_x_gpu(2)
1707
    def test_param_layout_mismatch_error(self):
1708
        process_group = self._get_process_group()
1709

1710
        dev0 = torch.device("cuda:" + str(gpus_for_rank(self.world_size)[self.rank][0]))
1711
        layer_devs = dev0
1712
        layer_formats = (
1713
            [torch.contiguous_format] * 4
1714
            if self.rank == 0
1715
            else [torch.channels_last] * 4
1716
        )
1717
        layer_dtypes = [torch.float] * 4
1718

1719
        m = ConvNet(layer_devs, layer_formats, layer_dtypes)
1720
        if self.rank == 0:
1721
            m_ddp = DistributedDataParallel(
1722
                m, device_ids=[dev0], process_group=process_group
1723
            )
1724
        else:
1725
            with self.assertRaisesRegex(
1726
                RuntimeError,
1727
                ".* appears not to match strides of the same param in process 0",
1728
            ):
1729
                m_ddp = DistributedDataParallel(
1730
                    m, device_ids=[dev0], process_group=process_group
1731
                )
1732

1733
    def _gpu_model_with_ddp_comm_hook(
1734
        self,
1735
        process_group,
1736
        hook=None,
1737
        gradient_as_bucket_view=False,
1738
        state=None,
1739
        static_graph=False,
1740
    ):
1741
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
1742
        gpu_model = DistributedDataParallel(
1743
            ModuleForDdpCommHook().to(device_id),
1744
            device_ids=[device_id],
1745
            process_group=process_group,
1746
            gradient_as_bucket_view=gradient_as_bucket_view,
1747
            static_graph=static_graph,
1748
        )
1749

1750
        # Register a DDP communication hook if any.
1751
        if hook is not None:
1752
            gpu_model.register_comm_hook(state, hook)
1753

1754
        return gpu_model
1755

1756
    @requires_nccl()
1757
    @skip_if_lt_x_gpu(2)
1758
    def test_ddp_comm_hook_future_passing_gpu_nccl(self):
1759
        """
1760
        This unit test verifies whether the Future object is passed properly using nccl backend.
1761
        The hook callback function creates a Future object and sets a value to it.
1762
        """
1763
        process_group = self._get_process_group()
1764

1765
        # Get GPU model with simple_hook registered.
1766
        gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)
1767

1768
        # check whether the grads are equal to what simple_hook's then callback returns.
1769
        # without the comm_hook, result would be 0.25 * torch.ones(2, 2).
1770
        self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
1771

1772
    def _test_ddp_comm_hook_allreduce_hook_nccl(
1773
        self, gradient_as_bucket_view=False, static_graph=False
1774
    ):
1775
        """
1776
        This unit test verifies whether a DDP communication hook that just calls
1777
        allreduce gives the same result with the case of no hook registered.
1778
        Without the then callback, the future_value in reducer is no longer
1779
        a PyObject, and this unit test verifies future_value is properly checked.
1780
        """
1781
        process_group = self._get_process_group()
1782

1783
        def allreduce_hook(
1784
            state: object, bucket: dist.GradBucket
1785
        ) -> torch.futures.Future[torch.Tensor]:
1786
            tensors = [bucket.buffer() / self.world_size]
1787
            return (
1788
                process_group.allreduce(tensors)
1789
                .get_future()
1790
                .then(lambda fut: fut.value()[0])
1791
            )
1792

1793
        # Get GPU model with allreduce_hook registered.
1794
        gpu_model = self._gpu_model_with_ddp_comm_hook(
1795
            process_group, allreduce_hook, gradient_as_bucket_view, static_graph
1796
        )
1797

1798
        # check whether the grads are equal to what DDP without hook would return.
1799
        self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1800

1801
    def _test_default_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
1802
        """
1803
        This unit test verifies whether default Python DDP communication hooks ALLREDUCE, FP16_COMPRESS
1804
        and BF16_COMPRESS, can give the same result with the case of no hook registered.
1805
        """
1806
        process_group = self._get_process_group()
1807

1808
        # For these default DDP comm hooks, the only state is process group.
1809
        state = process_group
1810
        hook_options = [default.allreduce_hook, default.fp16_compress_hook]
1811
        if (
1812
            not TEST_WITH_ROCM
1813
            and BFLOAT16_AVAILABLE
1814
            and c10d.is_nccl_available()
1815
            and torch.cuda.nccl.version() >= (2, 10)
1816
        ):
1817
            hook_options.append(default.bf16_compress_hook)
1818
        for hook in hook_options:
1819
            # Get GPU model with the hook registered.
1820
            # The first arg 'process_group' is used for initializing the test environment,
1821
            # so it cannot be replaced by 'state', although they have the same value.
1822
            gpu_model = self._gpu_model_with_ddp_comm_hook(
1823
                process_group, hook, gradient_as_bucket_view, state
1824
            )
1825

1826
            # check whether the grads are equal to what DDP without hook would return.
1827
            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1828

1829
    def _test_fp16_compress_wrapper(self, gradient_as_bucket_view=False):
1830
        """
1831
        This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with
1832
        the FP16_WRAPPER can give the same result as when there is no hook registered.
1833
        """
1834
        process_group = self._get_process_group()
1835
        powerSGD_state = powerSGD.PowerSGDState(process_group=process_group)
1836

1837
        hook_args = [
1838
            (powerSGD.powerSGD_hook, powerSGD_state),
1839
            (default.allreduce_hook, process_group),
1840
        ]
1841

1842
        for hook, state in hook_args:
1843
            gpu_model = self._gpu_model_with_ddp_comm_hook(
1844
                process_group,
1845
                default.fp16_compress_wrapper(hook),
1846
                gradient_as_bucket_view,
1847
                state,
1848
            )
1849

1850
            # check whether the grads are equal to what DDP without hook would return.
1851
            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1852

1853
    def _test_bf16_compress_wrapper(self, gradient_as_bucket_view=False):
1854
        """
1855
        This unit test verifies whether wrapping the ALLREDUCE and POWER_SGD hooks with
1856
        the BF16_WRAPPER can give the same result as when there is no hook registered.
1857
        """
1858
        process_group = self._get_process_group()
1859
        powerSGD_state = powerSGD.PowerSGDState(process_group=process_group)
1860

1861
        hook_args = [
1862
            (powerSGD.powerSGD_hook, powerSGD_state),
1863
            (default.allreduce_hook, process_group),
1864
        ]
1865

1866
        for hook, state in hook_args:
1867
            gpu_model = self._gpu_model_with_ddp_comm_hook(
1868
                process_group,
1869
                default.bf16_compress_wrapper(hook),
1870
                gradient_as_bucket_view,
1871
                state,
1872
            )
1873

1874
            # check whether the grads are equal to what DDP without hook would return.
1875
            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1876

1877
    def _test_powerSGD_ddp_comm_hook_nccl(self, gradient_as_bucket_view=False):
1878
        """
1879
        This unit test verifies whether Python DDP communication hook POWER_SGD
1880
        can give the same result with the case of no hook registered.
1881
        """
1882
        process_group = self._get_process_group()
1883

1884
        # Get GPU model with the hook registered.
1885
        # Test the hook with different algorithmic configs.
1886
        for use_error_feedback, warm_start, batch_tensors_with_same_shape in product(
1887
            [True, False],
1888
            [True, False],
1889
            [True, False],
1890
        ):
1891
            state = powerSGD.PowerSGDState(
1892
                process_group=process_group,
1893
                matrix_approximation_rank=1,
1894
                use_error_feedback=use_error_feedback,
1895
                warm_start=warm_start,
1896
                batch_tensors_with_same_shape=batch_tensors_with_same_shape,
1897
            )
1898
            for hook in [powerSGD.powerSGD_hook, powerSGD.batched_powerSGD_hook]:
1899
                gpu_model = self._gpu_model_with_ddp_comm_hook(
1900
                    process_group, hook, gradient_as_bucket_view, state
1901
                )
1902

1903
                # check whether the grads are equal to what DDP without hook would return.
1904
                self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1905

1906
    def _test_builtin_ddp_comm_hooks_nccl(self, gradient_as_bucket_view=False):
1907
        """
1908
        This unit test verifies whether built-in C++ DDP communication hooks ALLREDUCE and FP16_COMPRESS
1909
        can give the same result with the case of no hook registered.
1910
        """
1911
        process_group = self._get_process_group()
1912

1913
        for comm_hook_type in [
1914
            dist.BuiltinCommHookType.ALLREDUCE,
1915
            dist.BuiltinCommHookType.FP16_COMPRESS,
1916
        ]:
1917
            # Get GPU model with the built-in communication hook.
1918
            gpu_model = self._gpu_model_with_builtin_ddp_comm_hook(
1919
                process_group, comm_hook_type, gradient_as_bucket_view
1920
            )
1921

1922
            # check whether the grads are equal to what DDP without hook would return.
1923
            self._run_and_verify_hook(gpu_model, 8, 0.25 * torch.ones(2, 2))
1924

1925
    @requires_nccl()
1926
    @skip_if_lt_x_gpu(2)
1927
    def test_ddp_comm_hook_allreduce_hook_nccl(self):
1928
        self._test_ddp_comm_hook_allreduce_hook_nccl()
1929

1930
    @requires_nccl()
1931
    @skip_if_lt_x_gpu(2)
1932
    def test_default_ddp_comm_hooks_nccl(self):
1933
        self._test_default_ddp_comm_hooks_nccl()
1934

1935
    @requires_nccl()
1936
    @skip_if_lt_x_gpu(2)
1937
    def test_fp16_compress_wrapper_nccl(self):
1938
        self._test_fp16_compress_wrapper()
1939

1940
    @requires_nccl()
1941
    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
1942
    @skip_but_pass_in_sandcastle_if(
1943
        not BFLOAT16_AVAILABLE,
1944
        "BFloat16 is only supported by CUDA 11+",
1945
    )
1946
    @skip_if_lt_x_gpu(2)
1947
    def test_bf16_compress_wrapper_nccl(self):
1948
        self._test_bf16_compress_wrapper()
1949

1950
    @requires_nccl()
1951
    @skip_if_lt_x_gpu(2)
1952
    def test_builtin_ddp_comm_hooks_nccl(self):
1953
        self._test_builtin_ddp_comm_hooks_nccl()
1954

1955
    @requires_nccl()
1956
    @skip_if_lt_x_gpu(2)
1957
    def test_powerSGD_ddp_comm_hook_nccl(self):
1958
        self._test_powerSGD_ddp_comm_hook_nccl()
1959

1960
    @requires_nccl()
1961
    @skip_if_lt_x_gpu(2)
1962
    def test_ddp_comm_hook_allreduce_hook_nccl_grad_is_view(self):
1963
        self._test_ddp_comm_hook_allreduce_hook_nccl(gradient_as_bucket_view=True)
1964

1965
    @requires_nccl()
1966
    @skip_if_lt_x_gpu(2)
1967
    def test_ddp_comm_hook_allreduce_hook_nccl_static_graph(self):
1968
        self._test_ddp_comm_hook_allreduce_hook_nccl(static_graph=True)
1969

1970
    @requires_nccl()
1971
    @skip_if_lt_x_gpu(2)
1972
    def test_default_ddp_comm_hooks_nccl_is_view(self):
1973
        self._test_default_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)
1974

1975
    @requires_nccl()
1976
    @skip_if_lt_x_gpu(2)
1977
    def test_fp16_compress_wrapper_is_view(self):
1978
        self._test_fp16_compress_wrapper(gradient_as_bucket_view=True)
1979

1980
    @requires_nccl()
1981
    @requires_nccl_version((2, 10), "Need NCCL 2.10+ for BF16_COMPRESS")
1982
    @skip_but_pass_in_sandcastle_if(
1983
        not BFLOAT16_AVAILABLE,
1984
        "BFloat16 is only supported by CUDA 11+",
1985
    )
1986
    @skip_if_lt_x_gpu(2)
1987
    def test_bf16_compress_wrapper_is_view(self):
1988
        self._test_bf16_compress_wrapper(gradient_as_bucket_view=True)
1989

1990
    @requires_nccl()
1991
    @skip_if_lt_x_gpu(2)
1992
    def test_builtin_ddp_comm_hooks_nccl_grad_is_view(self):
1993
        self._test_builtin_ddp_comm_hooks_nccl(gradient_as_bucket_view=True)
1994

1995
    @requires_nccl()
1996
    @skip_if_lt_x_gpu(2)
1997
    def test_powerSGD_ddp_comm_hook_nccl_grad_is_view(self):
1998
        self._test_powerSGD_ddp_comm_hook_nccl(gradient_as_bucket_view=True)
1999

2000
    @requires_nccl()
2001
    @skip_if_lt_x_gpu(2)
2002
    def test_ddp_comm_hook_allreduce_with_then_hook_nccl(self):
2003
        """
2004
        This unit test verifies whether a DDP communication hook that calls allreduce and then
2005
        multiplies the result by ten and divides by two gives the expected result.
2006
        """
2007
        process_group = self._get_process_group()
2008

2009
        def allreduce_with_then_hook(
2010
            state: object, bucket: dist.GradBucket
2011
        ) -> torch.futures.Future[torch.Tensor]:
2012
            tensors = [bucket.buffer() / self.world_size]
2013
            fut = process_group.allreduce(tensors).get_future()
2014

2015
            def mult(fut):
2016
                # Multiply the result by 10.
2017
                return 10 * fut.value()[0]
2018

2019
            def div(fut):
2020
                # Divide the result by 2.
2021
                return 0.5 * fut.value()
2022

2023
            return fut.then(mult).then(div)
2024

2025
        # Get GPU model with allreduce_with_then_hook registered.
2026
        gpu_model = self._gpu_model_with_ddp_comm_hook(
2027
            process_group, allreduce_with_then_hook
2028
        )
2029

2030
        # check whether the grads are equal to what allreduce returns multiplied by 5.
2031
        # without the comm_hook, result would be still 0.25 * torch.ones(2, 2).
2032
        self._run_and_verify_hook(gpu_model, 8, 1.25 * torch.ones(2, 2))
2033

2034
    class AcceptsParam(torch.nn.Module):
2035
        def __init__(self, p, factor):
2036
            super().__init__()
2037
            self.a = p
2038
            self.f = factor
2039

2040
        def forward(self, input):
2041
            return input + self.a * self.f
2042

2043
    @requires_nccl()
2044
    @skip_if_lt_x_gpu(2)
2045
    def test_ddp_weight_sharing(self):
2046
        process_group = self._get_process_group()
2047

2048
        size = 2048 * 2048
2049
        dev = self.rank
2050
        world = self.world_size
2051

2052
        p = torch.nn.Parameter(torch.randn(size, requires_grad=True))
2053

2054
        for try_set_to_none, use_bucket_view in product((False, True), (False, True)):
2055
            m = torch.nn.Sequential(
2056
                self.AcceptsParam(p, dev + 1), self.AcceptsParam(p, dev + 1)
2057
            ).cuda(dev)
2058

2059
            m = torch.nn.parallel.DistributedDataParallel(
2060
                m,
2061
                bucket_cap_mb=1,
2062
                gradient_as_bucket_view=use_bucket_view,
2063
                device_ids=[dev],
2064
                process_group=process_group,
2065
            )
2066

2067
            for i in range(3):
2068
                m.zero_grad(set_to_none=try_set_to_none)
2069
                m(1).sum().backward()
2070

2071
                # Each param value is multiplied by "rank + 1" twice in forward, so the grad
2072
                # values produced by a particular rank should be 2. * (rank + 1).
2073
                # Summing these over ranks and dividing by world size gives the expected result:
2074
                analytic = torch.full_like(
2075
                    p, 2.0 * (world * (world + 1.0) / 2.0) / world, device=dev
2076
                )
2077
                for name, p in m.named_parameters():
2078
                    self.assertEqual(
2079
                        p.grad,
2080
                        analytic,
2081
                        "mismatch at "
2082
                        + name
2083
                        + ".grad for "
2084
                        + f"set_to_none = {try_set_to_none}, use_bucket_view = {use_bucket_view}",
2085
                    )
2086

2087
    @requires_nccl()
2088
    @skip_if_lt_x_gpu(2)
2089
    def test_ddp_packed_sequence(self):
2090
        """
2091
        Tests that DDP with ``device_ids`` specified can run a forward and
2092
        backward pass with ``PackedSequence`` s with parity compared to a local
2093
        version of the model.
2094
        """
2095
        store = c10d.FileStore(self.file_name, self.world_size)
2096
        process_group = dist.init_process_group(
2097
            "nccl",
2098
            world_size=self.world_size,
2099
            rank=self.rank,
2100
            store=store,
2101
        )
2102
        seqs = ["sequence_sequence", "seq", "sequence"]
2103
        vocab = ["<pad>"] + sorted({ch for seq in seqs for ch in seq})
2104
        vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
2105
        # Set the seed to make the embedding and LSTM deterministic (even
2106
        # across ranks since DDP broadcasts parameters from rank 0)
2107
        torch.manual_seed(0)
2108
        embed = nn.Embedding(len(vocab), 4)  # keep on CPU
2109
        lstm = nn.LSTM(input_size=4, hidden_size=2, batch_first=True).to(self.rank)
2110
        lstm_ddp = DistributedDataParallel(
2111
            copy.deepcopy(lstm),
2112
            device_ids=[self.rank],
2113
            process_group=process_group,
2114
        )
2115
        for p1, p2 in zip(lstm.parameters(), lstm_ddp.module.parameters()):
2116
            self.assertEqual(p1, p2)
2117
        seq_lengths = torch.LongTensor(list(map(len, vectorized_seqs)))
2118
        seq_tensor = torch.Tensor(
2119
            torch.zeros((len(vectorized_seqs), seq_lengths.max()))
2120
        ).long()
2121
        for i, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
2122
            seq_tensor[i, :seq_len] = torch.LongTensor(seq)
2123
        seq_lengths, permutation_idx = seq_lengths.sort(0, descending=True)
2124
        seq_tensor = seq_tensor[permutation_idx]
2125
        embedded_seq_tensor = embed(seq_tensor)
2126
        packed_input = torch.nn.utils.rnn.pack_padded_sequence(
2127
            embedded_seq_tensor,
2128
            seq_lengths,
2129
            batch_first=True,
2130
        )
2131
        packed_input_ddp = torch.nn.utils.rnn.pack_padded_sequence(
2132
            embedded_seq_tensor.detach().clone(),
2133
            seq_lengths,
2134
            batch_first=True,
2135
        )
2136
        # Move the input to GPU explicitly for the local model
2137
        packed_output, (ht, ct) = lstm(packed_input.to(self.rank))
2138
        # Let DDP move the input to GPU internally
2139
        packed_output_ddp, (ht_ddp, ct_ddp) = lstm_ddp(packed_input_ddp)
2140
        self.assertEqual(packed_output.data, packed_output_ddp.data)
2141
        self.assertEqual(ht, ht_ddp)
2142
        self.assertEqual(ct, ct_ddp)
2143
        packed_output.data.sum().backward()
2144
        packed_output_ddp.data.sum().backward()
2145
        for p1, p2 in zip(lstm.parameters(), lstm_ddp.parameters()):
2146
            self.assertEqual(p1.grad, p2.grad)
2147

2148
    @requires_nccl()
2149
    @skip_if_lt_x_gpu(2)
2150
    def test_channels_last_contig(self):
2151
        process_group = self._get_process_group()
2152
        device = torch.device(f"cuda:{self.rank}")
2153
        tensor = torch.ones((2, 16, 768, 1152), dtype=torch.float32, device=device).to(
2154
            memory_format=torch.channels_last
2155
        )
2156
        process_group.broadcast([tensor]).wait()
2157

2158
    @requires_nccl()
2159
    @skip_if_lt_x_gpu(2)
2160
    def test_ddp_complex_params(self):
2161
        class FFTModel(nn.Module):
2162
            def __init__(self, hin, win, n_features):
2163
                super().__init__()
2164
                self.hin = hin
2165
                self.win = win
2166
                self.weight = nn.Parameter(
2167
                    torch.ones(
2168
                        (n_features, n_features, hin, win // 2 + 1), dtype=torch.cfloat
2169
                    )
2170
                )
2171

2172
            def forward(self, x):
2173
                xc = torch.fft.rfft2(
2174
                    x, s=(self.hin, self.win), dim=(-2, -1), norm="ortho"
2175
                )
2176
                xcw = torch.einsum("nchw,cohw->nohw", xc, self.weight)
2177
                x = torch.fft.irfft2(xcw, dim=(-2, -1), norm="ortho")
2178
                return x
2179

2180
        process_group = self._get_process_group()
2181
        device_id = gpus_for_rank(self.world_size)[self.rank][0]
2182
        N, C, H, W = 1, 16, 64, 64
2183
        ddp_model = DistributedDataParallel(
2184
            FFTModel(hin=H, win=W, n_features=C).to(device_id),
2185
            device_ids=[device_id],
2186
            process_group=process_group,
2187
        )
2188
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)
2189

2190
        inp = torch.ones((N, C, H, W), dtype=torch.float32)
2191

2192
        # train step
2193
        out = ddp_model(inp)
2194
        loss = torch.sum(out)
2195
        loss.backward()
2196
        optimizer.step()
2197

2198
        torch.cuda.synchronize(device=device_id)
2199

2200

2201
class WorkHookTest(MultiProcessTestCase):
2202
    @property
2203
    def world_size(self):
2204
        return 2
2205

2206
    def setUp(self):
2207
        super().setUp()
2208
        # set TORCH_NCCL_ENABLE_TIMING to enable timing for CUDAEvents
2209
        # in ProcessGroup Work
2210
        os.environ["TORCH_NCCL_ENABLE_TIMING"] = "1"
2211
        self._spawn_processes()
2212

2213
    def tearDown(self):
2214
        super().tearDown()
2215
        del os.environ["TORCH_NCCL_ENABLE_TIMING"]
2216
        try:
2217
            os.remove(self.file_name)
2218
        except OSError:
2219
            pass
2220

2221
    def _get_store(self):
2222
        return dist.FileStore(self.file_name, self.world_size)
2223

2224
    def _get_process_group(self):
2225
        store = self._get_store()
2226
        c10d.init_process_group(
2227
            "nccl", store=store, rank=self.rank, world_size=self.world_size
2228
        )
2229
        return c10d.distributed_c10d._get_default_group()
2230

2231
    @requires_nccl()
2232
    @skip_if_lt_x_gpu(2)
2233
    def test_on_completion_hook_broadcast(self):
2234
        pg = self._get_process_group()
2235
        num_hook_fired = 0
2236
        durations: List[float] = []
2237

2238
        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2239
            nonlocal num_hook_fired, durations
2240
            num_hook_fired += 1
2241
            durations.append(work_info.active_duration.total_seconds())
2242

2243
        pg._register_on_completion_hook(hook)
2244
        tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
2245
        pg.broadcast([tensor]).wait()
2246
        pg.broadcast([tensor]).wait()
2247

2248
        # N.B.: destroy_process_group is necessary to wait for
2249
        # all pending works to finish.
2250
        c10d.destroy_process_group(pg)
2251

2252
        self.assertEqual(num_hook_fired, 2)
2253
        self.assertEqual(len(durations), 2)
2254
        for duration in durations:
2255
            self.assertTrue(duration > 0)
2256

2257
        self.assertEqual(tensor, torch.zeros([2, 3]).cuda(self.rank))
2258

2259
    @requires_nccl()
2260
    @skip_if_lt_x_gpu(2)
2261
    def test_on_completion_hook_mixed_ops(self):
2262
        pg = self._get_process_group()
2263
        num_hook_fired = 0
2264
        durations: List[float] = []
2265

2266
        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2267
            nonlocal num_hook_fired, durations
2268
            num_hook_fired += 1
2269
            durations.append(work_info.active_duration.total_seconds())
2270

2271
        pg._register_on_completion_hook(hook)
2272
        tensor = torch.ones([2, 3]).cuda(self.rank)
2273
        tensor_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
2274
        # intentionally using async ops.
2275
        pg.allreduce(tensor)
2276
        pg.allgather(tensor_list, tensor)
2277
        pg.allreduce(tensor)
2278

2279
        # N.B.: destroy_process_group is necessary to wait for
2280
        # all pending works to finish.
2281
        c10d.destroy_process_group(pg)
2282

2283
        self.assertEqual(num_hook_fired, 3)
2284
        self.assertEqual(len(durations), 3)
2285
        for duration in durations:
2286
            self.assertTrue(duration > 0)
2287

2288
        self.assertEqual(
2289
            tensor,
2290
            torch.ones([2, 3]).cuda(self.rank) * self.world_size * self.world_size,
2291
        )
2292

2293
        self.assertEqual(
2294
            tensor_list,
2295
            [
2296
                torch.ones([2, 3]).cuda(self.rank) * self.world_size
2297
                for _ in range(self.world_size)
2298
            ],
2299
        )
2300

2301
    @requires_nccl()
2302
    @skip_if_lt_x_gpu(2)
2303
    def test_on_completion_hook_with_ddp(self):
2304
        pg = self._get_process_group()
2305
        num_hook_fired: Dict[int, int] = {}
2306
        durations: Dict[OpType, List[float]] = {}
2307

2308
        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2309
            nonlocal num_hook_fired, durations
2310
            op_type = work_info.op_type
2311
            if op_type not in num_hook_fired:
2312
                num_hook_fired[op_type] = 0
2313
                durations[op_type] = []
2314
            num_hook_fired[op_type] += 1
2315
            durations[op_type].append(work_info.active_duration.total_seconds())
2316

2317
        pg._register_on_completion_hook(hook)
2318

2319
        nlayers = 10
2320
        net = nn.Sequential(
2321
            *[nn.Linear(1000, 1000, bias=False) for _ in range(nlayers)]
2322
        ).to(self.rank)
2323

2324
        ddp = DistributedDataParallel(
2325
            net,
2326
            device_ids=[self.rank],
2327
            process_group=pg,
2328
            bucket_cap_mb=1,
2329
        )
2330

2331
        pg._wait_for_pending_works()
2332

2333
        # DDP is expected to synchronize model parameter by broadcasting
2334
        # from rank0 to other ranks. However, this is DDP's internal implementation,
2335
        # which is subject to change in future versions.
2336
        self.assertTrue(num_hook_fired[OpType.BROADCAST] > 0)
2337
        ctor_allreduce = (
2338
            num_hook_fired[OpType.ALLREDUCE]
2339
            if OpType.ALLREDUCE in num_hook_fired
2340
            else 0
2341
        )
2342

2343
        x = torch.zeros(2, 1000).cuda(self.rank)
2344
        ddp(x).sum().backward()
2345

2346
        c10d.destroy_process_group(pg)
2347

2348
        self.assertTrue(OpType.ALLREDUCE in num_hook_fired)
2349
        # The number of allreduce ops depend on DDP internal implementation, but
2350
        # there should be at least one allreduce.
2351
        self.assertTrue(num_hook_fired[OpType.ALLREDUCE] - ctor_allreduce > 0)
2352
        self.assertTrue(all(duration > 0 for duration in chain(*(durations.values()))))
2353

2354
    # Not testing FSDP due to https://github.com/pytorch/pytorch/issues/90848.
2355
    # We cannot disable workCleanupLoop() as hooks are fired in that thread.
2356

2357
    @requires_nccl()
2358
    @skip_if_lt_x_gpu(2)
2359
    def test_on_completion_hook_all_gather_object(self):
2360
        torch.cuda.set_device(self.rank)
2361

2362
        pg = self._get_process_group()
2363
        num_hook_fired: Dict[int, int] = {}
2364
        durations: Dict[OpType, List[float]] = {}
2365

2366
        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2367
            nonlocal num_hook_fired, durations
2368
            op_type = work_info.op_type
2369
            if op_type not in num_hook_fired:
2370
                num_hook_fired[op_type] = 0
2371
                durations[op_type] = []
2372
            num_hook_fired[op_type] += 1
2373
            durations[op_type].append(work_info.active_duration.total_seconds())
2374

2375
        pg._register_on_completion_hook(hook)
2376

2377
        obj = {"rank": self.rank, "world_size": self.world_size}
2378
        obj_list = [None for _ in range(self.world_size)]
2379

2380
        c10d.all_gather_object(obj_list, obj, group=pg)
2381

2382
        for r, o in enumerate(obj_list):
2383
            self.assertTrue(isinstance(o, dict))
2384
            self.assertTrue(set(o.keys()), {"rank", "world_size"})
2385
            self.assertEqual(o["rank"], r)
2386
            self.assertEqual(o["world_size"], self.world_size)
2387

2388
        c10d.destroy_process_group(pg)
2389

2390
        self.assertTrue(OpType.ALLGATHER in num_hook_fired)
2391
        self.assertEqual(len(num_hook_fired), 1)
2392
        # two allgathers, one for size and another for values
2393
        self.assertEqual(num_hook_fired[OpType.ALLGATHER], 2)
2394
        self.assertTrue(all(duration > 0 for duration in durations[OpType.ALLGATHER]))
2395

2396
    @requires_nccl()
2397
    @skip_if_lt_x_gpu(2)
2398
    def test_on_completion_hook_seq(self):
2399
        pg = self._get_process_group()
2400
        num_hook_fired = 0
2401
        seq: int = -1
2402
        work: int = 0
2403

2404
        def hook(work_info: torch._C._distributed_c10d.WorkInfo):
2405
            nonlocal num_hook_fired, seq
2406
            num_hook_fired += 1
2407
            seq = work_info.seq
2408

2409
        pg._register_on_completion_hook(hook)
2410
        tensor = torch.ones([2, 3]).cuda(self.rank) * self.rank
2411
        work_count = 3
2412
        for i in range(work_count):
2413
            work += 1
2414
            pg.broadcast([tensor]).wait()
2415

2416
        # N.B.: destroy_process_group is necessary to wait for
2417
        # all pending works to finish.
2418
        c10d.destroy_process_group(pg)
2419

2420
        self.assertEqual(num_hook_fired, work_count)
2421
        self.assertEqual(work, seq)
2422

2423

2424
class NcclErrorHandlingTest(MultiProcessTestCase):
2425
    def setUp(self):
2426
        super().setUp()
2427
        # Need to skip return code checking for these tests since the child
2428
        # processes don't exit cleanly.
2429
        self.skip_return_code_checks = [
2430
            self.test_nccl_errors_blocking_abort.__wrapped__,
2431
            self.test_nccl_errors_blocking_sigkill.__wrapped__,
2432
            self.test_nccl_errors_blocking_sigterm.__wrapped__,
2433
            self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
2434
        ]
2435
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
2436
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
2437
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
2438
        self._spawn_processes()
2439

2440
    def tearDown(self):
2441
        super().tearDown()
2442
        try:
2443
            os.remove(self.file_name)
2444
        except OSError:
2445
            pass
2446

2447
    @property
2448
    def op_timeout_sec(self):
2449
        return 3
2450

2451
    @property
2452
    def world_size(self):
2453
        return 3
2454

2455
    @property
2456
    def blocking_wait_error_msg(self):
2457
        return "timeout"
2458

2459
    def _run_all_reduce(self, pg):
2460
        pg.allreduce(torch.rand(10).cuda(self.rank))
2461

2462
    @requires_nccl()
2463
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2464
    @skip_if_lt_x_gpu(3)
2465
    @skip_if_rocm
2466
    @skip_but_pass_in_sandcastle("Test does not pass when run locally")
2467
    def test_nccl_errors_nonblocking(self):
2468
        # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
2469
        # since test_c10d_common runs with async error handling by default, but this
2470
        # tests behavior when it is not enabled.
2471
        prev_nccl_async_error_handling = os.environ.get(
2472
            "TORCH_NCCL_ASYNC_ERROR_HANDLING", None
2473
        )
2474
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
2475
        store = c10d.FileStore(self.file_name, self.world_size)
2476
        process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
2477
        process_group.allreduce(torch.rand(10).cuda(self.rank))
2478
        if self.rank == 0:
2479
            # This allreduce does not block Python thread as allreduce enqueues
2480
            # the cuda operation, and then wait only blocks the current cuda
2481
            # stream.
2482
            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2483
            work.wait()
2484

2485
            # Now the work scheduled next should hang forever since the previous
2486
            # allreduce will never complete.
2487
            t = threading.Thread(target=self._run_all_reduce, args=(process_group,))
2488
            t.daemon = True
2489
            t.start()
2490
            t.join(int(get_timeout(self.id()) / 5))
2491
            self.assertTrue(t.is_alive())
2492

2493
        if prev_nccl_async_error_handling is not None:
2494
            os.environ[
2495
                "TORCH_NCCL_ASYNC_ERROR_HANDLING"
2496
            ] = prev_nccl_async_error_handling
2497

2498
    def _test_nccl_errors_blocking(self, func):
2499
        store = c10d.FileStore(self.file_name, self.world_size)
2500
        process_group = c10d.ProcessGroupNCCL(
2501
            store,
2502
            self.rank,
2503
            self.world_size,
2504
            timeout=timedelta(seconds=10),
2505
        )
2506
        process_group.allreduce(torch.rand(10).cuda(self.rank))
2507
        if self.rank == 0:
2508
            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2509
            with self.assertRaisesRegex(dist.DistBackendError, ""):
2510
                # It seems the error message would be different depending on
2511
                # whether the test is run on CI machine and devGPU.  Skipping
2512
                # the error message check to make both sides happy.
2513
                work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
2514
            # Run some GPU operations to make sure cuda has not gotten stuck.
2515
            # It was observed cuda could get stuck if NCCL communicators were
2516
            # not properly aborted before throwing RuntimeError.
2517
            a = torch.rand(10).cuda(self.rank)
2518
        elif self.rank == 1:
2519
            # Clean up structures (ex: files for FileStore before going down)
2520
            del process_group
2521
            func()
2522

2523
    @with_nccl_blocking_wait
2524
    @requires_nccl()
2525
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2526
    @skip_if_lt_x_gpu(3)
2527
    @skip_if_rocm
2528
    def test_nccl_errors_blocking_clean_exit(self):
2529
        self._test_nccl_errors_blocking(lambda: sys.exit(0))
2530

2531
    @with_nccl_blocking_wait
2532
    @requires_nccl()
2533
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2534
    @skip_if_lt_x_gpu(3)
2535
    @skip_if_rocm
2536
    def test_nccl_errors_blocking_nonzero_exit(self):
2537
        self._test_nccl_errors_blocking(lambda: sys.exit(1))
2538

2539
    @with_nccl_blocking_wait
2540
    @requires_nccl()
2541
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2542
    @skip_if_lt_x_gpu(3)
2543
    @skip_if_rocm
2544
    @skip_but_pass_in_sandcastle(
2545
        "Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2546
    )
2547
    def test_nccl_errors_blocking_abort(self):
2548
        self._test_nccl_errors_blocking(lambda: os.abort())
2549

2550
    @with_nccl_blocking_wait
2551
    @requires_nccl()
2552
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2553
    @skip_if_lt_x_gpu(3)
2554
    @skip_if_rocm
2555
    def test_nccl_errors_blocking_sigkill(self):
2556
        self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
2557

2558
    @with_nccl_blocking_wait
2559
    @requires_nccl()
2560
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2561
    @skip_if_lt_x_gpu(3)
2562
    @skip_if_rocm
2563
    def test_nccl_errors_blocking_sigterm(self):
2564
        self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
2565

2566
    @with_nccl_blocking_wait
2567
    @requires_nccl()
2568
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2569
    @skip_if_lt_x_gpu(3)
2570
    def test_nccl_blocking_wait_with_barrier(self):
2571
        store = c10d.FileStore(self.file_name, self.world_size)
2572
        process_group = c10d.ProcessGroupNCCL(
2573
            store,
2574
            self.rank,
2575
            self.world_size,
2576
            timeout=timedelta(seconds=10),
2577
        )
2578
        process_group.barrier().wait()
2579
        if self.rank == 0:
2580
            with self.assertRaisesRegex(dist.DistBackendError, ""):
2581
                # It seems the error message would be different depending on
2582
                # whether the test is run on CI machine and devGPU.  Skipping
2583
                # the error message check to make both sides happy.
2584
                process_group.barrier().wait(
2585
                    timeout=timedelta(seconds=self.op_timeout_sec)
2586
                )
2587

2588
    def _run_invalid_nccl_blocking_wait_env(self, val):
2589
        os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
2590
        store = c10d.FileStore(self.file_name, self.world_size)
2591
        with self.assertRaises(RuntimeError):
2592
            process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
2593

2594
    @requires_nccl()
2595
    @skip_if_lt_x_gpu(3)
2596
    def test_invalid_nccl_blocking_wait_env(self):
2597
        self._run_invalid_nccl_blocking_wait_env("abc")
2598
        self._run_invalid_nccl_blocking_wait_env("-1")
2599
        self._run_invalid_nccl_blocking_wait_env("2147483647")
2600
        self._run_invalid_nccl_blocking_wait_env("4294967295")
2601

2602
    @with_nccl_blocking_wait
2603
    @requires_nccl()
2604
    @requires_gloo()
2605
    @skip_if_lt_x_gpu(3)
2606
    def test_nccl_timeout(self):
2607
        store = c10d.FileStore(self.file_name, self.world_size)
2608

2609
        # Initialize process_group.
2610
        process_group = c10d.ProcessGroupNCCL(
2611
            store, self.rank, self.world_size, timeout=timedelta(seconds=10)
2612
        )
2613
        # Control gloo pg used as go-ahead signal/barrier
2614
        # to coordinate btwn ranks.
2615
        pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
2616
        failed_collective_timeout = timedelta(milliseconds=100)
2617
        process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
2618
            timeout=timedelta(seconds=5)
2619
        )
2620

2621
        if self.rank == 0:
2622
            # This should timeout in about 1 second.
2623
            # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
2624
            with self.assertRaisesRegex(
2625
                dist.DistBackendError, self.blocking_wait_error_msg
2626
            ):
2627
                process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
2628
                    timeout=failed_collective_timeout
2629
                )
2630
            # Now do a barrier to tell other rank to go ahead.
2631
            pg_gloo.barrier().wait()
2632
        else:
2633
            # Wait on rank 0 to fail.
2634
            try:
2635
                pg_gloo.barrier().wait()
2636
            except Exception as e:
2637
                raise ValueError(
2638
                    f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
2639
                ) from e
2640

2641

2642
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
2643
    @property
2644
    def device(self):
2645
        return f"cuda:{self.rank}"
2646

2647
    def setUp(self):
2648
        super().setUp()
2649
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
2650
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
2651
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
2652
        self._spawn_processes()
2653

2654
    def tearDown(self):
2655
        super().tearDown()
2656
        try:
2657
            os.remove(self.file_name)
2658
        except OSError:
2659
            pass
2660

2661
    def _test_broadcast_coalesced(self, process_group, device, root_rank):
2662
        half = torch.float16
2663

2664
        # No support for float16 for CPU tensors
2665
        if device == torch.device("cpu"):
2666
            half = torch.float32
2667

2668
        target = torch.arange(60, dtype=half, device=device).chunk(5)
2669
        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2670
        target += torch.arange(60, dtype=half, device=device).chunk(5)
2671
        target += torch.arange(60, dtype=torch.float64, device=device).chunk(5)
2672
        target += torch.arange(60, dtype=half, device=device).chunk(5)
2673
        target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)
2674

2675
        # The tensors to pass to broadcast are identical to the target
2676
        # only on the process that is the root of the broadcast.
2677
        if self.rank == root_rank:
2678
            tensors = [tensor.clone() for tensor in target]
2679
        else:
2680
            tensors = [torch.zeros_like(tensor) for tensor in target]
2681

2682
        if self.rank != root_rank:
2683
            self.assertNotEqual(tensors, target)
2684

2685
        c10d._broadcast_coalesced(
2686
            process_group, tensors, buffer_size=256, src=root_rank
2687
        )
2688

2689
        if self.rank != root_rank:
2690
            self.assertEqual(tensors, target)
2691

2692
    @requires_nccl()
2693
    @skip_if_lt_x_gpu(2)
2694
    def test_broadcast_coalesced_nccl(self):
2695
        store = c10d.FileStore(self.file_name, self.world_size)
2696
        c10d.init_process_group(
2697
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2698
        )
2699
        process_group = c10d.distributed_c10d._get_default_group()
2700
        device = torch.device("cuda:%d" % self.rank)
2701
        ranks = [0, 1]
2702
        for root_rank in ranks:
2703
            self._test_broadcast_coalesced(process_group, device, root_rank)
2704

2705
    @requires_nccl()
2706
    @skip_if_lt_x_gpu(2)
2707
    def test_all_reduce_coalesced_nccl(self):
2708
        store = c10d.FileStore(self.file_name, self.world_size)
2709
        c10d.init_process_group(
2710
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2711
        )
2712
        process_group = c10d.distributed_c10d._get_default_group()
2713
        device = torch.device("cuda:%d" % self.rank)
2714
        tensors = [
2715
            torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
2716
            for i in range(5)
2717
        ]
2718
        torch.distributed.all_reduce_coalesced(tensors, group=process_group)
2719
        for i, t in enumerate(tensors):
2720
            self.assertEqual(
2721
                t,
2722
                torch.full_like(
2723
                    t, self.world_size * (i + (self.world_size + 1.0) / 2.0)
2724
                ),
2725
            )
2726

2727
    @requires_nccl()
2728
    @skip_if_lt_x_gpu(2)
2729
    def test_all_reduce_coalesced_nccl_float8_errors(self):
2730
        store = c10d.FileStore(self.file_name, self.world_size)
2731
        c10d.init_process_group(
2732
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2733
        )
2734
        process_group = c10d.distributed_c10d._get_default_group()
2735
        device = torch.device("cuda:%d" % self.rank)
2736
        tensors = [
2737
            torch.full(
2738
                (60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
2739
            ).to(torch.float8_e4m3fn)
2740
            for i in range(5)
2741
        ]
2742
        with self.assertRaisesRegex(
2743
            RuntimeError,
2744
            "Float8 dtypes are not currenlty supported for NCCL reductions",
2745
        ):
2746
            torch.distributed.all_reduce_coalesced(tensors, group=process_group)
2747

2748
    @requires_nccl()
2749
    @skip_if_lt_x_gpu(2)
2750
    def test_all_reduce_coalesced_manager_nccl(self):
2751
        store = c10d.FileStore(self.file_name, self.world_size)
2752
        c10d.init_process_group(
2753
            backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2754
        )
2755
        process_group = c10d.distributed_c10d._get_default_group()
2756
        device = torch.device("cuda:%d" % self.rank)
2757
        tensors = [
2758
            torch.full((60 + i,), self.rank + 1 + i, device=device, dtype=torch.float)
2759
            for i in range(5)
2760
        ]
2761
        with torch.distributed._coalescing_manager(
2762
            group=process_group, device=device, async_ops=True
2763
        ) as cm:
2764
            for tensor in tensors:
2765
                torch.distributed.all_reduce(tensor)
2766
        self.assertEqual(len(cm.works), 1)
2767
        cm.wait()
2768
        for i, t in enumerate(tensors):
2769
            self.assertEqual(
2770
                t,
2771
                torch.full_like(
2772
                    t, self.world_size * (i + (self.world_size + 1.0) / 2.0)
2773
                ),
2774
            )
2775

2776
    @requires_nccl()
2777
    @skip_if_lt_x_gpu(2)
2778
    @skip_if_rocm
2779
    def test_intra_node_comm_all_reduce(self):
2780
        from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
2781
        from torch.testing._internal.common_cuda import SM80OrLater
2782

2783
        for peer in range(self.world_size):
2784
            if peer == self.rank:
2785
                continue
2786
            if not torch._C._cuda_canDeviceAccessPeer(self.rank, peer):
2787
                raise SkipTest("Test requires p2p access")
2788

2789
        if not SM80OrLater:
2790
            raise SkipTest("Test requires sm>=80")
2791

2792
        store = c10d.FileStore(self.file_name, self.world_size)
2793
        os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
2794
        os.environ["TEST_INTRA_NODE_COMM"] = "1"
2795
        torch.cuda.set_device(self.rank)
2796
        c10d.init_process_group(
2797
            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2798
        )
2799
        expect = self.world_size * (self.world_size - 1) // 2
2800

2801
        # IntraNodeComm currently only supports sum and bf16.
2802
        # Verify that it is not used in the next two configurations.
2803
        t = torch.full((4 * 1024 // 2,), self.rank).cuda()
2804
        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2805
        self.assertTrue(t.eq(expect).all())
2806
        self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
2807

2808
        t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2809
        c10d.all_reduce(t, c10d.ReduceOp.AVG)
2810
        self.assertEqual(_get_intra_node_comm_usage_counter(), 0)
2811

2812
        # Verify that IntraNodeComm is used up to 10MB
2813
        t = torch.full((4 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2814
        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2815
        self.assertTrue(t.eq(expect).all())
2816
        self.assertEqual(_get_intra_node_comm_usage_counter(), 1)
2817

2818
        t = torch.full((512 * 1024 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2819
        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2820
        self.assertTrue(t.eq(expect).all())
2821
        self.assertEqual(_get_intra_node_comm_usage_counter(), 2)
2822

2823
        t = torch.full((10 * 1024**2 // 2,), self.rank, dtype=torch.bfloat16).cuda()
2824
        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2825
        self.assertTrue(t.eq(expect).all())
2826
        self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
2827

2828
        # Verify that IntraNodeComm is not used beyond 10MB
2829
        t = torch.full(
2830
            (10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16
2831
        ).cuda()
2832
        c10d.all_reduce(t, c10d.ReduceOp.SUM)
2833
        self.assertTrue(t.eq(expect).all())
2834
        self.assertEqual(_get_intra_node_comm_usage_counter(), 3)
2835

2836
        c10d.destroy_process_group()
2837

2838
    @requires_nccl()
2839
    @skip_if_lt_x_gpu(2)
2840
    def test_sequence_num_set_default_pg_nccl(self):
2841
        torch.cuda.set_device(self.rank)
2842
        self._test_sequence_num_set_default_pg(backend="nccl")
2843

2844
    @skip_if_lt_x_gpu(2)
2845
    @requires_nccl()
2846
    def test_sequence_num_incremented_nccl_default(self):
2847
        self._test_sequence_num_incremented_default_group("nccl")
2848

2849
    @skip_if_lt_x_gpu(4)
2850
    @requires_nccl()
2851
    def test_sequence_num_incremented_nccl_subgroup(self):
2852
        if self.world_size < 4:
2853
            return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")
2854
        self._test_sequence_num_incremented_subgroup("nccl")
2855

2856
    @requires_nccl()
2857
    @skip_if_lt_x_gpu(2)
2858
    def test_sequence_num_set_nccl_new_group(self):
2859
        torch.cuda.set_device(self.rank)
2860
        self._test_sequence_num_set_new_group(backend="nccl")
2861

2862
    def _test_pass_nccl_options(self, pg_opts):
2863
        store = c10d.FileStore(self.file_name, self.world_size)
2864
        # Test init_process_group accepts options
2865
        dist.init_process_group(
2866
            "nccl",
2867
            world_size=self.world_size,
2868
            rank=self.rank,
2869
            store=store,
2870
            pg_options=pg_opts,
2871
        )
2872

2873
        # Test with new_group
2874
        pg = c10d.new_group([0, 1], pg_options=pg_opts)
2875
        # test the process group works as expected
2876
        t = torch.tensor([self.rank + 1] * 10).cuda(self.rank)
2877
        pg.allreduce(t).wait()
2878
        expected_tensor = torch.tensor([3] * 10).cuda(self.rank)
2879
        self.assertEqual(expected_tensor, t)
2880

2881
    @requires_nccl()
2882
    @skip_if_lt_x_gpu(2)
2883
    def test_pass_nccl_options_high_priority_stream(self):
2884
        pg_opts = c10d.ProcessGroupNCCL.Options()
2885
        pg_opts.is_high_priority_stream = True
2886
        self._test_pass_nccl_options(pg_opts)
2887

2888
    @requires_nccl()
2889
    @requires_nccl_version(
2890
        (2, 18), "Need NCCL 2.17+ for configuring NCCL communicators"
2891
    )
2892
    @skip_if_lt_x_gpu(2)
2893
    def test_pass_nccl_options_config(self):
2894
        pg_opts = c10d.ProcessGroupNCCL.Options()
2895
        pg_opts.config.max_ctas = 4
2896
        pg_opts.config.min_ctas = 2
2897
        pg_opts.config.cga_cluster_size = 2
2898
        pg_opts.config.net_name = "Socket"
2899
        pg_opts.config.split_share = 1
2900
        nccl_debug_file = tempfile.NamedTemporaryFile()
2901
        os.environ["NCCL_DEBUG"] = "INFO"
2902
        os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
2903

2904
        # Tests functionality when passing nccl config
2905
        self._test_pass_nccl_options(pg_opts)
2906

2907
        # Tests if comms were configured
2908
        nccl_debug_file_content = nccl_debug_file.read()
2909
        max_ctas = re.search(rb"Max CTAs.*(\d+)|$", nccl_debug_file_content).group(1)
2910
        min_ctas = re.search(rb"Min CTAs.*(\d+)|$", nccl_debug_file_content).group(1)
2911
        split_share = re.search(
2912
            rb"Split share.*(\d+)|$", nccl_debug_file_content
2913
        ).group(1)
2914
        cga_cluster_size = re.search(
2915
            rb"CGA cluster.*(\d+)|$", nccl_debug_file_content
2916
        ).group(1)
2917
        net_name = re.search(
2918
            rb"Using network.([a-zA-z]+)|$", nccl_debug_file_content
2919
        ).group(1)
2920
        self.assertEqual(pg_opts.config.max_ctas, int(max_ctas))
2921
        self.assertEqual(pg_opts.config.min_ctas, int(min_ctas))
2922
        self.assertEqual(pg_opts.config.cga_cluster_size, int(cga_cluster_size))
2923
        self.assertEqual(pg_opts.config.net_name, net_name.decode())
2924
        self.assertEqual(pg_opts.config.split_share, int(split_share))
2925

2926
    @requires_nccl()
2927
    @skip_if_lt_x_gpu(4)
2928
    def test_nccl_barrier(self):
2929
        store = c10d.FileStore(self.file_name, self.world_size)
2930
        c10d.init_process_group(
2931
            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2932
        )
2933

2934
        t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2935
        c10d.all_reduce(t)
2936
        expected_tensor = torch.tensor([3] * 10).cuda(2 * self.rank)
2937
        self.assertEqual(expected_tensor, t)
2938

2939
        # Test with new_group
2940
        pg = c10d.new_group([0, 1])
2941
        t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2942
        pg.allreduce(t).wait()
2943
        self.assertEqual(expected_tensor, t)
2944

2945
        pg = c10d.new_group([0])
2946
        if self.rank == 0:
2947
            t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2948
            expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2949
            pg.allreduce(t).wait()
2950
            self.assertEqual(expected_tensor, t)
2951

2952
        pg = c10d.new_group([1])
2953
        if self.rank == 1:
2954
            t = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2955
            expected_tensor = torch.tensor([self.rank + 1] * 10).cuda(2 * self.rank)
2956
            pg.allreduce(t).wait()
2957
            self.assertEqual(expected_tensor, t)
2958

2959
    @requires_nccl()
2960
    @skip_if_lt_x_gpu(2)
2961
    def test_nccl_barrier_device_ids(self):
2962
        store = c10d.FileStore(self.file_name, self.world_size)
2963
        c10d.init_process_group(
2964
            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2965
        )
2966

2967
        c10d.barrier(device_ids=[self.rank])
2968

2969
    @requires_nccl()
2970
    @skip_if_lt_x_gpu(2)
2971
    def test_nccl_barrier_device_ids_function_argument(self):
2972
        store = c10d.FileStore(self.file_name, self.world_size)
2973
        c10d.init_process_group(
2974
            backend="nccl", rank=self.rank, world_size=self.world_size, store=store
2975
        )
2976

2977
        with self.assertRaisesRegex(TypeError, "Invalid function argument"):
2978
            c10d.barrier(device_ids=self.rank)
2979

2980
    @requires_nccl()
2981
    @skip_if_lt_x_gpu(2)
2982
    @with_dist_debug_levels(levels=["DETAIL"])
2983
    def test_nccl_warn_not_in_group_debug_detail(self):
2984
        self._test_warn_not_in_group(backend="nccl")
2985

2986
    @requires_nccl()
2987
    @skip_if_lt_x_gpu(2)
2988
    @with_dist_debug_levels(levels=["INFO"])
2989
    def test_nccl_warn_not_in_group_debug_info(self):
2990
        self._test_warn_not_in_group(backend="nccl")
2991

2992
    @requires_nccl()
2993
    @skip_if_lt_x_gpu(2)
2994
    @with_dist_debug_levels(levels=["OFF"])
2995
    def test_nccl_warn_not_in_group_debug_off(self):
2996
        self._test_warn_not_in_group(backend="nccl")
2997

2998
    @requires_nccl()
2999
    @skip_if_lt_x_gpu(2)
3000
    def test_nncl_rank_membership(self):
3001
        self._test_rank_membership(backend="nccl")
3002

3003
    @requires_nccl()
3004
    @skip_if_lt_x_gpu(2)
3005
    def test_tensor_dtype_mismatch(self):
3006
        self._test_tensor_dtype_mismatch(backend="nccl")
3007

3008
    @requires_nccl()
3009
    @skip_if_lt_x_gpu(2)
3010
    def test_tensor_dtype_complex(self):
3011
        self._test_tensor_dtype_complex(backend="nccl")
3012

3013
    @requires_nccl()
3014
    @skip_if_lt_x_gpu(2)
3015
    def test_reduce_scatter_base_k(self):
3016
        store = dist.FileStore(self.file_name, self.world_size)
3017
        dist.init_process_group(
3018
            "nccl",
3019
            world_size=self.world_size,
3020
            rank=self.rank,
3021
            store=store,
3022
        )
3023
        output_tensor = torch.zeros(2, dtype=torch.int64).to(self.rank)
3024
        input_tensors = torch.arange(self.world_size * 2, dtype=torch.int64).to(
3025
            self.rank
3026
        )
3027
        input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
3028
        dist.reduce_scatter_tensor(output_tensor, input_tensors)
3029
        self.assertEqual(output_tensor, input_tensors[self.rank] * self.world_size)
3030

3031
    @requires_nccl()
3032
    @skip_if_lt_x_gpu(2)
3033
    def test_reduce_scatter_tensor_coalesced(self):
3034
        store = dist.FileStore(self.file_name, self.world_size)
3035
        dist.init_process_group(
3036
            "nccl",
3037
            world_size=self.world_size,
3038
            rank=self.rank,
3039
            store=store,
3040
        )
3041
        output_tensors = torch.zeros(2, 2).to(self.rank)
3042
        input_tensors = [torch.ones(2, 2).to(self.rank) for _ in range(self.world_size)]
3043
        with dist._coalescing_manager():
3044
            for i in range(self.world_size):
3045
                dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
3046
        self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
3047

3048
    @requires_nccl()
3049
    @skip_if_lt_x_gpu(2)
3050
    def test_reduce_scatter_base_k_float8_errors(self):
3051
        store = dist.FileStore(self.file_name, self.world_size)
3052
        dist.init_process_group(
3053
            "nccl",
3054
            world_size=self.world_size,
3055
            rank=self.rank,
3056
            store=store,
3057
        )
3058
        output_tensor = (
3059
            torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank)
3060
        )
3061
        input_tensors = (
3062
            torch.arange(self.world_size * 2, dtype=torch.float32)
3063
            .to(torch.float8_e4m3fn)
3064
            .to(self.rank)
3065
        )
3066
        input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
3067
        with self.assertRaisesRegex(
3068
            RuntimeError,
3069
            "Float8 dtypes are not currenlty supported for NCCL reductions",
3070
        ):
3071
            dist.reduce_scatter_tensor(output_tensor, input_tensors)
3072

3073
    @requires_nccl()
3074
    @skip_if_lt_x_gpu(2)
3075
    def test_reduce_scatter_tensor_coalesced_float8_errors(self):
3076
        store = dist.FileStore(self.file_name, self.world_size)
3077
        dist.init_process_group(
3078
            "nccl",
3079
            world_size=self.world_size,
3080
            rank=self.rank,
3081
            store=store,
3082
        )
3083
        output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank)
3084
        input_tensors = [
3085
            torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank)
3086
            for _ in range(self.world_size)
3087
        ]
3088

3089
        with self.assertRaisesRegex(
3090
            RuntimeError,
3091
            "Float8 dtypes are not currenlty supported for NCCL reductions",
3092
        ):
3093
            with dist._coalescing_manager():
3094
                for i in range(self.world_size):
3095
                    dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
3096
            self.assertEqual(output_tensors, input_tensors[self.rank])
3097

3098

3099
class SetDeviceMethod(Enum):
3100
    TORCH_CUDA_SET = auto()  # torch.cuda.set_device
3101
    COLLECTIVE_ARGUMENT = auto()  # broadcast_object_list(device=)
3102

3103

3104
class NcclProcessGroupWithDispatchedCollectivesTests(
3105
    test_c10d_common.ProcessGroupWithDispatchedCollectivesTests
3106
):
3107
    @requires_nccl()
3108
    @skip_if_lt_x_gpu(1)
3109
    def test_collectives(self):
3110
        self._test_collectives(backend="nccl")
3111

3112
    @requires_nccl()
3113
    @skip_if_lt_x_gpu(1)
3114
    def test_allreduce_coalesced(self):
3115
        self._test_allreduce_coalesced(backend="nccl")
3116

3117
    @requires_nccl()
3118
    @skip_if_lt_x_gpu(1)
3119
    def test_all_to_all_single(self):
3120
        self._test_all_to_all_single(backend="nccl")
3121

3122
    @requires_nccl()
3123
    @skip_if_lt_x_gpu(1)
3124
    def test_allgather_base(self):
3125
        store = dist.FileStore(self.file_name, self.world_size)
3126
        dist.init_process_group(
3127
            "nccl",
3128
            world_size=self.world_size,
3129
            rank=self.rank,
3130
            store=store,
3131
        )
3132
        device = "cuda"
3133
        tensor = torch.ones(10, 10, device=torch.device(device))
3134
        output_tensor = torch.zeros(10, 10, device=torch.device(device))
3135
        dist.all_gather_into_tensor(output_tensor, tensor)
3136
        self.assertEqual(output_tensor, tensor)
3137

3138
    @requires_nccl()
3139
    @skip_if_lt_x_gpu(1)
3140
    @parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
3141
    def test_allgather_float8(self, float8_dtype):
3142
        store = dist.FileStore(self.file_name, self.world_size)
3143
        dist.init_process_group(
3144
            "nccl",
3145
            world_size=self.world_size,
3146
            rank=self.rank,
3147
            store=store,
3148
        )
3149
        device = "cuda"
3150
        tensor = torch.ones(10, 16, device=torch.device(device)).to(float8_dtype)
3151
        output_tensor = torch.zeros(10, 16, device=torch.device(device)).to(
3152
            float8_dtype
3153
        )
3154
        dist.all_gather_into_tensor(output_tensor, tensor)
3155
        self.assertEqual(output_tensor.view(torch.float32), tensor.view(torch.float32))
3156

3157

3158
instantiate_parametrized_tests(NcclProcessGroupWithDispatchedCollectivesTests)
3159

3160

3161
class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):
3162
    def setUp(self):
3163
        super().setUp()
3164
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
3165
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
3166
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
3167
        self._spawn_processes()
3168

3169
    def tearDown(self):
3170
        super().tearDown()
3171
        try:
3172
            os.remove(self.file_name)
3173
        except OSError:
3174
            pass
3175

3176
    @property
3177
    def device(self):
3178
        return self.rank
3179

3180
    @requires_nccl()
3181
    @skip_if_lt_x_gpu(4)
3182
    def test_new_group_local_sync(self):
3183
        self._test_new_group_local_sync(backend="nccl")
3184

3185
    @requires_nccl()
3186
    @skip_if_lt_x_gpu(4)
3187
    def test_new_group_local_sync_sanity_check(self):
3188
        self._test_new_group_local_sync_sanity_check(backend="nccl")
3189

3190
    @requires_nccl()
3191
    @skip_if_lt_x_gpu(4)
3192
    def test_new_group_local_sync_duplicated_pg(self):
3193
        self._test_new_group_local_sync_duplicate_pg(backend="nccl")
3194

3195
    def _init_two_pg2_subgroups(self, world_size: int = 4):
3196
        if world_size != 4:
3197
            raise NotImplementedError(
3198
                f"need world size of 4 to get 2 subgroup PGs, but got world size of {world_size}"
3199
            )
3200
        store = c10d.FileStore(self.file_name, world_size)
3201
        c10d.init_process_group(
3202
            backend="nccl", store=store, rank=self.rank, world_size=world_size
3203
        )
3204
        # every rank creates the same sub groups
3205
        # including unused sub groups in the current rank
3206
        a_group = c10d.new_group([0, 1])
3207
        b_group = c10d.new_group([2, 3])
3208
        return a_group if self.rank < 2 else b_group
3209

3210
    @requires_nccl()
3211
    @skip_if_lt_x_gpu(4)
3212
    def test_gather_subgroup(self):
3213
        world_size = 4
3214
        if self.rank >= world_size:
3215
            # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
3216
            return
3217

3218
        subgroup = self._init_two_pg2_subgroups(world_size)
3219
        device = torch.device("cuda:%d" % self.rank)
3220
        input = torch.ones((10,), device=device) * self.rank
3221
        if self.rank == 0 or self.rank == 2:
3222
            gather_list = [torch.empty_like(input) for _ in range(subgroup.size())]
3223
            torch.distributed.gather(
3224
                input,
3225
                gather_list=gather_list,
3226
                dst=self.rank,
3227
                group=subgroup,
3228
                async_op=False,
3229
            )
3230
            for src in range(len(gather_list)):
3231
                expected = (torch.ones_like(input) * self.rank) + src
3232
                self.assertEqual(gather_list[src], expected)
3233
        else:
3234
            torch.distributed.gather(
3235
                input,
3236
                gather_list=None,
3237
                dst=self.rank - 1,
3238
                group=subgroup,
3239
                async_op=False,
3240
            )
3241

3242
    @requires_nccl()
3243
    @skip_if_lt_x_gpu(4)
3244
    def test_gather_object_subgroup(self):
3245
        world_size = 4
3246
        if self.rank >= world_size:
3247
            # just easier to write the test for exactly 4 gpus, even if this test class increased to 8gpu later
3248
            return
3249

3250
        subgroup = self._init_two_pg2_subgroups(world_size)
3251

3252
        # discrepancy #1
3253
        # have to set device or else gather_object gets wrong device from 'current_device = _get_pg_default_device(group)
3254
        torch.cuda.set_device(self.rank)
3255

3256
        input = {"rank": self.rank}
3257
        if self.rank == 0 or self.rank == 2:
3258
            # discrepancy #2
3259
            # another weird thing- what's the point of making me specify some empty objects in my list?
3260
            # empty list should be valid imo.  (but it throws an error)
3261
            gather_list = [{}, {}]
3262
            torch.distributed.gather_object(
3263
                input, object_gather_list=gather_list, dst=self.rank, group=subgroup
3264
            )
3265
            for src in range(len(gather_list)):
3266
                self.assertEqual(gather_list[src]["rank"], self.rank + src)
3267
        else:
3268
            torch.distributed.gather_object(
3269
                input, object_gather_list=None, dst=self.rank - 1, group=subgroup
3270
            )
3271

3272
    @requires_nccl()
3273
    @skip_if_lt_x_gpu(4)
3274
    def test_reduce_subgroup(self):
3275
        world_size = 4
3276
        if self.rank >= world_size:
3277
            return
3278
        subgroup = self._init_two_pg2_subgroups(world_size)
3279
        device = torch.device("cuda:%d" % self.rank)
3280
        x = torch.ones((10,), device=device) * self.rank
3281
        if self.rank == 0 or self.rank == 2:
3282
            expected = x + torch.ones((10,), device=device) * (self.rank + 1)
3283
            c10d.reduce(x, dst=self.rank, group=subgroup, async_op=False)
3284
            self.assertEqual(x, expected)
3285
        else:
3286
            c10d.reduce(x, dst=self.rank - 1, group=subgroup, async_op=False)
3287

3288
    @requires_nccl()
3289
    @skip_if_lt_x_gpu(4)
3290
    @parametrize("async_op", [True, False])
3291
    def test_send_recv_subgroup(self, async_op):
3292
        world_size = 4
3293
        if self.rank >= world_size:
3294
            return
3295
        subgroup = self._init_two_pg2_subgroups(world_size)
3296
        device = torch.device("cuda:%d" % self.rank)
3297
        if self.rank == 0 or self.rank == 2:
3298
            x = torch.empty((10,), device=device)
3299
            if async_op:
3300
                c10d.irecv(x, src=self.rank + 1, group=subgroup).wait()
3301
            else:
3302
                c10d.recv(x, src=self.rank + 1, group=subgroup)
3303
            expected = torch.ones((10,), device=device) * (self.rank + 1)
3304
            self.assertEqual(x, expected)
3305
        else:
3306
            x = torch.ones((10,), device=device) * self.rank
3307
            if async_op:
3308
                c10d.isend(x, dst=self.rank - 1, group=subgroup).wait()
3309
            else:
3310
                c10d.send(x, dst=self.rank - 1, group=subgroup)
3311

3312
    @requires_nccl()
3313
    @skip_if_lt_x_gpu(4)
3314
    def test_broadcast_subgroup(self):
3315
        world_size = 4
3316
        if self.rank >= world_size:
3317
            return
3318
        subgroup = self._init_two_pg2_subgroups(world_size)
3319
        device = torch.device("cuda:%d" % self.rank)
3320
        if self.rank == 0 or self.rank == 2:
3321
            x = torch.empty((10,), device=device)
3322
            c10d.broadcast(x, src=self.rank + 1, group=subgroup)
3323
            expected = torch.ones((10,), device=device) * (self.rank + 1)
3324
            self.assertEqual(x, expected)
3325
        else:
3326
            x = torch.ones((10,), device=device) * self.rank
3327
            c10d.broadcast(x, src=self.rank, group=subgroup)
3328

3329
    @requires_nccl()
3330
    @skip_if_lt_x_gpu(4)
3331
    @parametrize(
3332
        "set_device",
3333
        [SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
3334
    )
3335
    def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
3336
        world_size = 4
3337
        if self.rank >= world_size:
3338
            return
3339
        subgroup = self._init_two_pg2_subgroups(world_size)
3340
        if set_device == SetDeviceMethod.TORCH_CUDA_SET:
3341
            torch.cuda.set_device(self.rank)
3342
            device = None
3343
        else:
3344
            device = torch.device("cuda:%d" % self.rank)
3345
        if self.rank == 0 or self.rank == 2:
3346
            x = [{}]
3347
            c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device)
3348
            expected = [{"rank": self.rank + 1}]
3349
            self.assertEqual(x, expected)
3350
        else:
3351
            x = [{"rank": self.rank}]
3352
            c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device)
3353

3354
    @requires_nccl()
3355
    @skip_if_lt_x_gpu(4)
3356
    @parametrize(
3357
        "set_device",
3358
        [SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
3359
    )
3360
    def test_broadcast_object_list_subgroup(self, set_device: SetDeviceMethod):
3361
        world_size = 4
3362
        if self.rank >= world_size:
3363
            return
3364
        subgroup = self._init_two_pg2_subgroups(world_size)
3365
        if set_device == SetDeviceMethod.TORCH_CUDA_SET:
3366
            torch.cuda.set_device(self.rank)
3367
            device = None
3368
        else:
3369
            device = torch.device("cuda:%d" % self.rank)
3370
        if self.rank == 0 or self.rank == 2:
3371
            x = [{}]
3372
            c10d.broadcast_object_list(
3373
                x, src=self.rank + 1, group=subgroup, device=device
3374
            )
3375
            expected = [{"rank": self.rank + 1}]
3376
            self.assertEqual(x, expected)
3377
        else:
3378
            x = [{"rank": self.rank}]
3379
            c10d.broadcast_object_list(x, src=self.rank, group=subgroup, device=device)
3380

3381
    @requires_nccl()
3382
    @skip_if_lt_x_gpu(4)
3383
    def test_scatter_subgroup(self):
3384
        world_size = 4
3385
        if self.rank >= world_size:
3386
            return
3387
        subgroup = self._init_two_pg2_subgroups(world_size)
3388
        device = torch.device("cuda:%d" % self.rank)
3389
        x = torch.empty((10,), device=device)
3390
        expected = torch.ones((10,), device=device) * self.rank
3391
        if self.rank == 0 or self.rank == 2:
3392
            c10d.scatter(x, scatter_list=None, src=self.rank + 1, group=subgroup)
3393
        else:
3394
            scatter_list = [
3395
                torch.ones((10,), device=device) * (self.rank - 1),
3396
                torch.ones((10,), device=device) * self.rank,
3397
            ]
3398
            c10d.scatter(x, scatter_list=scatter_list, src=self.rank, group=subgroup)
3399
        self.assertEqual(x, expected)
3400

3401
    @requires_nccl()
3402
    @skip_if_lt_x_gpu(4)
3403
    def test_scatter_object_list_subgroup(self):
3404
        world_size = 4
3405
        if self.rank >= world_size:
3406
            return
3407
        subgroup = self._init_two_pg2_subgroups(world_size)
3408
        torch.cuda.set_device(self.rank)
3409
        scatter_object_output_list = [None]
3410
        expected = [{"rank": self.rank}]
3411
        if self.rank == 0 or self.rank == 2:
3412
            c10d.scatter_object_list(
3413
                scatter_object_output_list=scatter_object_output_list,
3414
                scatter_object_input_list=None,
3415
                src=self.rank + 1,
3416
                group=subgroup,
3417
            )
3418

3419
        else:
3420
            scatter_object_input_list = [
3421
                {"rank": self.rank - 1},
3422
                {"rank": self.rank},
3423
            ]
3424
            c10d.scatter_object_list(
3425
                scatter_object_output_list=scatter_object_output_list,
3426
                scatter_object_input_list=scatter_object_input_list,
3427
                src=self.rank,
3428
                group=subgroup,
3429
            )
3430
        self.assertEqual(scatter_object_output_list, expected)
3431

3432

3433
instantiate_parametrized_tests(LargeCommTest)
3434

3435

3436
class SparseCollective(MultiProcessTestCase):
3437
    @property
3438
    def world_size(self):
3439
        return 1
3440

3441
    def setUp(self):
3442
        super().setUp()
3443
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
3444
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
3445
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
3446
        # self.num_gpus = torch.cuda.device_count()
3447
        self._spawn_processes()
3448

3449
    def tearDown(self):
3450
        super().tearDown()
3451
        try:
3452
            os.remove(self.file_name)
3453
        except OSError:
3454
            pass
3455

3456
    class ToyModel(nn.Module):
3457
        def __init__(self, rank, vocab_size, embedding_dim):
3458
            super().__init__()
3459
            self.embedding = nn.Embedding(vocab_size, embedding_dim, sparse=True).to(
3460
                rank
3461
            )
3462
            self.linear = nn.Linear(embedding_dim, 1).to(rank)
3463

3464
        def forward(self, inputs):
3465
            embedded = self.embedding(inputs)
3466
            # embedded shape: (batch_size, sequence_length, embedding_dim)
3467
            flattened = torch.mean(embedded, dim=1)
3468
            # flattened shape: (batch_size, embedding_dim)
3469
            output = self.linear(flattened)
3470
            # output shape: (batch_size, 1)
3471
            return output
3472

3473
    @requires_nccl()
3474
    @skip_if_lt_x_gpu(1)
3475
    def test_ddp_set_sparse_metadata(self):
3476
        store = dist.FileStore(self.file_name, self.world_size)
3477
        dist.init_process_group(
3478
            "nccl",
3479
            world_size=self.world_size,
3480
            rank=self.rank,
3481
            store=store,
3482
        )
3483

3484
        vocab_size = 5
3485

3486
        model = SparseCollective.ToyModel(
3487
            self.rank, vocab_size=vocab_size, embedding_dim=10
3488
        )
3489
        ddp_model = DistributedDataParallel(model)
3490
        inputs = torch.tensor([[1, 0, 0], [0, 0, 0], [0, 0, 0]]).to(self.rank)
3491
        # set sparse metadata on the DDP model
3492
        indices = torch.Tensor(list(range(vocab_size)))
3493
        ddp_model._set_sparse_metadata({"embedding.weight": indices})
3494
        # forward pass
3495
        try:
3496
            output = ddp_model(inputs)
3497
            loss = output.sum()
3498

3499
            # backward pass
3500
            loss.backward()
3501
            self.assertTrue(ddp_model.module.embedding.weight.grad.indices, indices)
3502
        except RuntimeError as e:
3503
            if "NCCL does not support all_reduce with sparse tensors" in str(e):
3504
                pass
3505
            else:
3506
                # Rethrow the exception if it's a different error
3507
                raise
3508

3509

3510
class NCCLTraceTestBase(MultiProcessTestCase):
3511
    def setUp(self):
3512
        super().setUp()
3513
        os.environ[
3514
            "TORCH_NCCL_ENABLE_TIMING"
3515
        ] = "0"  # see 'timing_enabled' parametrized tests
3516
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
3517
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
3518
        self.tempdir = tempfile.TemporaryDirectory()
3519
        os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
3520
        os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
3521
        self._spawn_processes()
3522

3523
    @classmethod
3524
    def _run(
3525
        cls,
3526
        parent_conn,
3527
        rank: int,
3528
        test_name: str,
3529
        file_name: str,
3530
        parent_pipe,
3531
        **kwargs,
3532
    ) -> None:
3533
        cls.parent = parent_conn
3534
        super()._run(rank, test_name, file_name, parent_pipe)
3535

3536
    @property
3537
    def local_device(self):
3538
        return torch.device("cuda", self.rank_to_GPU[self.rank][0])
3539

3540
    def _join_processes(self, fn):
3541
        # We need to patch sys.exit() as skip_if will use sys.exit() and
3542
        # the exit code from the this process will not be catched.
3543
        with mock.patch("sys.exit") as exit_mock:
3544
            fn()
3545
        super()._join_processes(fn)
3546

3547
    def _spawn_processes(self) -> None:
3548
        proc = torch.multiprocessing.get_context("spawn").Process
3549
        self.children_pipes = []
3550
        parent_pipes = []
3551
        for i in range(self.world_size):
3552
            parent_conn, child_conn = torch.multiprocessing.Pipe()
3553
            self.children_pipes.append(child_conn)
3554
            parent_pipes.append(parent_conn)
3555
        piter = iter(parent_pipes)
3556

3557
        def wrap(*positional, args, **kwargs):
3558
            args = (next(piter), *args)
3559
            return proc(*positional, args=args, **kwargs)
3560

3561
        self._start_processes(wrap)
3562

3563
    def _create_process_group_nccl(self):
3564
        store = dist.FileStore(self.file_name, self.world_size)
3565
        c10d.init_process_group(
3566
            "nccl", world_size=self.world_size, rank=self.rank, store=store
3567
        )
3568
        pg = c10d.distributed_c10d._get_default_group()
3569
        return pg
3570

3571
    def tearDown(self):
3572
        super().tearDown()
3573
        try:
3574
            os.remove(self.file_name)
3575
        except OSError:
3576
            pass
3577

3578
    @property
3579
    def world_size(self):
3580
        return 2
3581

3582
    @property
3583
    def rank_to_GPU(self):
3584
        # return rank to GPU map
3585
        return init_multigpu_helper(self.world_size, "nccl")
3586

3587
    def _trace_basename(self):
3588
        # we pass the base to the env, and the dump util will append rank
3589
        return os.path.join(self.tempdir.name, "trace_")
3590

3591
    def _trace_name(self, rank):
3592
        return self._trace_basename() + str(rank)
3593

3594
    def started_or_scheduled(self, timing_enabled):
3595
        return "started" if timing_enabled else "scheduled"
3596

3597

3598
class NCCLTraceTest(NCCLTraceTestBase):
3599
    def _verify_trace(self, t, include_collectives, timing_enabled, is_json):
3600
        ver = t["version"]
3601
        self.assertEqual(ver, "2.3")
3602
        pg_config = t["pg_config"]
3603
        self.assertEqual(len(pg_config), 1)
3604
        default_pg_info = pg_config["0"]
3605
        self.assertIn("name", default_pg_info)
3606
        self.assertIn("desc", default_pg_info)
3607
        self.assertIn("ranks", default_pg_info)
3608
        pg_status = t["pg_status"]
3609
        self.assertEqual(len(pg_status), 1)
3610
        self.assertEqual(str(pg_status["0"]["last_enqueued_collective"]), "2")
3611
        self.assertEqual(str(pg_status["0"]["last_completed_collective"]), "2")
3612
        self.assertEqual(
3613
            str(pg_status["0"]["last_started_collective"]),
3614
            "2" if timing_enabled else "-1",
3615
        )
3616
        global_ranks = pg_config["0"]["ranks"]
3617
        self.assertEqual(len(json.loads(global_ranks)), self.world_size)
3618
        if include_collectives:
3619
            self.assertEqual(len(t["entries"]), 2)
3620
            t = t["entries"]
3621
            last = t[-1]
3622
            self.assertEqual(last["process_group"], ("0", "default_pg"))
3623
            self.assertEqual(last["state"], "completed")
3624
            s = last["time_discovered_started_ns"]
3625
            f = last["time_discovered_completed_ns"]
3626
            self.assertEqual(last["record_id"], 1)
3627
            self.assertIsNotNone(f)
3628
            if timing_enabled:
3629
                self.assertIsNotNone(s)
3630
                self.assertTrue(s <= f)
3631
            # we don't collect stack traces in JSON at the moment
3632
            if not is_json:
3633
                self.assertIn("test_c10d_nccl.py", str(last["frames"]))
3634
            self.assertEqual(last["input_sizes"], ((3, 4),))
3635
            self.assertEqual(last["input_dtypes"], ["Float"])
3636
            self.assertEqual(last["output_sizes"], ((3, 4),))
3637
            self.assertEqual(last["output_dtypes"], ["Float"])
3638
            self.assertEqual(last["collective_seq_id"], 2)
3639
            self.assertEqual(last["timeout_ms"], 600000)
3640
            now = datetime.now()
3641
            event_created_time = datetime.fromtimestamp(
3642
                last["time_created_ns"] / 1000000000
3643
            )
3644
            before_test = now - timedelta(minutes=1)
3645
            self.assertTrue(before_test < event_created_time < now)
3646
            if timing_enabled:
3647
                # very loose bounds, measured 0.036 ms on devgpu
3648
                self.assertTrue(0 < last["duration_ms"] < 100)
3649
            else:
3650
                self.assertTrue("duration_ms" not in last)
3651
        else:
3652
            self.assertTrue("entries" not in t)
3653

3654
    @requires_nccl()
3655
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3656
    @parametrize("timing_enabled", [True, False])
3657
    @parametrize("include_collectives", [True, False])
3658
    def test_short_json(self, timing_enabled, include_collectives):
3659
        if self.rank == self.MAIN_PROCESS_RANK:
3660
            return
3661
        pg = self._create_process_group_nccl()
3662
        if timing_enabled:
3663
            pg._enable_collectives_timing()
3664
        device = self.local_device
3665
        a = torch.full((3, 4), float(self.rank), device=device)
3666
        for i in range(2):
3667
            f = pg.allreduce(a)
3668
        f.wait()
3669
        torch.cuda.synchronize(device=device)
3670
        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
3671
        time.sleep(1)
3672
        t = json.loads(
3673
            torch._C._distributed_c10d._dump_nccl_trace_json(
3674
                includeCollectives=include_collectives
3675
            )
3676
        )
3677
        self._verify_trace(t, include_collectives, timing_enabled, True)
3678
        dist.destroy_process_group()
3679

3680
    @requires_nccl()
3681
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3682
    @parametrize("timing_enabled", [True, False])
3683
    @parametrize("include_collectives", [True, False])
3684
    def test_short_pickle(self, timing_enabled, include_collectives):
3685
        if self.rank == self.MAIN_PROCESS_RANK:
3686
            return
3687
        pg = self._create_process_group_nccl()
3688
        if timing_enabled:
3689
            pg._enable_collectives_timing()
3690
        device = self.local_device
3691
        a = torch.full((3, 4), float(self.rank), device=device)
3692
        for i in range(2):
3693
            f = pg.allreduce(a)
3694
        f.wait()
3695
        torch.cuda.synchronize(device=device)
3696
        # gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
3697
        time.sleep(1)
3698
        t = pickle.loads(
3699
            torch._C._distributed_c10d._dump_nccl_trace(
3700
                includeCollectives=include_collectives
3701
            )
3702
        )
3703
        self._verify_trace(
3704
            t,
3705
            include_collectives=include_collectives,
3706
            timing_enabled=timing_enabled,
3707
            is_json=True,
3708
        )
3709
        dist.destroy_process_group()
3710

3711
    @requires_nccl()
3712
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3713
    def test_dump_pipe(self):
3714
        def open_file_with_timeout(file_path, mode, timeout=1.0):
3715
            start_time = time.time()
3716
            while time.time() - start_time < timeout:
3717
                if os.path.exists(file_path):
3718
                    return open(file_path, mode)
3719
                time.sleep(0.1)
3720
            raise FileNotFoundError
3721

3722
        if self.rank == self.MAIN_PROCESS_RANK:
3723
            for c in self.children_pipes:
3724
                self.assertEqual(c.recv(), "next")
3725

3726
            dump_file = self._trace_name(rank=0)
3727
            pipe_file = dump_file + ".pipe"
3728
            with open_file_with_timeout(pipe_file, "w") as f:
3729
                f.write("1\n")
3730
            with open_file_with_timeout(dump_file, "rb", timeout=10.0) as f:
3731
                self.assertTrue("all_reduce" in str(pickle.load(f)))
3732

3733
            for c in self.children_pipes:
3734
                c.send("next")
3735
            return
3736

3737
        pg = self._create_process_group_nccl()
3738
        device = self.local_device
3739
        a = torch.full((3, 4), float(self.rank), device=device)
3740
        for i in range(2):
3741
            f = pg.allreduce(a)
3742
        f.wait()
3743
        torch.cuda.synchronize(device=device)
3744
        self.parent.send("next")
3745
        self.parent.recv()
3746

3747
    @requires_nccl()
3748
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3749
    def test_long(self):
3750
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
3751
        if self.rank == self.MAIN_PROCESS_RANK:
3752
            return
3753
        pg = self._create_process_group_nccl()
3754
        device = self.local_device
3755
        a = torch.full((3, 4), float(self.rank), device=device)
3756
        for i in range(2):
3757
            # test some other primitives to make sure
3758
            # their strings are valid
3759
            xs = [torch.ones(3, 4, device=device)]
3760
            pg.broadcast(xs).wait()
3761
            pg.allreduce(xs).wait()
3762
            pg.reduce(xs).wait()
3763
            ys = [[torch.empty(3, 4, device=device) for _ in range(self.world_size)]]
3764
            pg.allgather(ys, xs).wait()
3765
            pg.reduce_scatter(xs, ys).wait()
3766
            f = pg.allreduce(a)
3767
        f.wait()
3768
        torch.cuda.synchronize(device=device)
3769
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3770
        t = t["entries"]
3771
        self.assertEqual(len(t), 10)
3772
        first = t[0]
3773
        last = t[-1]
3774
        self.assertEqual(last["profiling_name"], "nccl:all_reduce")
3775
        self.assertEqual(last["state"], "completed")
3776
        self.assertIn("test_c10d_nccl.py", str(last["frames"]))
3777
        self.assertEqual(last["input_sizes"], ((3, 4),))
3778
        self.assertEqual(last["input_dtypes"], ["Float"])
3779
        self.assertEqual(last["output_sizes"], ((3, 4),))
3780
        self.assertEqual(last["output_dtypes"], ["Float"])
3781
        self.assertEqual(last["timeout_ms"], 600000)
3782
        self.assertEqual(last["collective_seq_id"] - first["collective_seq_id"], 9)
3783
        dist.destroy_process_group()
3784

3785
    @requires_nccl()
3786
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3787
    def test_trace_while_all_works_retired(self):
3788
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
3789
        if self.rank == self.MAIN_PROCESS_RANK:
3790
            return
3791
        pg = self._create_process_group_nccl()
3792
        device = self.local_device
3793
        # send more works than the buffer size to overwrite the previous entry
3794
        for i in range(12):
3795
            a = [torch.ones(3, 4, device=device)]
3796
            pg.broadcast(a).wait()
3797
        torch.cuda.synchronize(device=device)
3798

3799
        # wait for all works to be retired
3800
        pg._wait_for_pending_works()
3801
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3802
        t = t["entries"]
3803
        self.assertEqual(len(t), 10)
3804
        last = t[-1]
3805
        self.assertEqual(last["retired"], True)
3806
        self.assertEqual(last["state"], "completed")
3807

3808
    @requires_nccl()
3809
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3810
    @parametrize("timing_enabled", [True, False])
3811
    @parametrize("only_active", [True, False])
3812
    def test_trace_while_active(self, timing_enabled, only_active):
3813
        if self.rank == self.MAIN_PROCESS_RANK:
3814
            for c in self.children_pipes:
3815
                self.assertEqual(c.recv(), "next")
3816
            for c in self.children_pipes:
3817
                c.send("next")
3818
            return
3819

3820
        pg = self._create_process_group_nccl()
3821
        if timing_enabled:
3822
            pg._enable_collectives_timing()
3823
        device = self.local_device
3824
        with torch.cuda.device(device):
3825
            a = torch.full((3, 4), float(self.rank), device=device)
3826

3827
            pg.allreduce(a).wait()
3828
            e = torch.cuda.Event()
3829
            e.record()
3830
            if self.rank != 0:
3831
                pg.allreduce(a).wait()
3832
            e.synchronize()
3833
            t = pickle.loads(
3834
                torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active)
3835
            )
3836
            t = t["entries"]
3837
            if only_active:
3838
                if self.rank == 0:
3839
                    self.assertEqual(len(t), 0)
3840
                else:
3841
                    self.assertEqual(len(t), 1)
3842
            if not only_active:
3843
                if self.rank == 0:
3844
                    self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
3845
                    self.assertEqual(t[-1]["collective_seq_id"], 1)
3846
                    self.assertEqual(t[-1]["state"], "completed")
3847
                else:
3848
                    self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
3849
                    self.assertEqual(t[-1]["collective_seq_id"], 2)
3850
                    self.assertEqual(
3851
                        t[-1]["state"], self.started_or_scheduled(timing_enabled)
3852
                    )
3853

3854
            self.parent.send("next")
3855
            self.assertEqual("next", self.parent.recv())
3856
            if self.rank == 0:
3857
                pg.allreduce(a).wait()
3858
            torch.cuda.synchronize(device=device)
3859

3860
    @requires_nccl()
3861
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3862
    @parametrize("timing_enabled", [True, False])
3863
    def test_trace_while_stuck(self, timing_enabled):
3864
        if self.rank == self.MAIN_PROCESS_RANK:
3865
            for c in self.children_pipes:
3866
                self.assertEqual(c.recv(), "next")
3867
            for c in self.children_pipes:
3868
                c.send("next")
3869
            return
3870

3871
        pg = self._create_process_group_nccl()
3872
        if timing_enabled:
3873
            pg._enable_collectives_timing()
3874

3875
        device = self.local_device
3876
        with torch.cuda.device(device):
3877
            a = torch.full((3, 4), float(self.rank), device=device)
3878

3879
            pg.allreduce(a).wait()
3880
            e = torch.cuda.Event()
3881
            e.record()
3882

3883
            def gather_trace():
3884
                e.synchronize()
3885
                # give the other thread some time to fill the cuda buffer
3886
                time.sleep(5)
3887
                t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3888
                t = t["entries"]
3889
                self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
3890
                if self.rank == 0:
3891
                    self.assertEqual(t[-1]["collective_seq_id"], 1)
3892
                    self.assertEqual(t[-1]["state"], "completed")
3893
                else:
3894
                    self.assertEqual(t[-1]["collective_seq_id"], 2)
3895
                    self.assertEqual(
3896
                        t[-1]["state"], self.started_or_scheduled(timing_enabled)
3897
                    )
3898
                    self.assertIsNone(t[-1]["time_discovered_completed_ns"])
3899
                # this will eventually cause the missing rank 0
3900
                # to continue which will unblock the non-zero ranks
3901
                self.parent.send("next")
3902

3903
            if self.rank != 0:
3904
                pg.allreduce(a).wait()
3905
                th = threading.Thread(target=gather_trace)
3906
                th.start()
3907
                # fill the cuda buffer, at around 1024 events
3908
                # this will stall
3909
                for i in range(2000):
3910
                    a = a + a
3911
                th.join()
3912
            else:
3913
                gather_trace()
3914

3915
            self.assertEqual("next", self.parent.recv())
3916
            if self.rank == 0:
3917
                pg.allreduce(a).wait()
3918
            torch.cuda.synchronize(device=device)
3919

3920
    @requires_nccl()
3921
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
3922
    @parametrize(
3923
        "op_sizes_per_coalesce",
3924
        [
3925
            [(2, 3)],
3926
            [(2, 3), (5, 5), (1,)],
3927
        ],
3928
    )
3929
    @parametrize("timing_enabled", [True, False])
3930
    def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled):
3931
        """
3932
        'WorkEnqueue' was skipped for isendirecv, leading to segfault on dump_entries when update_state tried to use
3933
        a destructed Work obj's cuda events
3934
        """
3935

3936
        if self.rank == self.MAIN_PROCESS_RANK:
3937
            return
3938
        pg = self._create_process_group_nccl()
3939
        if timing_enabled:
3940
            pg._enable_collectives_timing()
3941

3942
        num_coalesced_ops = 20
3943
        ops_per_coalesce = len(op_sizes_per_coalesce)
3944
        for i in range(num_coalesced_ops):
3945
            ops = []
3946
            for input_sizes in op_sizes_per_coalesce:
3947
                tensor = torch.zeros(input_sizes).to(self.local_device)
3948
                if self.rank == 0:
3949
                    ops.append(dist.P2POp(dist.irecv, tensor, 1))
3950
                elif self.rank == 1:
3951
                    tensor *= 2
3952
                    ops.append(dist.P2POp(dist.isend, tensor, 0))
3953

3954
            dist.batch_isend_irecv(ops).pop().wait()
3955

3956
        torch.cuda.synchronize(device=self.local_device)
3957

3958
        if timing_enabled:
3959
            # wait for watchdog thread to process the queue of works
3960
            time.sleep(1)
3961

3962
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
3963
        self.assertEqual(len(t["entries"]), num_coalesced_ops * (ops_per_coalesce + 1))
3964

3965
        expected_record_id = 0
3966
        expected_seq = 1
3967
        expected_op_id = 1
3968
        for seq in range(num_coalesced_ops):
3969
            first_op = seq * (ops_per_coalesce + 1)
3970
            coalesced_op = first_op + ops_per_coalesce
3971
            for p2p_op_idx, input_sizes in zip(
3972
                range(first_op, coalesced_op, 1), op_sizes_per_coalesce
3973
            ):
3974
                # the indivudal ops inside the coalescing group the individual op metadata,
3975
                # but not the timing info coming from the actual coalesced kernel
3976
                profiling_name = (
3977
                    "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0"
3978
                )
3979
                self.assertEqual(
3980
                    t["entries"][p2p_op_idx]["record_id"], expected_record_id
3981
                )
3982
                expected_record_id += 1
3983
                self.assertEqual(
3984
                    t["entries"][p2p_op_idx]["profiling_name"], profiling_name
3985
                )
3986
                self.assertEqual(
3987
                    t["entries"][p2p_op_idx]["collective_seq_id"], expected_seq
3988
                )
3989
                self.assertEqual(t["entries"][p2p_op_idx]["op_id"], expected_op_id)
3990
                expected_op_id += 1
3991
                self.assertEqual(t["entries"][p2p_op_idx]["input_sizes"], [input_sizes])
3992
                self.assertEqual(
3993
                    t["entries"][p2p_op_idx]["output_sizes"], [input_sizes]
3994
                )
3995
                # duration doesn't get tagged onto individual ops yet, nor is their state updated
3996
                self.assertEqual(t["entries"][p2p_op_idx]["state"], "scheduled")
3997
                self.assertTrue("duration_ms" not in t["entries"][p2p_op_idx])
3998

3999
            # the coalesced op has no metadata but indicates that coalescing was used,
4000
            # and accurately reflects the timing and state info for the whole group
4001
            self.assertEqual(
4002
                t["entries"][coalesced_op]["record_id"], expected_record_id
4003
            )
4004
            expected_record_id += 1
4005
            self.assertEqual(
4006
                t["entries"][coalesced_op]["profiling_name"], "nccl:coalesced"
4007
            )
4008
            self.assertEqual(
4009
                t["entries"][coalesced_op]["collective_seq_id"], expected_seq
4010
            )
4011
            expected_seq += 1
4012
            self.assertEqual(t["entries"][coalesced_op]["state"], "completed")
4013
            self.assertEqual(t["entries"][coalesced_op]["input_sizes"], [])
4014
            self.assertEqual(t["entries"][coalesced_op]["output_sizes"], [])
4015
            if timing_enabled:
4016
                duration = t["entries"][coalesced_op]["duration_ms"]
4017
                self.assertTrue(0.001 < duration < 10000, duration)
4018
            else:
4019
                self.assertTrue("duration_ms" not in t["entries"][coalesced_op])
4020
            self.assertEqual(t["entries"][coalesced_op]["timeout_ms"], 600000)
4021

4022
    @requires_nccl()
4023
    @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
4024
    @parametrize(
4025
        "op_sizes",
4026
        [
4027
            [(2, 3)],
4028
            [(2, 3), (5, 5), (1,)],
4029
        ],
4030
    )
4031
    @parametrize("timing_enabled", [True, False])
4032
    def test_individual_send_recv(self, op_sizes, timing_enabled):
4033
        """
4034
        'WorkEnqueue' was skipped for isendirecv, leading to segfault on dump_entries when update_state tried to use
4035
        a destructed Work obj's cuda events
4036
        """
4037

4038
        if self.rank == self.MAIN_PROCESS_RANK:
4039
            return
4040
        pg = self._create_process_group_nccl()
4041
        if timing_enabled:
4042
            pg._enable_collectives_timing()
4043
        num_repeats = 10
4044
        ops_per_repeat = len(op_sizes)
4045
        for i in range(num_repeats):
4046
            for input_sizes in op_sizes:
4047
                tensor = torch.zeros(input_sizes).to(self.local_device)
4048
                if self.rank == 0:
4049
                    dist.recv(tensor, 1)
4050
                elif self.rank == 1:
4051
                    tensor *= 2
4052
                    dist.send(tensor, 0)
4053

4054
        torch.cuda.synchronize(device=self.local_device)
4055
        if timing_enabled:
4056
            # wait for watchdog thread to process the queue of works
4057
            time.sleep(1)
4058

4059
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
4060
        self.assertEqual(len(t["entries"]), num_repeats * (ops_per_repeat))
4061
        expected_seq = 1
4062
        expected_op_id = 1
4063
        for seq in range(num_repeats * ops_per_repeat):
4064
            input_sizes = op_sizes[seq % ops_per_repeat]
4065
            profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0"
4066
            self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name)
4067
            self.assertEqual(t["entries"][seq]["p2p_seq_id"], expected_seq)
4068
            expected_seq += 1
4069
            self.assertEqual(t["entries"][seq]["op_id"], expected_op_id)
4070
            expected_op_id += 1
4071
            self.assertEqual(t["entries"][seq]["input_sizes"], [input_sizes])
4072
            self.assertEqual(t["entries"][seq]["output_sizes"], [input_sizes])
4073
            self.assertEqual(t["entries"][seq]["state"], "completed")
4074

4075
            if timing_enabled:
4076
                duration = t["entries"][seq]["duration_ms"]
4077
                self.assertTrue(0.001 < duration < 10000, duration)
4078
            else:
4079
                self.assertTrue("duration_ms" not in t["entries"][seq])
4080

4081
    # TODO(whc) support and test coalesced collectives that use the c++ start/end group thingy instead of python
4082
    # coalescing manager
4083

4084
    # TODO(whc) test out other ops (And combinations of ops, if that's valid?)
4085
    @requires_nccl()
4086
    @skip_if_lt_x_gpu(2)
4087
    @parametrize("timing_enabled", [True, False])
4088
    def test_coalescing_manager_collective(self, timing_enabled):
4089
        """
4090
        The coalescing manager api works by accumulating operations in python via a contextmanager, and then making
4091
        one call into c++ to an <op>_coalesced API.  It has limited support for ops and has been added recently to
4092
        avoid overheads of making individual py-cpp calls.  This complicates flight recording..
4093

4094
        For now, flight recording of coalescing_manager collectives is less detailed than cpp coalesced collectives.
4095
        """
4096
        if self.rank == self.MAIN_PROCESS_RANK:
4097
            return
4098
        pg = self._create_process_group_nccl()
4099
        if timing_enabled:
4100
            pg._enable_collectives_timing()
4101

4102
        output_tensors = torch.zeros(2, 2).to(self.rank)
4103
        input_tensors = [torch.ones(2, 2).to(self.rank) for _ in range(self.world_size)]
4104

4105
        # TODO(whc) make this work with bigger world or something
4106
        self.assertEqual(self.world_size, 2, self.world_size)
4107

4108
        with dist._coalescing_manager():
4109
            for i in range(self.world_size):
4110
                dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
4111
        self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
4112

4113
        torch.cuda.synchronize(device=self.rank)
4114

4115
        if timing_enabled:
4116
            # wait for watchdog thread to process the queue of works
4117
            time.sleep(1)
4118

4119
        t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
4120

4121
        self.assertEqual(
4122
            len(t["entries"]), 1
4123
        )  # one for the reduce_scatter_tensor_coalesced, one for the endCoalescing
4124
        self.assertEqual(
4125
            t["entries"][0]["profiling_name"], "nccl:reduce_scatter_tensor_coalesced"
4126
        )
4127
        self.assertEqual(t["entries"][0]["collective_seq_id"], 1)
4128
        self.assertEqual(t["entries"][0]["input_sizes"], [[2, 2], [2, 2]])
4129
        self.assertEqual(
4130
            t["entries"][0]["output_sizes"],
4131
            [
4132
                [
4133
                    2,
4134
                ],
4135
                [
4136
                    2,
4137
                ],
4138
            ],
4139
        )
4140
        self.assertEqual(t["entries"][0]["state"], "completed")
4141
        if timing_enabled:
4142
            duration = t["entries"][0]["duration_ms"]
4143
            self.assertTrue(0.001 < duration < 10000, duration)
4144
        else:
4145
            self.assertTrue("duration_ms" not in t["entries"][0])
4146

4147

4148
def check_if_test_is_skipped(fn):
4149
    def wrapper(self, *args, **kwargs):
4150
        for skip in TEST_SKIPS.values():
4151
            if self.processes[0].exitcode == skip.exit_code:
4152
                return MultiProcessTestCase._check_return_codes(self, *args, **kwargs)
4153
        return fn(self, *args, **kwargs)
4154

4155
    return wrapper
4156

4157

4158
class NCCLTraceTestDumpOnTimeoutBase(NCCLTraceTestBase):
4159
    timeout_sec = 1
4160

4161
    def _create_process_group_nccl(self):
4162
        store = dist.FileStore(self.file_name, self.world_size)
4163
        c10d.init_process_group(
4164
            "nccl",
4165
            world_size=self.world_size,
4166
            rank=self.rank,
4167
            store=store,
4168
            timeout=timedelta(seconds=NCCLTraceTestDumpOnTimeoutBase.timeout_sec),
4169
        )
4170
        pg = c10d.distributed_c10d._get_default_group()
4171
        return pg
4172

4173
    @check_if_test_is_skipped
4174
    def _check_return_codes(self, elapsed_time):
4175
        # the base test infra assumes processes exit with matching return codes,
4176
        # but we want rank0 to abort and rank1 to exit cleanly in this test
4177
        self.assertEqual(self.processes[0].exitcode, -6)
4178
        self.assertEqual(self.processes[1].exitcode, 0)
4179

4180
    def _wait_process(self, rank, timeout):
4181
        try:
4182
            self.processes[rank].join(timeout)
4183
            return self.processes[rank].exitcode
4184
        except TimeoutError:
4185
            return None
4186

4187

4188
@skip_but_pass_in_sandcastle
4189
class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
4190
    @requires_nccl()
4191
    @skip_if_lt_x_gpu(2)
4192
    @parametrize("timing_enabled", [True, False])
4193
    def test_timeout_dumps(self, timing_enabled):
4194
        # dump on heartbeatmonitor thread
4195
        os.environ["TORCH_NCCL_COORD_CHECK_MILSEC"] = "1000"
4196
        # need rank0 to crash before looking for its output file
4197
        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "1"
4198

4199
        if self.rank == self.MAIN_PROCESS_RANK:
4200
            # wait for rank0 to crash before looking for its output file
4201
            # we rely on rank0 holding off its abort long enough to dump the debug info
4202
            self.assertEqual(self._wait_process(0, timeout=90), -6)
4203
            with open(self._trace_name(rank=0), "rb") as f:
4204
                t = pickle.load(f)
4205
                t = t["entries"]
4206
                self.assertEqual(len(t), 2)
4207
                self.assertEqual(t[0]["collective_seq_id"], 1)
4208
                self.assertEqual(t[0]["state"], "completed")
4209
                self.assertEqual(t[1]["collective_seq_id"], 2)
4210
                self.assertEqual(
4211
                    t[1]["state"], self.started_or_scheduled(timing_enabled)
4212
                )
4213

4214
            self.assertFalse(os.path.exists(self._trace_name(rank=1)))
4215

4216
            return
4217

4218
        pg = self._create_process_group_nccl()
4219
        if timing_enabled:
4220
            # we force disabled timing in setup, since there is no 'disable' function
4221
            pg._enable_collectives_timing()
4222

4223
        device = self.local_device
4224
        with torch.cuda.device(device):
4225
            a = torch.full((3, 4), float(self.rank), device=device)
4226

4227
            pg.allreduce(a).wait()
4228
            if self.rank == 0:
4229
                pg.allreduce(a).wait()
4230

4231
            # rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
4232
            torch.cuda.synchronize(device=device)
4233

4234

4235
instantiate_parametrized_tests(ProcessGroupNCCLGroupTest)
4236
instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
4237
instantiate_parametrized_tests(NCCLTraceTest)
4238

4239

4240
@skip_but_pass_in_sandcastle
4241
class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
4242
    @check_if_test_is_skipped
4243
    def _check_return_codes(self, elapsed_time):
4244
        # the base test infra assumes processes exit with matching return codes,
4245
        # but we want rank0 to abort and rank1 to exit cleanly in this test
4246
        self.assertEqual(self.processes[0].exitcode, -6)
4247
        self.assertEqual(self.processes[1].exitcode, -6)
4248

4249
    @requires_nccl()
4250
    @skip_if_lt_x_gpu(2)
4251
    def test_timeout_dumps_on_stuck_ranks(self):
4252
        # need rank0 to crash quicker after detecting timeout
4253
        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "1"
4254
        # restore this env var to its prior default in case another test changed it
4255
        os.environ["TORCH_NCCL_COORD_CHECK_MILSEC"] = "1000"
4256

4257
        if self.rank == self.MAIN_PROCESS_RANK:
4258
            # wait for both rank0 and 1 to crash before looking for both ranks' output
4259
            # file, and we rely on rank1 to sleep long enough to dump the debug info.
4260
            self.assertEqual(self._wait_process(0, timeout=90), -6)
4261
            self.assertEqual(self._wait_process(1, timeout=90), -6)
4262
            self.assertTrue(os.path.exists(self._trace_name(rank=1)))
4263
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
4264
            with open(self._trace_name(rank=0), "rb") as f:
4265
                t = pickle.load(f)
4266
                t = t["entries"]
4267
                self.assertEqual(len(t), 2)
4268
            with open(self._trace_name(rank=1), "rb") as f:
4269
                t = pickle.load(f)
4270
                t = t["entries"]
4271
                self.assertEqual(len(t), 1)
4272
                self.assertEqual(t[0]["collective_seq_id"], 1)
4273
                self.assertEqual(t[0]["state"], "completed")
4274
            return
4275

4276
        pg = self._create_process_group_nccl()
4277
        device = self.local_device
4278
        with torch.cuda.device(device):
4279
            a = torch.full((3, 4), float(self.rank), device=device)
4280

4281
            pg.allreduce(a).wait()
4282
            if self.rank == 0:
4283
                pg.allreduce(a).wait()
4284

4285
            # rank 0 will get stuck, timeout and then signal a timeout to all ranks.
4286
            torch.cuda.synchronize(device=device)
4287

4288
            if self.rank == 1:
4289
                # Force rank 1 to idle so that it will eventually timeout as well after
4290
                # getting the global signal to dump the debugging info.
4291
                time.sleep(600)
4292

4293

4294
@skip_but_pass_in_sandcastle
4295
class NcclErrorDumpTest(NCCLTraceTestBase):
4296
    def _wait_process(self, rank, timeout):
4297
        try:
4298
            self.processes[rank].join(timeout)
4299
            return self.processes[rank].exitcode
4300
        except TimeoutError:
4301
            return None
4302

4303
    @check_if_test_is_skipped
4304
    def _check_return_codes(self, elapsed_time):
4305
        # the base test infra assumes processes exit with matching return codes,
4306
        # but we want rank0 to abort with exception and rank1 to exit with exit 1
4307
        self.assertEqual(self.processes[0].exitcode, -6)
4308
        self.assertEqual(self.processes[1].exitcode, 1)
4309

4310
    @requires_nccl()
4311
    @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
4312
    @skip_if_lt_x_gpu(2)
4313
    @skip_if_rocm
4314
    def test_nccl_errors_dump(self):
4315
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
4316
        os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000"
4317
        os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1"
4318
        # need rank0 to dump before abort
4319
        os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "5"
4320

4321
        if self.rank == self.MAIN_PROCESS_RANK:
4322
            # wait for both rank0 and 1 to crash before looking for dump
4323
            self.assertEqual(self._wait_process(0, timeout=90), -6)
4324
            self.assertEqual(self._wait_process(1, timeout=90), 1)
4325
            # verify that the trace file exists for rank0
4326
            self.assertTrue(os.path.exists(self._trace_name(rank=0)))
4327
            return
4328

4329
        store = c10d.FileStore(self.file_name, self.world_size)
4330
        process_group = c10d.ProcessGroupNCCL(
4331
            store,
4332
            self.rank,
4333
            self.world_size,
4334
            timeout=timedelta(seconds=10),
4335
        )
4336
        process_group.allreduce(torch.rand(10).cuda(self.rank))
4337
        if self.rank == 0:
4338
            work = process_group.allreduce(torch.rand(10).cuda(self.rank))
4339
            # expect an error to be raised
4340
            with self.assertRaisesRegex(dist.DistBackendError, ""):
4341
                # Block the current stream on the NCCL stream
4342
                work.wait()
4343
                # Run some GPU operations
4344
                a = torch.rand(10).cuda(self.rank)
4345
        elif self.rank == 1:
4346
            # Clean up structures (ex: files for FileStore before going down)
4347
            del process_group
4348
            sys.exit(1)
4349

4350

4351
# tests that needs to be run with a larger world size
4352
class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
4353
    def _create_process_group_nccl(self, store, opts, device_id=None):
4354
        # create nccl processgroup with opts
4355
        c10d.init_process_group(
4356
            "nccl",
4357
            world_size=self.world_size,
4358
            rank=self.rank,
4359
            store=store,
4360
            pg_options=opts,
4361
            device_id=device_id,
4362
        )
4363
        pg = c10d.distributed_c10d._get_default_group()
4364
        return pg
4365

4366
    def opts(self, high_priority_stream=False):
4367
        opts = c10d.ProcessGroupNCCL.Options()
4368
        opts.is_high_priority_stream = high_priority_stream
4369
        return opts
4370

4371
    def setUp(self):
4372
        super().setUp()
4373
        # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
4374
        # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
4375
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
4376
        # self.num_gpus = torch.cuda.device_count()
4377
        self._spawn_processes()
4378

4379
    def tearDown(self):
4380
        super().tearDown()
4381
        try:
4382
            os.remove(self.file_name)
4383
        except OSError:
4384
            pass
4385

4386
    @property
4387
    def world_size(self):
4388
        return 8
4389

4390
    @property
4391
    def rank_to_GPU(self):
4392
        # return rank to GPU map
4393
        return init_multigpu_helper(self.world_size, "nccl")
4394

4395
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
4396
    @skip_if_lt_x_gpu(8)
4397
    def test_comm_split_group_larger_scale(self):
4398
        store = c10d.FileStore(self.file_name, self.world_size)
4399
        device = torch.device(f"cuda:{self.rank}")
4400
        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
4401
        backend = pg._get_backend(torch.device(device))
4402

4403
        tensor = torch.full((1,), self.rank).cuda(device)
4404
        ng1 = c10d.split_group(pg, [[0, 1], [2, 3, 4, 5, 6, 7]])
4405
        backend1 = ng1._get_backend(torch.device(device))
4406

4407
        # comm split happens eagerly since device_id is passed to init_process_group.
4408
        self.assertEqual(backend.comm_split_count(), 1)
4409
        # dist.broadcast take Source rank on global process group
4410
        if self.rank < 2:
4411
            dist.broadcast(tensor, 0, group=ng1)
4412
            self.assertEqual(tensor, torch.full((1,), 0))
4413
        else:
4414
            dist.broadcast(tensor, 2, group=ng1)
4415
            self.assertEqual(tensor, torch.full((1,), 2))
4416

4417
        # test split with only one colored group, other ranks should be no color split.
4418
        ng2 = c10d.split_group(pg, [[5, 6, 7]])
4419
        self.assertEqual(backend.comm_split_count(), 2)
4420

4421
        if self.rank >= 5:
4422
            tensor2 = torch.full((1,), self.rank).cuda(device)
4423
            dist.broadcast(tensor2, 7, group=ng2)
4424
            self.assertEqual(tensor2, torch.full((1,), 7))
4425
        else:
4426
            self.assertEqual(ng2, None)
4427
        # a barrier and a cuda sync before destroying all pgs.
4428
        dist.barrier(pg)
4429
        torch.cuda.synchronize()
4430
        dist.destroy_process_group()
4431

4432
    @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit")
4433
    @skip_if_lt_x_gpu(8)
4434
    def test_comm_recursive_split_group(self):
4435
        store = c10d.FileStore(self.file_name, self.world_size)
4436
        device = torch.device(f"cuda:{self.rank}")
4437
        pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
4438
        backend = pg._get_backend(torch.device(device))
4439

4440
        # split the default PG into 2 subgroups, each subgroup (ng1) has 4 ranks.
4441
        tensor1 = torch.full((1,), self.rank).cuda(device)
4442
        ng1 = c10d.split_group(pg, [[0, 1, 2, 3], [4, 5, 6, 7]])
4443
        backend1 = ng1._get_backend(torch.device(device))
4444
        if self.rank < 4:
4445
            dist.broadcast(tensor1, 0, group=ng1)
4446
            self.assertEqual(tensor1, torch.full((1,), 0))
4447
        else:
4448
            dist.broadcast(tensor1, 4, group=ng1)
4449
            self.assertEqual(tensor1, torch.full((1,), 4))
4450

4451
        # comm split happens eagerly since device_id is passed to init_process_group.
4452
        self.assertEqual(backend.comm_split_count(), 1)
4453
        self.assertEqual(backend1.comm_split_count(), 0)
4454

4455
        # further split ng1 into 2 subgroups, each subgroup (ng2) has 2 ranks.
4456
        tensor2 = torch.full((1,), self.rank).cuda(device)
4457
        ng2 = c10d.split_group(ng1, [[0, 1], [2, 3]])
4458
        backend2 = ng2._get_backend(torch.device(device))
4459
        self.assertEqual(backend.comm_split_count(), 1)
4460
        self.assertEqual(backend1.comm_split_count(), 1)
4461
        self.assertEqual(backend2.comm_split_count(), 0)
4462

4463
        # execute collective calls within each 2-rank pg
4464
        if self.rank == 0 or self.rank == 1:
4465
            dist.broadcast(tensor2, 1, group=ng2)
4466
            self.assertEqual(tensor2, torch.full((1,), 1))
4467

4468
        if self.rank == 2 or self.rank == 3:
4469
            dist.broadcast(tensor2, 2, group=ng2)
4470
            self.assertEqual(tensor2, torch.full((1,), 2))
4471

4472
        if self.rank == 4 or self.rank == 5:
4473
            dist.broadcast(tensor2, 5, group=ng2)
4474
            self.assertEqual(tensor2, torch.full((1,), 5))
4475

4476
        if self.rank == 6 or self.rank == 7:
4477
            dist.broadcast(tensor2, 6, group=ng2)
4478
            self.assertEqual(tensor2, torch.full((1,), 6))
4479
        # a barrier and a cuda sync before destroying all pgs.
4480
        dist.barrier(pg)
4481
        torch.cuda.synchronize()
4482
        dist.destroy_process_group()
4483

4484

4485
if __name__ == "__main__":
4486
    assert (
4487
        not torch.cuda._initialized
4488
    ), "test_distributed must not have initialized CUDA context on main process"
4489

4490
    run_tests()
4491

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.