pytorch
2559 строк · 92.4 Кб
1# Owner(s): ["oncall: distributed"]
2
3import copy4import logging5import math6import operator7import os8import random9import sys10import tempfile11from datetime import timedelta12from functools import reduce13from itertools import groupby14
15import torch16import torch.distributed as c10d17
18if not c10d.is_available() or not c10d.is_gloo_available():19print("c10d GLOO not available, skipping tests", file=sys.stderr)20sys.exit(0)21
22import test_c10d_common23import torch.distributed as dist24import torch.nn.functional as F25import torch.testing._internal.common_utils as common26from test_c10d_common import (27gpus_for_rank,28LOOPBACK,29ModuleForDdpCommHook,30SparseGradientModule,31Task,32)
33from torch import nn34from torch.distributed._shard.sharded_tensor import (35init_from_local_shards,36Shard,37ShardedTensor,38ShardMetadata,39)
40from torch.nn.parallel import DistributedDataParallel41from torch.testing._internal.common_distributed import (42create_device,43MultiProcessTestCase,44requires_gloo,45simple_sparse_reduce_tests,46skip_if_lt_x_gpu,47skip_if_win32,48verify_ddp_error_logged,49)
50from torch.testing._internal.common_utils import (51retry_on_connect_failures,52run_tests,53skip_but_pass_in_sandcastle,54TestCase,55)
56
57
58def simple_reduce_tests(rank, world_size):59tests = [60(61c10d.ReduceOp.SUM,62torch.tensor([rank + 1.0]),63torch.tensor([float(world_size * (world_size + 1) / 2)]),64),65(66c10d.ReduceOp.PRODUCT,67torch.tensor([rank + 1.0]),68torch.tensor([float(math.factorial(world_size))]),69),70(71c10d.ReduceOp.MIN,72torch.tensor([rank + 1.0]),73torch.tensor([1.0]),74),75(76c10d.ReduceOp.MAX,77torch.tensor([rank + 1.0]),78torch.tensor([float(world_size)]),79),80]81
82# Generate tests for BAND.83# The bit that is set changes in every iteration to check84# that the output changes accordingly.85for i in range(4):86vin = rank | (1 << i)87vout = 1 << i88tests.append(89(90c10d.ReduceOp.BAND,91torch.tensor([vin], dtype=torch.int32),92torch.tensor([vout], dtype=torch.int32),93),94)95
96# Generate tests for BOR.97# These emulate a larger world size per iteration by having every98# rank contribute multiple values that are pre-OR'ed.99for i in range(1, 5):100vin = reduce(operator.or_, [rank * i + j for j in range(i)])101vout = reduce(operator.or_, range(world_size * i))102tests.append(103(104c10d.ReduceOp.BOR,105torch.tensor([vin], dtype=torch.int32),106torch.tensor([vout], dtype=torch.int32),107),108)109
110# Generate tests for XOR.111# These emulate a larger world size per iteration by having every112# rank contribute multiple values that are pre-XOR'ed.113for i in range(1, 5):114vin = reduce(operator.xor, [rank * i + j for j in range(i)])115vout = reduce(operator.xor, range(world_size * i))116tests.append(117(118c10d.ReduceOp.BXOR,119torch.tensor([vin], dtype=torch.int32),120torch.tensor([vout], dtype=torch.int32),121),122)123
124return tests125
126
127def simple_coalesced_reduce_tests(rank, world_size):128return [129(130c10d.ReduceOp.SUM,131[torch.tensor([rank + 1.0]), torch.tensor([(rank + 1.0) ** 2])],132[133torch.tensor([float(world_size * (world_size + 1) / 2)]),134torch.tensor(135[float(world_size * (world_size + 1) * (2 * world_size + 1) / 6)]136),137],138),139(140c10d.ReduceOp.PRODUCT,141[torch.tensor([rank + 1.0]), torch.tensor([rank + 2.0])],142[143torch.tensor([float(math.factorial(world_size))]),144torch.tensor([float(math.factorial(world_size + 1))]),145],146),147(148c10d.ReduceOp.MIN,149[torch.tensor([rank + x]) for x in [0.0, 1.0]],150[torch.tensor([0.0]), torch.tensor([1.0])],151),152(153c10d.ReduceOp.MAX,154[torch.tensor([rank + x]) for x in [1.0, 2.0]],155[torch.tensor([float(world_size)]), torch.tensor([world_size + 1.0])],156),157]158
159
160def simple_multi_input_reduce_tests(rank, world_size):161return [162(163c10d.ReduceOp.SUM,164[torch.tensor([2 * rank + 0.0]), torch.tensor([2 * rank + 1.0])],165torch.tensor([float(world_size * (2 * world_size - 1))]),166),167(168c10d.ReduceOp.PRODUCT,169[torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],170torch.tensor([float(math.factorial(2 * world_size))]),171),172(173c10d.ReduceOp.MIN,174[torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],175torch.tensor([1.0]),176),177(178c10d.ReduceOp.MAX,179[torch.tensor([2 * rank + 1.0]), torch.tensor([2 * rank + 2.0])],180torch.tensor([2.0 * world_size]),181),182]183
184
185class RendezvousEnvTest(TestCase):186@requires_gloo()187@retry_on_connect_failures188def test_logging_init(self):189os.environ["WORLD_SIZE"] = "1"190os.environ["MASTER_ADDR"] = "127.0.0.1"191os.environ["MASTER_PORT"] = str(common.find_free_port())192os.environ["RANK"] = "0"193
194previous_handlers = logging.root.handlers195
196c10d.init_process_group(backend="gloo", init_method="env://")197
198current_handlers = logging.root.handlers199self.assertEqual(len(previous_handlers), len(current_handlers))200for current, previous in zip(current_handlers, previous_handlers):201self.assertEqual(current, previous)202
203c10d.destroy_process_group()204
205
206class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):207@requires_gloo()208@retry_on_connect_failures209def test_default_store_timeout_gloo(self):210self._test_default_store_timeout("gloo")211
212
213class ProcessGroupGlooTest(MultiProcessTestCase):214def _create_process_group_gloo(self, store, rank, world_size, opts):215pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts)216dist.barrier(group=pg)217return pg218
219def setUp(self):220super().setUp()221self._spawn_processes()222
223def opts(self, threads=2):224opts = c10d.ProcessGroupGloo._Options()225opts._timeout = 50.0226opts._devices = [create_device(interface=LOOPBACK)]227opts._threads = threads228return opts229
230@requires_gloo()231def test_multi_device_constructor(self):232store = c10d.FileStore(self.file_name, self.world_size)233opts = c10d.ProcessGroupGloo._Options()234opts._timeout = 5.0235opts._devices = [236create_device(interface=LOOPBACK),237create_device(interface=LOOPBACK),238]239pg = self._create_process_group_gloo(store, self.rank, self.world_size, opts)240
241# Execute 2x the number of operations to ensure we use every device.242for fut in [pg.allreduce(torch.ones(i + 1)).get_future() for i in range(4)]:243fut.wait()244
245@requires_gloo()246def test_empty_tensors(self):247store = c10d.FileStore(self.file_name, self.world_size)248pg = self._create_process_group_gloo(249store, self.rank, self.world_size, self.opts()250)251
252xs = [torch.FloatTensor([])]253fut = pg.broadcast(xs).get_future()254fut.wait()255output = fut.value()256self.assertEqual(0, output[0].numel())257self.assertEqual(xs[0], output[0])258
259@requires_gloo()260def test_broadcast_checks(self):261store = c10d.FileStore(self.file_name, self.world_size)262pg = self._create_process_group_gloo(263store, self.rank, self.world_size, self.opts()264)265
266t1 = torch.zeros([1], dtype=torch.float32)267t2 = torch.zeros([1], dtype=torch.float64)268t3 = torch.zeros([2], dtype=torch.float32)269
270with self.assertRaisesRegex(RuntimeError, "invalid root rank"):271opts = c10d.BroadcastOptions()272opts.rootRank = -1273opts.rootTensor = 0274pg.broadcast([t1], opts)275
276with self.assertRaisesRegex(RuntimeError, "invalid root rank"):277opts = c10d.BroadcastOptions()278opts.rootRank = self.world_size279opts.rootTensor = 0280pg.broadcast([t1], opts)281
282with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):283opts = c10d.BroadcastOptions()284opts.rootRank = self.rank285opts.rootTensor = -1286pg.broadcast([t1], opts)287
288with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):289opts = c10d.BroadcastOptions()290opts.rootRank = self.rank291opts.rootTensor = 1292pg.broadcast([t1], opts)293
294with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):295opts = c10d.BroadcastOptions()296opts.rootRank = self.rank297opts.rootTensor = 0298pg.broadcast([], opts)299
300with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):301opts = c10d.BroadcastOptions()302opts.rootRank = self.rank303opts.rootTensor = 0304pg.broadcast([t1, t2], opts)305
306with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):307opts = c10d.BroadcastOptions()308opts.rootRank = self.rank309opts.rootTensor = 0310pg.broadcast([t1, t3], opts)311
312def _test_broadcast_basics(self, fn):313store = c10d.FileStore(self.file_name, self.world_size)314pg = self._create_process_group_gloo(315store, self.rank, self.world_size, self.opts()316)317
318def broadcast(xs, rootRank, rootTensor):319opts = c10d.BroadcastOptions()320opts.rootRank = rootRank321opts.rootTensor = rootTensor322fut = pg.broadcast(xs, opts).get_future()323fut.wait()324return fut.value()325
326# Every rank is root once327for i in range(self.world_size):328# Run with 1 input tensor329x = fn(torch.tensor([self.rank]))330output = broadcast([x], i, 0)331self.assertEqual(torch.tensor([i]), output[0])332
333# Run with 2 input tensors334num = 2335for j in range(num):336xs = [337fn(torch.tensor([self.rank * num + 0.0])),338fn(torch.tensor([self.rank * num + 1.0])),339]340
341output = broadcast(xs, i, j)342self.assertEqual(torch.tensor([i * num + j], dtype=torch.float32), output[0])343self.assertEqual(torch.tensor([i * num + j], dtype=torch.float32), output[1])344
345# Test overloaded convenience function346x = torch.tensor([self.rank + 1.0])347fut = pg.broadcast(x, root=0).get_future()348fut.wait()349result = fut.value()350self.assertEqual(torch.tensor([1.0]), result[0])351
352@requires_gloo()353def test_broadcast_basics(self):354self._test_broadcast_basics(lambda t: t.clone())355
356@skip_if_lt_x_gpu(2)357@requires_gloo()358def test_broadcast_basics_cuda(self):359self._test_broadcast_basics(lambda t: t.clone().cuda())360
361def _test_broadcast_stress(self, inputs):362store = c10d.FileStore(self.file_name, self.world_size)363pg = self._create_process_group_gloo(364store, self.rank, self.world_size, self.opts(threads=8)365)366work_handles = [367pg.broadcast(inputs[i], root=(i % self.world_size))368for i in range(len(inputs))369]370for i, work_handle in enumerate(work_handles):371work_handle.wait()372self.assertEqual(373torch.tensor([(i * self.world_size) + (i % self.world_size)]),374inputs[i],375msg=("Mismatch in iteration %d" % i),376)377
378@requires_gloo()379def test_broadcast_stress(self):380inputs = [torch.tensor([i * self.world_size + self.rank]) for i in range(1000)]381self._test_broadcast_stress(inputs)382
383@skip_if_lt_x_gpu(2)384@requires_gloo()385def test_broadcast_stress_cuda(self):386inputs = [387torch.tensor([i * self.world_size + self.rank]).cuda() for i in range(1000)388]389self._test_broadcast_stress(inputs)390
391@requires_gloo()392def test_allreduce_checks(self):393store = c10d.FileStore(self.file_name, self.world_size)394pg = self._create_process_group_gloo(395store, self.rank, self.world_size, self.opts()396)397
398t1 = torch.zeros([1], dtype=torch.float32)399t2 = torch.zeros([1], dtype=torch.float64)400t3 = torch.zeros([2], dtype=torch.float32)401
402with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):403opts = c10d.AllreduceOptions()404pg.allreduce([], opts)405
406with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):407opts = c10d.AllreduceOptions()408pg.allreduce([t1, t2], opts)409
410with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):411opts = c10d.AllreduceOptions()412pg.allreduce([t1, t3], opts)413
414def _test_allreduce_basics(self, fn):415store = c10d.FileStore(self.file_name, self.world_size)416pg = self._create_process_group_gloo(417store, self.rank, self.world_size, self.opts()418)419
420# Single input tests421tests = simple_reduce_tests(self.rank, self.world_size)422for (op, input, expected) in tests:423opts = c10d.AllreduceOptions()424opts.reduceOp = op425tensor = fn(input)426fut = pg.allreduce([tensor], opts).get_future()427fut.wait()428result = fut.value()429self.assertEqual(expected, result[0])430
431# Multi input tests432tests = simple_multi_input_reduce_tests(self.rank, self.world_size)433for (op, inputs, output) in tests:434opts = c10d.AllreduceOptions()435opts.reduceOp = op436tensors = [fn(input) for input in inputs]437fut = pg.allreduce(tensors, opts).get_future()438fut.wait()439result = fut.value()440for tensor in result:441self.assertEqual(output, tensor)442
443# Test overloaded convenience function (defaults to using sum)444x = fn(torch.tensor([self.rank + 1.0]))445fut = pg.allreduce(x).get_future()446fut.wait()447result = fut.value()448self.assertEqual(449torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),450result[0],451)452
453@requires_gloo()454def test_allreduce_basics(self):455self._test_allreduce_basics(lambda t: t.clone())456
457@skip_if_lt_x_gpu(2)458@requires_gloo()459def test_allreduce_basics_cuda(self):460self._test_allreduce_basics(lambda t: t.clone().cuda())461
462def _test_allreduce_stress(self, inputs):463store = c10d.FileStore(self.file_name, self.world_size)464pg = self._create_process_group_gloo(465store, self.rank, self.world_size, self.opts(threads=8)466)467future_handles = [468pg.allreduce(inputs[i]).get_future() for i in range(len(inputs))469]470for i, future_handle in enumerate(future_handles):471future_handle.wait()472self.assertEqual(473torch.tensor(474[475(i * self.world_size)476+ (self.world_size * (self.world_size - 1) // 2)477]478),479future_handle.value()[0],480msg=("Mismatch in iteration %d" % i),481)482
483@requires_gloo()484def test_allreduce_stress(self):485inputs = [torch.tensor([i + self.rank]) for i in range(1000)]486self._test_allreduce_stress(inputs)487
488@skip_if_lt_x_gpu(2)489@requires_gloo()490def test_allreduce_stress_cuda(self):491inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]492self._test_allreduce_stress(inputs)493
494@requires_gloo()495def test_allreduce_coalesced_checks(self):496store = c10d.FileStore(self.file_name, self.world_size)497pg = self._create_process_group_gloo(498store, self.rank, self.world_size, self.opts()499)500
501t1 = torch.zeros(1, dtype=torch.float32)502t2 = torch.zeros(1, dtype=torch.float64)503t3 = torch.sparse_coo_tensor([[0]], [1], size=(1,))504
505with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):506opts = c10d.AllreduceCoalescedOptions()507pg.allreduce_coalesced([], opts)508
509with self.assertRaisesRegex(RuntimeError, "tensors must all have the same type"):510opts = c10d.AllreduceCoalescedOptions()511pg.allreduce_coalesced([t1, t2], opts)512
513with self.assertRaisesRegex(RuntimeError, "invalid tensor layout at index"):514opts = c10d.AllreduceCoalescedOptions()515pg.allreduce_coalesced([t1, t3], opts)516
517with self.assertRaisesRegex(RuntimeError, "unsupported layout"):518opts = c10d.AllreduceCoalescedOptions()519pg.allreduce_coalesced([t3, t3.clone()], opts)520
521@skip_if_lt_x_gpu(1)522@requires_gloo()523def test_allreduce_coalesced_checks_cuda(self):524store = c10d.FileStore(self.file_name, self.world_size)525pg = self._create_process_group_gloo(526store, self.rank, self.world_size, self.opts()527)528
529t1 = torch.zeros(1, dtype=torch.float32)530
531with self.assertRaisesRegex(RuntimeError, "unsupported device type"):532opts = c10d.AllreduceCoalescedOptions()533pg.allreduce_coalesced([t1.cuda(), t1.cuda()], opts)534
535def _test_allreduce_coalesced_basics(self, fn):536store = c10d.FileStore(self.file_name, self.world_size)537pg = self._create_process_group_gloo(538store, self.rank, self.world_size, self.opts()539)540
541test_cases = simple_coalesced_reduce_tests(self.rank, self.world_size)542for op, inputs, outputs in test_cases:543opts = c10d.AllreduceCoalescedOptions()544opts.reduceOp = op545tensors = [fn(x) for x in inputs]546fut = pg.allreduce_coalesced(tensors, opts).get_future()547fut.wait()548result = fut.value()549for result_tensor, expected in zip(result, outputs):550self.assertEqual(result_tensor, expected)551
552@requires_gloo()553def test_allreduce_coalesced_basics(self):554self._test_allreduce_coalesced_basics(lambda t: t.clone())555
556def _expected_output(self, i):557ws = self.world_size558return 2 * [torch.tensor([(i * ws) + (ws * (ws - 1) // 2)])]559
560def _test_allreduce_coalesced_stress(self, inputs):561store = c10d.FileStore(self.file_name, self.world_size)562pg = self._create_process_group_gloo(563store, self.rank, self.world_size, self.opts(threads=8)564)565future_handles = [566pg.allreduce_coalesced(input).get_future() for input in inputs567]568for i, future_handle in enumerate(future_handles):569future_handle.wait()570result = future_handle.value()571self.assertEqual(572self._expected_output(i),573result,574msg=f"Mismatch in iteration {i}",575)576
577@requires_gloo()578def test_allreduce_coalesced_stress(self):579inputs = [2 * [torch.tensor([i + self.rank])] for i in range(1000)]580self._test_allreduce_coalesced_stress(inputs)581
582@requires_gloo()583def test_allreduce_coalesced_async(self):584store = c10d.FileStore(self.file_name, self.world_size)585c10d.init_process_group(586backend="gloo", rank=self.rank, world_size=self.world_size, store=store587)588
589xs = [2 * [torch.tensor([i + self.rank])] for i in range(2)]590futs = [c10d.all_reduce_coalesced(x, async_op=True) for x in xs]591torch.futures.wait_all(futs)592for i, fut in enumerate(futs):593self.assertEqual(594self._expected_output(i),595fut.wait(),596msg=f"Mismatch in iteration {i}",597)598
599@requires_gloo()600def test_sparse_allreduce_checks(self):601store = c10d.FileStore(self.file_name, self.world_size)602pg = self._create_process_group_gloo(603store, self.rank, self.world_size, self.opts()604)605
606t1 = torch.zeros([1])607t2 = torch.sparse_coo_tensor([[0]], [1], size=(2,))608t3 = torch.sparse_coo_tensor([[0]], [1], size=(4,))609
610with self.assertRaisesRegex(RuntimeError, "requires non-empty tensor list"):611opts = c10d.AllreduceOptions()612pg.allreduce([], opts)613
614with self.assertRaisesRegex(RuntimeError, "invalid tensor layout"):615opts = c10d.AllreduceOptions()616pg.allreduce([t1, t2], opts)617
618with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):619opts = c10d.AllreduceOptions()620pg.allreduce([t2, t3], opts)621
622# Sparse allreduce only works with c10d.ReduceOp.SUM.623for op in [c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX]:624with self.assertRaisesRegex(RuntimeError, "unsupported reduction operation"):625opts = c10d.AllreduceOptions()626opts.reduceOp = op627pg.allreduce([t3], opts)628
629def _test_sparse_allreduce_basics(self, fn):630store = c10d.FileStore(self.file_name, self.world_size)631pg = self._create_process_group_gloo(632store, self.rank, self.world_size, self.opts()633)634
635for num_inputs_per_rank in [1, 2]:636tests = simple_sparse_reduce_tests(637self.rank, self.world_size, num_inputs=num_inputs_per_rank638)639for (inputs, outputs) in tests:640tensors = [fn(input) for input in inputs]641fut = pg.allreduce(tensors).get_future()642fut.wait()643result = fut.value()644self.assertEqual(tensors, outputs)645self.assertEqual(result, outputs)646
647@requires_gloo()648def test_sparse_allreduce_basics(self):649self._test_sparse_allreduce_basics(lambda t: t)650
651@skip_if_lt_x_gpu(2)652@requires_gloo()653def test_sparse_allreduce_basics_cuda(self):654self._test_sparse_allreduce_basics(lambda t: t.clone().cuda())655
656@skip_if_lt_x_gpu(2)657@requires_gloo()658def test_sparse_allreduce_cuda_dispatched(self):659store = c10d.FileStore(self.file_name, self.world_size)660dist.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)661tests = simple_sparse_reduce_tests(662self.rank, self.world_size, num_inputs=1663)664for (inputs, outputs) in tests:665tensors = inputs[-1].clone().cuda()666work = dist.all_reduce(tensors, async_op=True)667work.wait()668self.assertEqual([tensors], outputs)669
670@requires_gloo()671def test_allgather_into_tensor_coalesced(self):672store = c10d.FileStore(self.file_name, self.world_size)673dist.init_process_group(674backend="gloo",675store=store,676rank=self.rank,677world_size=self.world_size,678)679torch.manual_seed(42)680in_shapes = [(5, 5), (10, 10), (15, 15)]681out_shapes = [(s[0] * self.world_size,) + s[1:] for s in in_shapes]682
683outputs = [torch.empty(s) for s in out_shapes]684inputs = [torch.rand(s) for s in in_shapes]685work = dist.group.WORLD.allgather_into_tensor_coalesced(outputs, inputs)686work.wait()687
688for output, input in zip(outputs, inputs):689expect = torch.cat([input] * self.world_size)690self.assertTrue(torch.allclose(output, expect))691
692@requires_gloo()693def test_reduce_scatter_tensor(self):694store = c10d.FileStore(self.file_name, self.world_size)695dist.init_process_group(696backend="gloo",697store=store,698rank=self.rank,699world_size=self.world_size,700)701torch.manual_seed(42)702out_shape = (20, 20)703in_shape = (out_shape[0] * self.world_size,) + out_shape[1:]704
705output = torch.empty(out_shape)706input = torch.rand(in_shape)707work = dist.reduce_scatter_tensor(output, input, async_op=True)708work.wait()709
710expect = input.view(self.world_size, *out_shape) \711.chunk(self.world_size)[self.rank] * self.world_size712self.assertTrue(torch.allclose(output, expect))713
714@requires_gloo()715def test_reduce_scatter_tensor_coalesced(self):716store = c10d.FileStore(self.file_name, self.world_size)717dist.init_process_group(718backend="gloo",719store=store,720rank=self.rank,721world_size=self.world_size,722)723torch.manual_seed(42)724out_shapes = [(5, 5), (10, 10), (15, 15)]725in_shapes = [(s[0] * self.world_size,) + s[1:] for s in out_shapes]726
727outputs = [torch.empty(s) for s in out_shapes]728inputs = [torch.rand(s) for s in in_shapes]729work = dist.group.WORLD.reduce_scatter_tensor_coalesced(outputs, inputs)730work.wait()731
732for output, input in zip(outputs, inputs):733expect = input.view(self.world_size, *output.shape) \734.chunk(self.world_size)[self.rank] * self.world_size735self.assertTrue(torch.allclose(output, expect))736
737@requires_gloo()738def test_scatter_checks(self):739store = c10d.FileStore(self.file_name, self.world_size)740pg = self._create_process_group_gloo(741store, self.rank, self.world_size, self.opts()742)743
744t1 = torch.zeros([1], dtype=torch.float32)745t2 = torch.zeros([1], dtype=torch.float64)746t3 = torch.zeros([2], dtype=torch.float32)747
748with self.assertRaisesRegex(RuntimeError, "invalid root rank"):749opts = c10d.ScatterOptions()750opts.rootRank = -1751pg.scatter([t1], [], opts)752
753with self.assertRaisesRegex(RuntimeError, "invalid root rank"):754opts = c10d.ScatterOptions()755opts.rootRank = self.world_size756pg.scatter([t1], [], opts)757
758with self.assertRaisesRegex(759RuntimeError, "requires a single-element output tensor list"760):761opts = c10d.ScatterOptions()762opts.rootRank = 0763pg.scatter([], [], opts)764
765with self.assertRaisesRegex(766RuntimeError, "requires a single-element output tensor list"767):768opts = c10d.ScatterOptions()769opts.rootRank = 0770pg.scatter([t1, t1], [], opts)771
772with self.assertRaisesRegex(RuntimeError, "requires a single-element input list"):773opts = c10d.ScatterOptions()774opts.rootRank = self.rank775pg.scatter([t1], [], opts)776
777with self.assertRaisesRegex(RuntimeError, "requires a single-element input list"):778opts = c10d.ScatterOptions()779opts.rootRank = self.rank780pg.scatter([t1], [[t1] * self.world_size, [t1] * self.world_size], opts)781
782desired_list_size = self.world_size783incorrect_list_size = self.world_size - 1784err_str = "Incorrect input list size {}. Input list size should be {}"785with self.assertRaisesRegex(786RuntimeError, err_str.format(incorrect_list_size, desired_list_size)787):788opts = c10d.ScatterOptions()789opts.rootRank = self.rank790pg.scatter([t1], [[t1] * incorrect_list_size], opts)791
792incorrect_list_size = self.world_size + 1793with self.assertRaisesRegex(794RuntimeError, err_str.format(incorrect_list_size, desired_list_size)795):796opts = c10d.ScatterOptions()797opts.rootRank = self.rank798pg.scatter([t1], [[t1] * incorrect_list_size], opts)799
800with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):801opts = c10d.ScatterOptions()802opts.rootRank = self.rank803pg.scatter([t1], [[t2] * self.world_size], opts)804
805with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):806opts = c10d.ScatterOptions()807opts.rootRank = self.rank808pg.scatter([t1], [[t3] * self.world_size], opts)809
810with self.assertRaisesRegex(RuntimeError, "requires empty input on non-root"):811opts = c10d.ScatterOptions()812opts.rootRank = (self.rank + 1) % self.world_size813pg.scatter([t1], [[t1] * self.world_size], opts)814
815def _test_scatter_basics(self, fn):816store = c10d.FileStore(self.file_name, self.world_size)817pg = self._create_process_group_gloo(818store, self.rank, self.world_size, self.opts()819)820
821# Preallocate tensors for input/output822input = [fn(torch.tensor([self.rank])) for _ in range(self.world_size)]823outputs = [fn(torch.tensor([-1])) for _ in range(self.world_size)]824
825# Take turns being the scatter root and accumulate work items826futures = []827for i in range(self.world_size):828opts = c10d.ScatterOptions()829opts.rootRank = i830if i == self.rank:831futures.append(pg.scatter([outputs[i]], [input], opts).get_future())832else:833futures.append(pg.scatter([outputs[i]], [], opts).get_future())834
835# Wait for work to complete836for i in range(self.world_size):837futures[i].wait()838result = futures[i].value()839self.assertEqual(torch.tensor([i]), result[0])840
841@requires_gloo()842def test_scatter_basics(self):843self._test_scatter_basics(lambda t: t.clone())844
845@skip_if_lt_x_gpu(2)846@requires_gloo()847def test_scatter_basics_cuda(self):848self._test_scatter_basics(lambda t: t.clone().cuda())849
850def _test_scatter_stress(self, inputs, fn):851store = c10d.FileStore(self.file_name, self.world_size)852pg = self._create_process_group_gloo(853store, self.rank, self.world_size, self.opts(threads=8)854)855outputs = [856[fn(torch.tensor([-1])) for _ in range(self.world_size)]857for _ in range(len(inputs))858]859future_handles = []860for i in range(len(inputs)):861for root in range(self.world_size):862opts = c10d.ScatterOptions()863opts.rootRank = root864if root == self.rank:865fut = pg.scatter(866[outputs[i][root]], [[fn(e) for e in inputs[i]]], opts867).get_future()868else:869fut = pg.scatter([outputs[i][root]], [], opts).get_future()870future_handles.append(fut)871
872for i, future_handle in enumerate(future_handles):873future_handle.wait()874iter = i // self.world_size875root = i % self.world_size876result = future_handle.value()877
878self.assertEqual(879torch.tensor([iter + root]),880result[0],881msg=("Mismatch in iteration %d for rank %d" % (iter, root)),882)883
884@requires_gloo()885def test_set_gloo_pg_timeout(self):886store = c10d.FileStore(self.file_name, self.world_size)887pg = self._create_process_group_gloo(888store, self.rank, self.world_size, self.opts()889)890pg.allreduce(torch.rand(10))891self.assertEqual(pg.options._timeout, timedelta(seconds=50))892pg._set_default_timeout(timedelta(seconds=23))893self.assertEqual(pg.options._timeout, timedelta(seconds=23))894
895@requires_gloo()896def test_scatter_stress(self):897inputs = [898[torch.tensor([i + self.rank]) for _ in range(self.world_size)]899for i in range(1000)900]901self._test_scatter_stress(inputs, lambda t: t.clone())902
903@skip_but_pass_in_sandcastle(904"Test is flaky, see https://github.com/pytorch/pytorch/issues/15963"905)906@skip_if_lt_x_gpu(2)907@requires_gloo()908def test_scatter_stress_cuda(self):909inputs = [910[torch.tensor([i + self.rank]) for _ in range(self.world_size)]911for i in range(1000)912]913self._test_scatter_stress(inputs, lambda t: t.clone().cuda())914
915@requires_gloo()916def test_gather_checks(self):917store = c10d.FileStore(self.file_name, self.world_size)918pg = self._create_process_group_gloo(919store, self.rank, self.world_size, self.opts()920)921
922t1 = torch.zeros([1], dtype=torch.float32)923t2 = torch.zeros([1], dtype=torch.float64)924t3 = torch.zeros([2], dtype=torch.float32)925
926with self.assertRaisesRegex(RuntimeError, "invalid root rank"):927opts = c10d.GatherOptions()928opts.rootRank = -1929pg.gather([], [t1], opts)930
931with self.assertRaisesRegex(RuntimeError, "invalid root rank"):932opts = c10d.GatherOptions()933opts.rootRank = self.world_size934pg.gather([], [t1], opts)935
936with self.assertRaisesRegex(937RuntimeError, "requires a single-element input tensor list"938):939opts = c10d.GatherOptions()940opts.rootRank = 0941pg.gather([], [], opts)942
943with self.assertRaisesRegex(944RuntimeError, "requires a single-element input tensor list"945):946opts = c10d.GatherOptions()947opts.rootRank = 0948pg.gather([], [t1, t1], opts)949
950with self.assertRaisesRegex(951RuntimeError, "requires a single-element output list"952):953opts = c10d.GatherOptions()954opts.rootRank = self.rank955pg.gather([], [t1], opts)956
957with self.assertRaisesRegex(958RuntimeError, "requires a single-element output list"959):960opts = c10d.GatherOptions()961opts.rootRank = self.rank962pg.gather([[t1] * self.world_size, [t1] * self.world_size], [t1], opts)963
964desired_list_size = self.world_size965incorrect_list_size = self.world_size - 1966err_str = "Incorrect output list size {}. Output list size should be {}"967with self.assertRaisesRegex(968RuntimeError, err_str.format(incorrect_list_size, desired_list_size)969):970opts = c10d.GatherOptions()971opts.rootRank = self.rank972pg.gather([[t1] * incorrect_list_size], [t1], opts)973
974incorrect_list_size = self.world_size + 1975with self.assertRaisesRegex(976RuntimeError, err_str.format(incorrect_list_size, desired_list_size)977):978opts = c10d.GatherOptions()979opts.rootRank = self.rank980pg.gather([[t1] * incorrect_list_size], [t1], opts)981
982with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):983opts = c10d.GatherOptions()984opts.rootRank = self.rank985pg.gather([[t2] * self.world_size], [t1], opts)986
987with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):988opts = c10d.GatherOptions()989opts.rootRank = self.rank990pg.gather([[t3] * self.world_size], [t1], opts)991
992with self.assertRaisesRegex(RuntimeError, "requires empty output on non-root"):993opts = c10d.GatherOptions()994opts.rootRank = (self.rank + 1) % self.world_size995pg.gather([[t1] * self.world_size], [t1], opts)996
997def _test_gather_basics(self, fn):998store = c10d.FileStore(self.file_name, self.world_size)999pg = self._create_process_group_gloo(1000store, self.rank, self.world_size, self.opts()1001)1002
1003# Preallocate tensors for input/output1004input = [fn(torch.tensor([self.rank]))]1005outputs = [fn(torch.tensor([-1])) for _ in range(self.world_size)]1006
1007# Take turns being the gather root and accumulate work items1008futures = []1009for i in range(self.world_size):1010opts = c10d.GatherOptions()1011opts.rootRank = i1012if i == self.rank:1013futures.append(pg.gather([outputs], input, opts).get_future())1014else:1015futures.append(pg.gather([], input, opts).get_future())1016
1017# Wait for work to complete1018expected = [fn(torch.tensor([rank])) for rank in range(self.world_size)]1019for i in range(self.world_size):1020futures[i].wait()1021result = futures[i].value()1022if i == self.rank:1023self.assertEqual(expected, result)1024
1025@requires_gloo()1026def test_gather_basics(self):1027self._test_gather_basics(lambda t: t.clone())1028
1029@skip_if_lt_x_gpu(2)1030@requires_gloo()1031def test_gather_basics_cuda(self):1032self._test_gather_basics(lambda t: t.clone().cuda())1033
1034@requires_gloo()1035def test_gather_noncontiguous_input(self):1036# Take a column of 2D tensor, such that memory is not dense1037self._test_gather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])1038
1039def _test_gather_stress(self, inputs, fn):1040store = c10d.FileStore(self.file_name, self.world_size)1041pg = self._create_process_group_gloo(1042store, self.rank, self.world_size, self.opts(threads=8)1043)1044future_handles = []1045outputs = [1046[[fn(torch.tensor([-1])) for _ in range(self.world_size)]]1047for _ in range(len(inputs))1048]1049expected_outputs = [1050[[torch.tensor([i + j]) for j in range(self.world_size)]]1051for i in range(len(inputs))1052]1053for i in range(len(inputs)):1054for root in range(self.world_size):1055opts = c10d.GatherOptions()1056opts.rootRank = root1057if root == self.rank:1058fut = pg.gather(outputs[i], [fn(inputs[i])], opts).get_future()1059else:1060fut = pg.gather([], [fn(inputs[i])], opts).get_future()1061future_handles.append(fut)1062
1063for i, future_handle in enumerate(future_handles):1064future_handle.wait()1065iter = i // self.world_size1066root = i % self.world_size1067if root == self.rank:1068result = future_handle.value()1069self.assertEqual(1070expected_outputs[iter],1071[result],1072msg=("Mismatch in iteration %d for root %d" % (iter, root)),1073)1074
1075@requires_gloo()1076def test_gather_stress(self):1077inputs = [torch.tensor([i + self.rank]) for i in range(1000)]1078self._test_gather_stress(inputs, lambda t: t.clone())1079
1080@skip_if_lt_x_gpu(2)1081@requires_gloo()1082def test_gather_stress_cuda(self):1083inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]1084self._test_gather_stress(inputs, lambda t: t.clone().cuda())1085
1086@requires_gloo()1087def test_allgather_checks(self):1088store = c10d.FileStore(self.file_name, self.world_size)1089pg = self._create_process_group_gloo(1090store, self.rank, self.world_size, self.opts()1091)1092
1093t1 = torch.zeros([1], dtype=torch.float32)1094t2 = torch.zeros([1], dtype=torch.float64)1095t3 = torch.zeros([2], dtype=torch.float32)1096
1097with self.assertRaisesRegex(RuntimeError, "requires non-empty input tensor list"):1098pg.allgather([], [])1099
1100with self.assertRaisesRegex(1101RuntimeError, "requires input/output tensor lists to have the same length"1102):1103pg.allgather([], [t1])1104
1105with self.assertRaisesRegex(1106RuntimeError, "requires input/output tensor lists to have the same length"1107):1108pg.allgather([[t1] * self.world_size, [t1] * self.world_size], [t1])1109
1110with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"):1111pg.allgather([[t1] * (self.world_size - 1)], [t1])1112
1113with self.assertRaisesRegex(RuntimeError, "invalid output tensor list"):1114pg.allgather([[t1] * (self.world_size + 1)], [t1])1115
1116with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):1117pg.allgather(1118[[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2]1119)1120
1121with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):1122pg.allgather(1123[[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3]1124)1125
1126with self.assertRaisesRegex(RuntimeError, "invalid tensor type"):1127pg.allgather([([t1, t2] * (self.world_size))[: self.world_size]], [t1])1128
1129with self.assertRaisesRegex(RuntimeError, "invalid tensor size"):1130pg.allgather([([t1, t3] * (self.world_size))[: self.world_size]], [t1])1131
1132def _test_allgather_basics(self, fn):1133store = c10d.FileStore(self.file_name, self.world_size)1134pg = self._create_process_group_gloo(1135store, self.rank, self.world_size, self.opts()1136)1137
1138# Run with N input tensor per rank1139for n in [1, 2, 3]:1140input = [fn(torch.tensor([n * self.rank + i])) for i in range(n)]1141output = [1142[fn(torch.tensor([-1])) for _ in range(n * self.world_size)]1143for _ in range(n)1144]1145expected_output = [1146[fn(torch.tensor([i])) for i in range(n * self.world_size)]1147for _ in range(n)1148]1149fut = pg.allgather(output, input).get_future()1150fut.wait()1151result = fut.value()1152if n == 1:1153result = [result]1154self.assertEqual(expected_output, result)1155
1156@requires_gloo()1157def test_allgather_basics(self):1158self._test_allgather_basics(lambda t: t.clone())1159
1160@skip_if_lt_x_gpu(2)1161@requires_gloo()1162def test_allgather_basics_cuda(self):1163self._test_allgather_basics(lambda t: t.clone().cuda())1164
1165@requires_gloo()1166def test_allgather_noncontiguous_input(self):1167# Take a column of 2D tensor, such that memory is not dense1168self._test_allgather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])1169
1170def _test_allgather_stress(self, inputs, fn):1171store = c10d.FileStore(self.file_name, self.world_size)1172pg = self._create_process_group_gloo(1173store, self.rank, self.world_size, self.opts(threads=8)1174)1175future_handles = []1176outputs = [1177[[fn(torch.tensor([-1])) for _ in range(self.world_size)]]1178for _ in range(len(inputs))1179]1180expected_outputs = [1181[[torch.tensor([i + j]) for j in range(self.world_size)]]1182for i in range(len(inputs))1183]1184input_holder = {}1185for i in range(len(inputs)):1186# Note that this works around the data race discussed in1187# https://github.com/pytorch/pytorch/issues/75529, but we should1188# actually be able to pass the list directly into allgather when1189# that race is fixed.1190input_holder[i] = [fn(inputs[i])]1191fut = pg.allgather(outputs[i], input_holder[i]).get_future()1192future_handles.append(fut)1193
1194for i, future_handle in enumerate(future_handles):1195future_handle.wait()1196result = future_handle.value()1197self.assertEqual(1198expected_outputs[i],1199[result],1200msg=("Mismatch in iteration %d" % i),1201)1202
1203@requires_gloo()1204def test_allgather_stress(self):1205inputs = [torch.tensor([i + self.rank]) for i in range(1000)]1206self._test_allgather_stress(inputs, lambda t: t.clone())1207
1208@skip_if_lt_x_gpu(2)1209@requires_gloo()1210def test_allgather_stress_cuda(self):1211inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]1212self._test_allgather_stress(inputs, lambda t: t.clone().cuda())1213
1214@requires_gloo()1215def test_allgather_coalesced_checks(self):1216store = c10d.FileStore(self.file_name, self.world_size)1217pg = self._create_process_group_gloo(1218store, self.rank, self.world_size, self.opts()1219)1220dummy_input = [torch.zeros([1], dtype=torch.float32)]1221dummy_output_lists = [1222[torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size)1223]1224
1225# One of output tensors does not match input list.1226dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)]1227with self.assertRaisesRegex(1228RuntimeError, "invalid size of output tensor at index 0"1229):1230c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)1231
1232# One of output tensors does not match input list.1233dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)]1234with self.assertRaisesRegex(RuntimeError, "invalid tensor type at index 0"):1235c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)1236
1237# Output lists have too many elements1238dummy_output_lists = [1239[torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size + 1)1240]1241with self.assertRaisesRegex(1242RuntimeError, "output lists should be equal to world size"1243):1244c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)1245
1246# Output is not a list of lists.1247dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]1248with self.assertRaisesRegex(1249TypeError, "Invalid function argument.*output_tensor_lists"1250):1251c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)1252
1253@requires_gloo()1254def test_allgather_coalesced_async(self):1255store = c10d.FileStore(self.file_name, self.world_size)1256c10d.init_process_group(1257backend="gloo", rank=self.rank, world_size=self.world_size, store=store1258)1259
1260xxs = [2 * [torch.tensor([i + self.rank])] for i in range(2)]1261yys = [[[torch.zeros_like(x) for x in xx] for _ in range(self.world_size)] for xx in xxs]1262futs = [c10d.all_gather_coalesced(yy, xx, async_op=True) for xx, yy in zip(xxs, yys)]1263
1264# expected outputs1265zzs = [[2 * [torch.tensor([i + r])] for r in range(self.world_size)] for i in range(2)]1266
1267torch.futures.wait_all(futs)1268for yy, zz in zip(yys, zzs):1269# one iteration1270for y_out, z_out in zip(yy, zz):1271# one output tensor list1272for y, z in zip(y_out, z_out):1273# one tensor in output tensor list1274self.assertEqual(y, z)1275
1276# Added to address https://github.com/pytorch/pytorch/issues/652311277# In the failed tests, all assertEqual are passed on all processes.1278# However, one of the processes didn't call ProcessGroupGloo1279# destructor before exiting program. This is not surprising as the only1280# guarantee that Python makes is that garbage collection MAY happen1281# before the program exits. If GC didn't happen, the two threads in1282# ProcessGroup might be destructed before joined.1283# FIXME: it's still unclear why only this test require explicit1284# destroy_process_group()1285c10d.destroy_process_group()1286
1287@requires_gloo()1288def test_reduce_checks(self):1289store = c10d.FileStore(self.file_name, self.world_size)1290pg = pg = self._create_process_group_gloo(1291store, self.rank, self.world_size, self.opts()1292)1293
1294t1 = torch.zeros([1], dtype=torch.float32)1295
1296with self.assertRaisesRegex(RuntimeError, "invalid root rank"):1297opts = c10d.ReduceOptions()1298opts.rootRank = -11299opts.rootTensor = 01300pg.reduce([t1], opts)1301
1302with self.assertRaisesRegex(RuntimeError, "invalid root rank"):1303opts = c10d.ReduceOptions()1304opts.rootRank = self.world_size1305opts.rootTensor = 01306pg.reduce([t1], opts)1307
1308with self.assertRaisesRegex(RuntimeError, "invalid root tensor"):1309opts = c10d.ReduceOptions()1310opts.rootRank = self.rank1311opts.rootTensor = 11312pg.reduce([t1], opts)1313
1314with self.assertRaisesRegex(1315RuntimeError, "requires a single-element tensor list"1316):1317opts = c10d.ReduceOptions()1318opts.rootRank = self.rank1319opts.rootTensor = 01320pg.reduce([t1, t1], opts)1321
1322def _test_reduce_basics(self, fn):1323store = c10d.FileStore(self.file_name, self.world_size)1324pg = self._create_process_group_gloo(1325store, self.rank, self.world_size, self.opts()1326)1327for (op, input, output) in simple_reduce_tests(self.rank, self.world_size):1328for root in range(self.world_size):1329opts = c10d.ReduceOptions()1330opts.reduceOp = op1331opts.rootRank = root1332tmp = fn(input)1333fut = pg.reduce([tmp], opts).get_future()1334fut.wait()1335result = fut.value()1336if root == self.rank:1337self.assertEqual(output, result[0])1338
1339@requires_gloo()1340def test_reduce_basics(self):1341self._test_reduce_basics(lambda t: t.clone())1342
1343@skip_if_lt_x_gpu(2)1344@requires_gloo()1345def test_reduce_basics_cuda(self):1346self._test_reduce_basics(lambda t: t.clone().cuda())1347
1348def _test_reduce_stress(self, inputs):1349store = c10d.FileStore(self.file_name, self.world_size)1350pg = self._create_process_group_gloo(1351store, self.rank, self.world_size, self.opts(threads=8)1352)1353future_handles = []1354outputs = []1355for i in range(len(inputs)):1356for root in range(self.world_size):1357opts = c10d.ReduceOptions()1358opts.rootRank = root1359tmp = inputs[i].clone()1360outputs.append(tmp)1361fut = pg.reduce([tmp], opts).get_future()1362future_handles.append(fut)1363
1364for i, future_handle in enumerate(future_handles):1365future_handle.wait()1366result = future_handle.value()1367iter = i // self.world_size1368root = i % self.world_size1369if root == self.rank:1370self.assertEqual(1371torch.tensor(1372[1373(iter * self.world_size)1374+ (self.world_size * (self.world_size - 1) // 2)1375]1376),1377result[0],1378msg=("Mismatch in iteration %d with root rank %d" % (iter, root)),1379)1380
1381@requires_gloo()1382def test_reduce_stress(self):1383inputs = [torch.tensor([i + self.rank]) for i in range(1000)]1384self._test_reduce_stress(inputs)1385
1386@skip_if_lt_x_gpu(2)1387@requires_gloo()1388def test_reduce_stress_cuda(self):1389inputs = [torch.tensor([i + self.rank]).cuda() for i in range(1000)]1390self._test_reduce_stress(inputs)1391
1392@requires_gloo()1393def test_send_recv_all_to_all(self):1394store = c10d.FileStore(self.file_name, self.world_size)1395pg = self._create_process_group_gloo(1396store, self.rank, self.world_size, self.opts()1397)1398
1399# Preallocate tensors for input/output1400inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]1401outputs = [torch.tensor([-1]) for _ in range(self.world_size)]1402
1403# Issue sends1404send_work = []1405for i in range(self.world_size):1406if i == self.rank:1407continue1408send_work.append(pg.send([inputs[i]], i, 0))1409
1410# Issue recvs1411recv_work = []1412for i in range(self.world_size):1413if i == self.rank:1414continue1415recv_work.append(pg.recv([outputs[i]], i, 0))1416
1417# Wait for sends to complete1418for work in send_work:1419work.wait()1420self.assertTrue(work.is_completed())1421
1422# Wait for recvs to complete1423for work in recv_work:1424work.wait()1425self.assertTrue(work.is_completed())1426
1427# Test that every output other than our own contains the respective rank1428for i in range(self.world_size):1429if i == self.rank:1430continue1431self.assertEqual(torch.tensor([i]), outputs[i])1432
1433@requires_gloo()1434def test_barrier_implies_wait(self):1435store = c10d.FileStore(self.file_name, self.world_size)1436pg = self._create_process_group_gloo(1437store, self.rank, self.world_size, self.opts()1438)1439
1440# Kick off allreduce operations1441size = (100, 100)1442num = 161443tensors = [torch.full(size, float(i)) for i in range(num)]1444for tensor in tensors:1445# Note: leak the returned work handle1446pg.allreduce(tensor)1447
1448# Barrier should ensure all previous work has completed1449pg.barrier().get_future().wait()1450
1451for i, tensor in enumerate(tensors):1452self.assertEqual(torch.full(size, float(i * self.world_size)), tensor)1453
1454@skip_if_win32()1455@requires_gloo()1456def test_round_robin(self):1457num_process_groups = 21458store = c10d.FileStore(self.file_name, self.world_size)1459c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)1460pg = c10d._round_robin_process_groups(1461[1462c10d.new_group(pg_options=self.opts())1463for i in range(num_process_groups)1464]1465)1466
1467# Run a few collectives so that we have called each process group1468for _ in range(num_process_groups + 1):1469tensor = torch.full([100, 100], float(self.rank))1470pg.broadcast(tensor, root=0).wait()1471self.assertEqual(torch.full([100, 100], 0.0), tensor)1472
1473@skip_if_win32()1474@requires_gloo()1475def test_round_robin_create_destroy(self):1476store = c10d.FileStore(self.file_name, self.world_size)1477c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)1478
1479def create(num, prefix):1480return c10d._round_robin_process_groups(1481[1482c10d.new_group(pg_options=self.opts())1483for i in range(num)1484]1485)1486
1487# Run create/use/destroy twice1488for i in range(2):1489num_process_groups = 21490pg = create(num=num_process_groups, prefix=i)1491for _ in range(3):1492tensor = torch.ones([10, 10])1493pg.allreduce(tensor).wait()1494self.assertEqual(torch.full([10, 10], float(self.world_size)), tensor)1495del pg1496
1497
1498class DistributedDataParallelTest(1499test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase1500):1501def setUp(self):1502super().setUp()1503self._spawn_processes()1504
1505def _get_process_group(self):1506store = self._get_store()1507c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)1508return c10d.distributed_c10d._get_default_group()1509
1510def _test_gloo_backend(1511self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False1512):1513store = c10d.FileStore(self.file_name, self.world_size)1514c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)1515process_group = c10d.distributed_c10d._get_default_group()1516device = devices[-1]1517backend = process_group._get_backend(device)1518backend.create_device(interface=LOOPBACK)1519self._test_ddp_with_process_group(1520process_group, devices, device_ids, multi_device, gradient_as_bucket_view1521)1522
1523@requires_gloo()1524def test_gloo_backend_cpu_module(self):1525self._test_gloo_backend([torch.device("cpu")], None)1526
1527@requires_gloo()1528def test_gloo_backend_cpu_module_grad_is_view(self):1529self._test_gloo_backend(1530[torch.device("cpu")], None, gradient_as_bucket_view=True1531)1532
1533@requires_gloo()1534@skip_if_lt_x_gpu(2)1535def test_gloo_backend_1gpu_module_device_ids_integer_list(self):1536int_devices = gpus_for_rank(self.world_size)[self.rank][:1]1537devices = [torch.device("cuda:" + str(i)) for i in int_devices]1538self._test_gloo_backend(devices, int_devices)1539
1540@requires_gloo()1541@skip_if_lt_x_gpu(2)1542def test_gloo_backend_1gpu_module_device_ids_torch_device_list(self):1543int_devices = gpus_for_rank(self.world_size)[self.rank][:1]1544devices = [torch.device("cuda:" + str(i)) for i in int_devices]1545self._test_gloo_backend(devices, devices)1546
1547@requires_gloo()1548@skip_if_lt_x_gpu(4)1549def test_gloo_backend_2gpu_module(self):1550int_devices = gpus_for_rank(self.world_size)[self.rank][:2]1551devices = [torch.device("cuda:" + str(i)) for i in int_devices]1552self._test_gloo_backend(devices, None, multi_device=True)1553
1554@requires_gloo()1555@skip_if_lt_x_gpu(8)1556def test_gloo_backend_4gpu_module(self):1557int_devices = gpus_for_rank(self.world_size)[self.rank][:4]1558devices = [torch.device("cuda:" + str(i)) for i in int_devices]1559self._test_gloo_backend(devices, None, multi_device=True)1560
1561def _test_global_local_unused_params_grad(1562self, gradient_as_bucket_view=False, static_graph=False1563):1564"""1565By simulating a multi-task training, this test is to make sure:
15661) DDP does not touch the grad of globally unused parameters.
15672) DDP does update the grad of locally unused parameters.
1568"""
1569
1570class GlobalLocalUnusedParamModule(nn.Module):1571def __init__(self):1572super().__init__()1573self.t0 = Task()1574self.t1 = Task()1575self.task_unused = Task()1576
1577def task_parameters(self):1578return (self.t0.p, self.t1.p, self.task_unused.p)1579
1580def forward(self, x, rank):1581return self.t0(x) if rank == 0 else self.t1(x)1582
1583def run_and_verify_grad(model):1584# Run forward1585output = model(8, self.rank)1586
1587# The grads of all parameters should be None at this point.1588t0_p, t1_p, task_unused_p = model.module.task_parameters()1589self.assertIsNone(t0_p.grad)1590self.assertIsNone(t1_p.grad)1591self.assertIsNone(task_unused_p.grad)1592
1593# Run backward1594output.mean().backward()1595
1596# Now locally unused parameter should have grad updated on all ranks.1597# However the globally unused parameter should still have None grad.1598self.assertIsNotNone(t0_p.grad)1599self.assertIsNotNone(t1_p.grad)1600self.assertIsNone(task_unused_p.grad)1601
1602process_group = self._get_process_group()1603
1604# Test on CPU1605cpu_model = DistributedDataParallel(1606GlobalLocalUnusedParamModule().cpu(),1607process_group=process_group,1608find_unused_parameters=True,1609gradient_as_bucket_view=gradient_as_bucket_view,1610static_graph=static_graph,1611)1612run_and_verify_grad(cpu_model)1613
1614# Test on GPU1615device_id = gpus_for_rank(self.world_size)[self.rank][0]1616gpu_model = DistributedDataParallel(1617GlobalLocalUnusedParamModule().to(device_id),1618device_ids=[device_id],1619process_group=process_group,1620find_unused_parameters=True,1621gradient_as_bucket_view=gradient_as_bucket_view,1622static_graph=static_graph,1623)1624run_and_verify_grad(gpu_model)1625
1626@requires_gloo()1627@skip_if_lt_x_gpu(2)1628def test_global_local_unused_params_grad(self):1629self._test_global_local_unused_params_grad()1630
1631@requires_gloo()1632@skip_if_lt_x_gpu(2)1633def test_global_local_unused_params_grad_with_grad_is_view(self):1634self._test_global_local_unused_params_grad(gradient_as_bucket_view=True)1635
1636@requires_gloo()1637@skip_if_lt_x_gpu(2)1638def test_global_local_unused_params_grad_with_static_graph(self):1639self._test_global_local_unused_params_grad(static_graph=True)1640
1641@requires_gloo()1642@skip_if_lt_x_gpu(2)1643def test_find_unused_parameters_when_unused_parameters_empty(self):1644"""1645An empty unused_parameters array does not imply find_unused_parameters =
1646false. This test makes sure that DDP allreduces unused parameters
1647accordingly where the forward pass in some process uses all parameters.
1648This unit test creates a module that uses all parameters in rank = 0, and
1649has unused parameters in other ranks.
1650"""
1651
1652class FindUnusedParamModule(nn.Module):1653def __init__(self):1654super().__init__()1655self.t0 = Task()1656self.t1 = Task()1657
1658def task_parameters(self):1659return (self.t0.p, self.t1.p)1660
1661def forward(self, x, rank):1662return self.t1(self.t0(x)) if rank == 0 else self.t1(x)1663
1664def run_and_verify_grad(model):1665# Run forward1666output = model(8, self.rank)1667
1668# The grads of all parameters should be None at this point.1669[self.assertIsNone(t_p.grad) for t_p in model.module.task_parameters()]1670
1671# Run backward1672output.mean().backward()1673
1674# Now locally unused parameter should have grad updated on all ranks.1675[self.assertIsNotNone(t_p.grad) for t_p in model.module.task_parameters()]1676
1677process_group = self._get_process_group()1678
1679# Test on CPU1680cpu_model = DistributedDataParallel(1681FindUnusedParamModule().cpu(),1682process_group=process_group,1683find_unused_parameters=True,1684)1685run_and_verify_grad(cpu_model)1686
1687# Test on GPU1688device_id = gpus_for_rank(self.world_size)[self.rank][0]1689gpu_model = DistributedDataParallel(1690FindUnusedParamModule().to(device_id),1691device_ids=[device_id],1692process_group=process_group,1693find_unused_parameters=True,1694)1695run_and_verify_grad(gpu_model)1696
1697@requires_gloo()1698def test_ignored_output(self):1699"""1700Test that the output of a model can be ignored and that there is no
1701implicit requirement that `backward` gets called.
1702"""
1703process_group = self._get_process_group()1704
1705class IgnoredOutput(nn.Module):1706def __init__(self):1707super().__init__()1708self.fc1 = nn.Linear(2, 10, bias=False)1709self.fc2 = nn.Linear(10, 4, bias=False)1710self.relu = nn.ReLU()1711
1712def forward(self, x):1713x = self.relu(self.fc1(x))1714x = self.relu(self.fc2(x))1715return F.softmax(x, dim=1)1716
1717model = DistributedDataParallel(1718IgnoredOutput().float(),1719process_group=process_group,1720)1721
1722batch_size = 41723criterion = nn.CrossEntropyLoss()1724input = torch.rand([batch_size, 2], dtype=torch.float)1725target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])1726
1727# Run a few iterations where we ignore the output.1728for _ in range(4):1729output = model(input)1730del output1731
1732# Run a few iterations where we use the output.1733for _ in range(4):1734output = model(input)1735loss = criterion(output, target)1736loss.backward()1737
1738@requires_gloo()1739def test_ignored_output_with_unused_parameters(self):1740"""1741Test that the output of a model can be ignored and that there is no
1742implicit requirement that `backward` gets called, if not all model
1743parameters participated in computing the model output.
1744"""
1745process_group = self._get_process_group()1746
1747class IgnoredOutputWithUnusedParameters(nn.Module):1748def __init__(self):1749super().__init__()1750self.fc1 = nn.Linear(2, 10, bias=False)1751self.fc2 = nn.Linear(10, 4, bias=False)1752self.fc3 = nn.Linear(4, 4, bias=False)1753self.relu = nn.ReLU()1754
1755def forward(self, x):1756x = self.relu(self.fc1(x))1757x = self.relu(self.fc2(x))1758return F.softmax(x, dim=1)1759
1760model = DistributedDataParallel(1761IgnoredOutputWithUnusedParameters().float(),1762process_group=process_group,1763find_unused_parameters=True,1764)1765
1766batch_size = 41767criterion = nn.CrossEntropyLoss()1768input = torch.rand([batch_size, 2], dtype=torch.float)1769target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])1770
1771# Run a few iterations where we ignore the output.1772for _ in range(4):1773output = model(input)1774del output1775
1776# Run a few iterations where we use the output.1777for _ in range(4):1778output = model(input)1779loss = criterion(output, target)1780loss.backward()1781
1782@requires_gloo()1783@skip_if_lt_x_gpu(2)1784def test_ignored_sharded_tensor(self):1785class MyModule(nn.Module):1786def __init__(self, shard_tensor: ShardedTensor) -> None:1787super().__init__()1788self.fc1 = nn.Linear(2, 10, bias=False)1789self.st = nn.Parameter(shard_tensor)1790self.relu = nn.ReLU()1791
1792def forward(self, x):1793x = self.relu(self.fc1(x))1794return F.softmax(x, dim=1)1795pg = dist.init_process_group(1796"gloo",1797init_method=f"file://{self.file_name}",1798world_size=self.world_size,1799rank=self.rank,1800)1801device = torch.device(f"cuda:{self.rank}")1802local_shard_metadata = ShardMetadata(1803shard_offsets=[(self.rank % 2) * 5, 0],1804shard_sizes=[5, 10],1805placement=f"rank:{self.rank}/cuda:{self.rank}"1806)1807local_shards = [Shard(torch.randn(5, 10, device=device), local_shard_metadata)]1808st = init_from_local_shards(local_shards, [10, 10])1809m = MyModule(st)1810DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(1811module=m,1812params_and_buffers_to_ignore={'st'}1813)1814# test to make DDP constructor will not fail when module includes a ShardedTensor when ignored1815DistributedDataParallel(1816m,1817device_ids=[device] if device.type == "gpu" else None,1818process_group=pg,1819gradient_as_bucket_view=True,1820broadcast_buffers=False,1821static_graph=True,1822)1823
1824def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):1825mult = 21826batch_size = mult * self.world_size1827criterion = nn.CrossEntropyLoss()1828input = torch.randint(0, 10, [batch_size, 2])1829target = torch.randint(0, 10, [batch_size])1830
1831# Run with entire batch against single process version1832criterion(vanilla_model(input), target).backward()1833
1834# Run with partial batch against multi process version1835partial_input = input.split(mult)[self.rank]1836partial_target = target.split(mult)[self.rank]1837criterion(ddp_model(partial_input), partial_target).backward()1838
1839# Check that the gradients are sparse and identical1840vanilla_parameter = next(vanilla_model.parameters())1841ddp_parameter = next(ddp_model.parameters())1842self.assertEqual(vanilla_parameter.grad.coalesce(), ddp_parameter.grad.coalesce())1843
1844@requires_gloo()1845@skip_if_lt_x_gpu(2)1846def test_save_load_checkpoint(self):1847dist.init_process_group(1848"gloo",1849init_method=f"file://{self.file_name}",1850world_size=self.world_size,1851rank=self.rank,1852)1853
1854class TestModel(nn.Module):1855def __init__(self):1856super().__init__()1857self.fc1 = nn.Linear(2, 10, bias=False)1858self.fc2 = nn.Linear(10, 4, bias=False)1859self.relu = nn.ReLU()1860
1861def forward(self, x):1862x = self.relu(self.fc1(x))1863x = self.relu(self.fc2(x))1864return F.softmax(x, dim=1)1865
1866def train_loop(model, optimizer, iterations):1867for _ in range(iterations):1868optimizer.zero_grad()1869output = model(input)1870loss = criterion(output, target)1871loss.backward()1872optimizer.step()1873
1874device_id = gpus_for_rank(self.world_size)[self.rank][0]1875
1876model_withload = TestModel().float().to(device_id)1877model_withoutload = TestModel().float().to(device_id)1878
1879ddp_withload = DistributedDataParallel(1880model_withload,1881device_ids=[device_id],1882)1883ddp_withoutload = DistributedDataParallel(1884model_withoutload,1885device_ids=[device_id],1886)1887
1888# ensure that all the three models start with the same set of parameters. By default they are randomized on construction1889for p in ddp_withload.parameters():1890with torch.no_grad():1891p.zero_()1892for p in model_withload.parameters():1893with torch.no_grad():1894p.zero_()1895for p in ddp_withoutload.parameters():1896with torch.no_grad():1897p.zero_()1898
1899batch_size = 41900criterion = nn.CrossEntropyLoss()1901
1902optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)1903optimizer_non_ddp_withload = torch.optim.SGD(1904model_withload.parameters(), lr=0.0011905)1906optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)1907
1908input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)1909target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(1910device_id
1911)1912
1913# run the model for 6 iterations, with a checkpoint in the middle1914train_loop(ddp_withload, optimizer_withload, 3)1915
1916# zero out parameters of both DDP and non-DDP models and reload them from the DDP state dict1917checkpoint_path = tempfile.gettempdir() + "/model.checkpoint"1918if self.rank == 0:1919torch.save(ddp_withload.state_dict(), checkpoint_path)1920
1921dist.barrier()1922map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}1923ddp_state_dict = torch.load(checkpoint_path, map_location=map_location)1924
1925for model in [ddp_withload, model_withload]:1926for p in ddp_withload.parameters():1927with torch.no_grad():1928p.zero_()1929ddp_withload.load_state_dict(ddp_state_dict)1930# the non-DDP model needs to first remove the prefix of "module." from the DDP state dict1931torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(1932ddp_state_dict, "module."1933)1934model_withload.load_state_dict(ddp_state_dict)1935
1936train_loop(ddp_withload, optimizer_withload, 3)1937train_loop(model_withload, optimizer_non_ddp_withload, 3)1938
1939# re-run the model with the same inputs for 6 iterations with no checkpoint1940train_loop(ddp_withoutload, optimizer_withoutload, 6)1941
1942for p_withload, p_withoutload, p_non_ddp_withload in zip(1943ddp_withload.parameters(),1944ddp_withoutload.parameters(),1945model_withload.parameters(),1946):1947self.assertEqual(p_withload, p_withoutload)1948self.assertEqual(p_non_ddp_withload, p_withoutload)1949
1950def _test_sparse_gradients(self, gradient_as_bucket_view=False):1951process_group = self._get_process_group()1952
1953# Ensure initialized weights and inputs are identical across processes1954torch.manual_seed(1337)1955
1956vanilla_model = SparseGradientModule()1957ddp_model = DistributedDataParallel(1958copy.deepcopy(vanilla_model),1959process_group=process_group,1960gradient_as_bucket_view=gradient_as_bucket_view,1961)1962
1963self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)1964
1965@requires_gloo()1966def test_sparse_gradients(self):1967self._test_sparse_gradients()1968
1969@requires_gloo()1970def test_sparse_gradients_grad_is_view(self):1971self._test_sparse_gradients(gradient_as_bucket_view=True)1972
1973@requires_gloo()1974def test_ddp_comm_hook_future_passing_cpu(self):1975"""1976This unit test verifies whether the Future object is passed properly.
1977The callback function creates a Future object and sets a value to it.
1978"""
1979store = c10d.FileStore(self.file_name, self.world_size)1980process_group = self._get_process_group()1981
1982# Test on CPU1983cpu_model = DistributedDataParallel(1984ModuleForDdpCommHook().cpu(), process_group=process_group1985)1986
1987# Register DDP Communication Hook1988cpu_model.register_comm_hook(None, self._simple_hook)1989
1990# check whether the grads are equal to what then callback returns.1991# without the comm_hook, result would be 0.25 * torch.ones(2, 2).1992self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))1993
1994def _gpu_model_with_ddp_comm_hook(1995self, process_group, hook=None, gradient_as_bucket_view=False, state=None1996):1997device_id = gpus_for_rank(self.world_size)[self.rank][0]1998gpu_model = DistributedDataParallel(1999ModuleForDdpCommHook().to(device_id),2000device_ids=[device_id],2001process_group=process_group,2002gradient_as_bucket_view=gradient_as_bucket_view,2003)2004
2005# Register a DDP communication hook if any.2006if hook is not None:2007gpu_model.register_comm_hook(state, hook)2008
2009return gpu_model2010
2011@requires_gloo()2012@skip_if_lt_x_gpu(2)2013def test_ddp_comm_hook_future_passing_gpu_gloo(self):2014"""2015This unit test verifies whether the Future object is passed properly using gloo backend.
2016The hook callback function creates a Future object and sets a value to it.
2017"""
2018process_group = self._get_process_group()2019
2020# Get GPU model with simple_hook registered.2021gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, self._simple_hook)2022
2023# check whether the grads are equal to what simple_hook's then callback returns.2024# without the comm_hook, result would be 0.25 * torch.ones(2, 2).2025self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))2026
2027@requires_gloo()2028def test_ddp_invalid_comm_hook_init(self):2029"""2030This unit test makes sure that register_comm_hook properly checks the format
2031of hook defined by user. The Python hook must be callable. This test also
2032checks whether bucket annotation checked properly if defined.
2033"""
2034process_group = self._get_process_group()2035
2036model = DistributedDataParallel(2037ModuleForDdpCommHook(), process_group=process_group2038)2039
2040with self.assertRaisesRegex(TypeError, "Communication hook must be callable."):2041model.register_comm_hook(state=None, hook=1)2042
2043with self.assertRaisesRegex(2044ValueError, "bucket annotation should be dist.GradBucket."2045):2046
2047def comm_hook(2048state: object, bucket: int2049) -> torch.futures.Future[torch.Tensor]:2050return torch.futures.Future()2051
2052model.register_comm_hook(state=None, hook=comm_hook)2053
2054@requires_gloo()2055def test_ddp_invalid_comm_hook_return_type(self):2056"""2057This test checks whether return annotation checked properly if defined. It also
2058checks whether an internal error is thrown if return type is incorrect and user
2059hasn't specified any return type annotation.
2060"""
2061process_group = self._get_process_group()2062
2063model = DistributedDataParallel(2064ModuleForDdpCommHook(), process_group=process_group2065)2066
2067expected_err = "Communication hook: return annotation should be torch.futures.Future"2068with self.assertRaisesRegex(2069ValueError,2070expected_err,2071):2072
2073def comm_hook(state: object, bucket: dist.GradBucket) -> int:2074return torch.futures.Future()2075
2076model.register_comm_hook(state=None, hook=comm_hook)2077
2078verify_ddp_error_logged(model, expected_err)2079
2080with self.assertRaisesRegex(2081RuntimeError,2082"callback must return a torch.futures.Future object, but got",2083):2084
2085def comm_hook(state: object, bucket: dist.GradBucket):2086return 12087
2088model.register_comm_hook(state=None, hook=comm_hook)2089
2090# Run forward2091output = model(8, self.rank)2092
2093# Run backward2094output.mean().backward()2095
2096@requires_gloo()2097def test_ddp_comm_hook_register_just_once(self):2098"""2099DDP communication hook can only be registered once. This test validates whether
2100the error is thrown properly when register_comm_hook is called more than once.
2101"""
2102process_group = self._get_process_group()2103
2104model = DistributedDataParallel(2105ModuleForDdpCommHook(), process_group=process_group2106)2107
2108def dummy_hook(state, bucket):2109fut = torch.futures.Future()2110fut.set_result([bucket.buffer()])2111return fut2112
2113model.register_comm_hook(None, dummy_hook)2114
2115with self.assertRaisesRegex(2116RuntimeError,2117"register_comm_hook or register_builtin_comm_hook can only be called once.",2118):2119model.register_comm_hook(None, dummy_hook)2120
2121@requires_gloo()2122def test_ddp_comm_hook_sparse_gradients(self):2123"""2124Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
2125simple hook that does allreduce and works with gloo backend for this test.
2126"""
2127process_group = self._get_process_group()2128
2129# Ensure initialized weights and inputs are identical across processes2130torch.manual_seed(1337)2131
2132vanilla_model = SparseGradientModule()2133ddp_model = DistributedDataParallel(2134copy.deepcopy(vanilla_model),2135process_group=process_group,2136)2137
2138def allreduce_hook_gloo(2139state: object, bucket: dist.GradBucket2140) -> torch.futures.Future[torch.Tensor]:2141def div_by_world_size(fut):2142# Divide the result by 2 * world_size.2143return fut.wait()[0] / self.world_size2144
2145# Prepare allreduced grad bucket tensors by running an async work.2146fut = process_group.allreduce([bucket.buffer()]).get_future()2147return fut.then(div_by_world_size)2148
2149ddp_model.register_comm_hook(None, allreduce_hook_gloo)2150
2151self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)2152
2153
2154class ReducerModule(nn.Module):2155def __init__(self):2156super().__init__()2157self.fc1 = nn.Linear(2, 10, bias=False)2158self.fc2 = nn.Linear(10, 4, bias=False)2159self.fc3 = nn.Linear(4, 4, bias=False)2160self.relu = nn.ReLU()2161
2162def forward(self, x, use_fc3=True):2163x = self.relu(self.fc1(x)).float()2164x = self.relu(self.fc2(x)).float()2165if use_fc3:2166x = self.fc3(x).float()2167return F.softmax(x, dim=1)2168
2169
2170class ReducerTest(TestCase):2171def setUp(self):2172self.file = tempfile.NamedTemporaryFile(delete=False)2173world_size = 12174self.store = c10d.FileStore(self.file.name, world_size)2175c10d.init_process_group(backend="gloo", store=self.store, rank=0, world_size=world_size)2176self.process_group = c10d.distributed_c10d._get_default_group()2177
2178def tearDown(self):2179c10d.destroy_process_group()2180try:2181os.remove(self.file.name)2182except OSError as e:2183print(str(e))2184pass2185
2186@requires_gloo()2187def test_single_dtype_single_bucket(self):2188model = ReducerModule()2189parameters = list(model.parameters())2190buckets = [list(range(len(parameters)))]2191dist.Reducer(parameters, buckets, [dist._DEFAULT_FIRST_BUCKET_BYTES], self.process_group)2192
2193def _create_mixed_precision_model(self):2194model = ReducerModule()2195model.float()2196model.fc1.double()2197return model2198
2199@requires_gloo()2200def test_multi_dtype_single_bucket(self):2201model = self._create_mixed_precision_model()2202
2203# Raise if there are multiple types per bucket.2204# In this case we create one bucket for all parameters.2205with self.assertRaises(RuntimeError):2206parameters = list(model.parameters())2207buckets = [list(range(len(parameters)))]2208dist.Reducer(2209parameters,2210buckets,2211[dist._DEFAULT_FIRST_BUCKET_BYTES],2212self.process_group2213)2214
2215@requires_gloo()2216def test_multi_dtype_multi_bucket(self):2217model = self._create_mixed_precision_model()2218parameters = list(model.parameters())2219group_by_dtype = groupby(2220range(len(parameters)), key=lambda i: parameters[i].dtype2221)2222buckets = [list(indices) for _, indices in group_by_dtype]2223dist.Reducer(2224parameters,2225buckets,2226[dist._DEFAULT_FIRST_BUCKET_BYTES for _ in buckets],2227self.process_group2228)2229
2230def _create_reducer_for_models(self, models, find_unused_parameters=False):2231self.assertEqual(len(models), 1)2232parameters = list(models[0].parameters())2233group_by_dtype = groupby(2234range(len(parameters)), key=lambda i: parameters[i].dtype2235)2236buckets = [list(indices) for _, indices in group_by_dtype]2237return dist.Reducer(2238parameters,2239buckets,2240[dist._DEFAULT_FIRST_BUCKET_BYTES for _ in range(len(buckets))],2241self.process_group,2242find_unused_parameters=find_unused_parameters,2243)2244
2245@requires_gloo()2246def test_forward_backward(self):2247batch_size = 102248model = self._create_mixed_precision_model()2249reducer = self._create_reducer_for_models([model])2250reducer.prepare_for_forward()2251loss = nn.CrossEntropyLoss()2252input = torch.rand([batch_size, 2], dtype=torch.double)2253target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])2254output = loss(model(input), target)2255reducer.prepare_for_backward(output)2256output.backward()2257
2258@requires_gloo()2259def test_forward_backward_unused_parameters(self):2260batch_size = 102261model = self._create_mixed_precision_model()2262reducer = self._create_reducer_for_models([model], find_unused_parameters=True)2263reducer.prepare_for_forward()2264loss = nn.CrossEntropyLoss()2265input = torch.rand([batch_size, 2], dtype=torch.double)2266target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])2267output = loss(model(input, use_fc3=False), target)2268
2269# Check that the grad of fc3 is not set.2270self.assertEqual(None, model.fc3.weight.grad)2271
2272# Compute and accumulate gradients.2273reducer.prepare_for_backward(output)2274output.backward()2275
2276# The reducer will have marked the grad of fc3 as ready, because2277# it doesn't show up in the autograd graph of `output`. Since fc3.weight2278# is considered being globally unused, it will be kept untouched as None.2279self.assertEqual(None, model.fc3.weight.grad)2280
2281@requires_gloo()2282def test_forward_backward_optimizer(self):2283batch_size = 102284model = self._create_mixed_precision_model()2285reducer = self._create_reducer_for_models([model], find_unused_parameters=True)2286reducer.prepare_for_forward()2287loss = nn.CrossEntropyLoss()2288optimizer = torch.optim.Adam(model.parameters())2289for i in range(3):2290input = torch.rand([batch_size, 2], dtype=torch.double)2291target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])2292
2293# The `zero_grad` function calls `detach_` and `zero_` on the grad2294# tensors of model parameters. If we tried to set the grad tensors2295# to a view of the reducer's bucket tensors, this would blow up.2296optimizer.zero_grad()2297
2298# Unused parameter only in the first iteration.2299output = loss(model(input, use_fc3=(i > 0)), target)2300reducer.prepare_for_backward(output)2301output.backward()2302optimizer.step()2303
2304
2305class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):2306@property2307def device(self):2308return "cpu"2309
2310
2311def setUp(self):2312super().setUp()2313self._spawn_processes()2314
2315def tearDown(self):2316super().tearDown()2317try:2318os.remove(self.file_name)2319except OSError:2320pass2321
2322def _test_broadcast_coalesced(self, process_group, device, root_rank):2323half = torch.float162324
2325# No support for float16 for CPU tensors2326if device == torch.device("cpu"):2327half = torch.float322328
2329target = torch.arange(60, dtype=half, device=device).chunk(5)2330target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)2331target += torch.arange(60, dtype=half, device=device).chunk(5)2332target += torch.arange(60, dtype=torch.float64, device=device).chunk(5)2333target += torch.arange(60, dtype=half, device=device).chunk(5)2334target += torch.arange(60, dtype=torch.float32, device=device).chunk(5)2335
2336# The tensors to pass to broadcast are identical to the target2337# only on the process that is the root of the broadcast.2338if self.rank == root_rank:2339tensors = [tensor.clone() for tensor in target]2340else:2341tensors = [torch.zeros_like(tensor) for tensor in target]2342
2343if self.rank != root_rank:2344self.assertNotEqual(tensors, target)2345
2346c10d._broadcast_coalesced(2347process_group, tensors, buffer_size=256, src=root_rank2348)2349
2350if self.rank != root_rank:2351self.assertEqual(tensors, target)2352
2353@requires_gloo()2354@skip_if_lt_x_gpu(2)2355def test_broadcast_coalesced_gloo_cuda(self):2356store = c10d.FileStore(self.file_name, self.world_size)2357c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)2358process_group = c10d.distributed_c10d._get_default_group()2359device = torch.device("cuda:%d" % self.rank)2360backend = process_group._get_backend(device)2361backend.create_device(interface=LOOPBACK)2362ranks = list(range(self.world_size))2363for root_rank in ranks:2364self._test_broadcast_coalesced(process_group, device, root_rank)2365
2366@requires_gloo()2367def test_broadcast_coalesced_gloo_cpu(self):2368store = c10d.FileStore(self.file_name, self.world_size)2369c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)2370process_group = c10d.distributed_c10d._get_default_group()2371device = torch.device("cpu")2372backend = process_group._get_backend(device)2373backend.create_device(interface=LOOPBACK)2374ranks = list(range(self.world_size))2375for root_rank in ranks:2376self._test_broadcast_coalesced(process_group, device, root_rank)2377
2378@requires_gloo()2379@skip_if_lt_x_gpu(2)2380def test_sequence_num_set_default_pg_gloo(self):2381self._test_sequence_num_set_default_pg(backend="gloo")2382
2383@requires_gloo()2384@skip_if_lt_x_gpu(2)2385def test_sequence_num_set_gloo_new_group(self):2386self._test_sequence_num_set_new_group(backend="gloo")2387
2388@skip_if_lt_x_gpu(2)2389@requires_gloo()2390def test_sequence_num_incremented_gloo_default(self):2391self._test_sequence_num_incremented_default_group("gloo")2392
2393@skip_if_lt_x_gpu(4)2394@requires_gloo()2395def test_sequence_num_incremented_gloo_subgroup(self):2396if self.world_size < 4:2397return skip_but_pass_in_sandcastle("Test requires world_size of at least 4")2398self._test_sequence_num_incremented_subgroup("gloo")2399
2400@skip_if_lt_x_gpu(2)2401@requires_gloo()2402def test_gloo_warn_not_in_group(self):2403self._test_warn_not_in_group(backend="gloo")2404
2405@skip_if_lt_x_gpu(2)2406@requires_gloo()2407def test_gloo_rank_membership(self):2408self._test_rank_membership(backend="gloo")2409
2410@skip_if_lt_x_gpu(2)2411@requires_gloo()2412def test_tensor_dtype_mismatch(self):2413self._test_tensor_dtype_mismatch(backend="gloo")2414
2415@skip_if_lt_x_gpu(2)2416@requires_gloo()2417def test_tensor_dtype_complex(self):2418self._test_tensor_dtype_complex(backend="gloo")2419
2420@requires_gloo()2421def test_bool_tensors(self):2422self._test_bool_tensors(backend="gloo")2423
2424class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGroupWithDispatchedCollectivesTests):2425@requires_gloo()2426def test_collectives(self):2427self._test_collectives(backend="gloo")2428
2429@requires_gloo()2430def test_allreduce_coalesced(self):2431self._test_allreduce_coalesced(backend="gloo")2432
2433@requires_gloo()2434def test_all_to_all_single(self):2435self._test_all_to_all_single(backend="gloo")2436
2437@requires_gloo()2438def test_allgather_coalesced(self):2439store = dist.FileStore(self.file_name, self.world_size)2440dist.init_process_group(2441"gloo",2442world_size=self.world_size,2443rank=self.rank,2444store=store,2445)2446input_tensor = torch.ones(10, 10, dtype=torch.float32)2447output_tensor_list = [torch.zeros_like(input_tensor)]2448dist.all_gather_coalesced([output_tensor_list], [input_tensor])2449self.assertEqual(output_tensor_list, [input_tensor])2450
2451@requires_gloo()2452def test_monitored_barrier(self):2453store = dist.FileStore(self.file_name, self.world_size)2454dist.init_process_group(2455"gloo",2456world_size=self.world_size,2457rank=self.rank,2458store=store,2459)2460dist.monitored_barrier()2461
2462class CompilerTest(test_c10d_common.CompilerTest):2463
2464@property2465def world_size(self):2466return 22467
2468def _get_default_group(self):2469store = c10d.FileStore(self.file_name, self.world_size)2470dist.init_process_group(2471backend="gloo",2472rank=self.rank,2473world_size=self.world_size,2474store=store,2475)2476return dist.distributed_c10d._get_default_group()2477
2478def test_allreduce_work_wait_cpu(self):2479self._test_allreduce_work_wait(torch.ones(2, 2) * self.rank)2480
2481@skip_if_lt_x_gpu(2)2482def test_allreduce_work_wait_gpu(self):2483self._test_allreduce_work_wait(2484torch.ones(2, 2, device=self.rank) * self.rank2485)2486
2487def test_allgather_work_wait_cpu(self):2488self._test_allgather_work_wait(torch.ones(2, 2) * self.rank)2489
2490@skip_if_lt_x_gpu(2)2491def test_allgather_work_wait_gpu(self):2492self._test_allgather_work_wait(2493torch.ones(2, 2, device=self.rank) * self.rank2494)2495
2496def test_broadcast_work_wait_cpu(self):2497self._test_broadcast_work_wait(torch.ones(2, 2) * self.rank)2498
2499@skip_if_lt_x_gpu(2)2500def test_broadcast_work_wait_gpu(self):2501self._test_broadcast_work_wait(2502torch.ones(2, 2, device=self.rank) * self.rank2503)2504
2505def test_scatter_work_wait_cpu(self):2506self._test_scatter_work_wait(torch.ones(2, 2) * self.rank)2507
2508@skip_if_lt_x_gpu(2)2509def test_scatter_work_wait_gpu(self):2510self._test_scatter_work_wait(2511torch.ones(2, 2, device=self.rank) * self.rank2512)2513
2514def test_nested_comm_tensor_wrapping(self):2515self._test_nested_comm_tensor_wrapping(torch.ones(2, 2) * self.rank)2516
2517def test_consecutive_comm_work_wait_cpu(self):2518self._test_consecutive_comm_work_wait(torch.ones(2, 2) * self.rank)2519
2520@skip_if_lt_x_gpu(2)2521def test_consecutive_comm_work_wait_gpu(self):2522self._test_consecutive_comm_work_wait(2523torch.ones(2, 2, device=self.rank) * self.rank2524)2525
2526class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):2527def setUp(self):2528super().setUp()2529self._spawn_processes()2530
2531def tearDown(self):2532super().tearDown()2533try:2534os.remove(self.file_name)2535except OSError:2536pass2537
2538@property2539def device(self):2540return torch.device("cpu")2541
2542@requires_gloo()2543def test_new_group_local_sync(self):2544self._test_new_group_local_sync(backend="gloo")2545
2546@requires_gloo()2547def test_new_group_local_sync_sanity_check(self):2548self._test_new_group_local_sync_sanity_check(backend="gloo")2549
2550@requires_gloo()2551def test_new_group_local_sync_duplicate_pg(self):2552self._test_new_group_local_sync_duplicate_pg(backend="gloo")2553
2554if __name__ == "__main__":2555assert (2556not torch.cuda._initialized2557), "test_distributed must not have initialized CUDA context on main process"2558
2559run_tests()2560