pytorch
214 строк · 7.2 Кб
1# Owner(s): ["oncall: distributed"]
2
3from datetime import timedelta4from multiprocessing.pool import ThreadPool5
6import torch7import torch.distributed as dist8from torch.testing._internal.common_utils import run_tests, TestCase9
10
11# simple example of user code that takes the base class ControlCollectives
12# and executes multiple different collectives
13def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int:14timeout = timedelta(seconds=10)15# first a barrier16collectives.barrier("1", timeout, True)17# then an all_sum18out = collectives.all_sum("2", rank, timeout)19return out20
21
22class TestCollectives(TestCase):23def test_barrier(self) -> None:24store = dist.HashStore()25
26world_size = 227
28def f(rank: int) -> None:29collectives = dist._StoreCollectives(store, rank, world_size)30collectives.barrier("foo", timedelta(seconds=10), True)31
32with ThreadPool(world_size) as pool:33pool.map(f, range(world_size))34
35def test_broadcast(self) -> None:36store = dist.HashStore()37
38world_size = 439timeout = timedelta(seconds=10)40
41def f(rank: int) -> None:42collectives = dist._StoreCollectives(store, rank, world_size)43if rank == 2:44collectives.broadcast_send("foo", b"data", timeout)45else:46out = collectives.broadcast_recv("foo", timeout)47self.assertEqual(out, b"data")48
49with ThreadPool(world_size) as pool:50pool.map(f, range(world_size))51
52def test_gather(self) -> None:53store = dist.HashStore()54
55world_size = 456timeout = timedelta(seconds=10)57
58def f(rank: int) -> None:59collectives = dist._StoreCollectives(store, rank, world_size)60if rank == 2:61out = collectives.gather_recv("foo", str(rank), timeout)62self.assertEqual(out, [b"0", b"1", b"2", b"3"])63else:64collectives.gather_send("foo", str(rank), timeout)65
66with ThreadPool(world_size) as pool:67pool.map(f, range(world_size))68
69def test_scatter(self) -> None:70store = dist.HashStore()71
72world_size = 473timeout = timedelta(seconds=10)74
75def f(rank: int) -> None:76collectives = dist._StoreCollectives(store, rank, world_size)77if rank == 2:78out = collectives.scatter_send(79"foo", [str(i) for i in range(world_size)], timeout80)81else:82out = collectives.scatter_recv("foo", timeout)83self.assertEqual(out, str(rank).encode())84
85with ThreadPool(world_size) as pool:86pool.map(f, range(world_size))87
88def test_all_sum(self) -> None:89store = dist.HashStore()90
91world_size = 492timeout = timedelta(seconds=10)93
94def f(rank: int) -> None:95collectives = dist._StoreCollectives(store, rank, world_size)96out = collectives.all_sum("foo", rank, timeout)97self.assertEqual(out, sum(range(world_size)))98
99with ThreadPool(world_size) as pool:100pool.map(f, range(world_size))101
102def test_broadcast_timeout(self) -> None:103store = dist.HashStore()104
105world_size = 4106timeout = timedelta(milliseconds=1)107collectives = dist._StoreCollectives(store, 1, world_size)108with self.assertRaisesRegex(Exception, "Wait timeout"):109collectives.broadcast_recv("foo", timeout)110
111def test_gather_timeout(self) -> None:112store = dist.HashStore()113
114world_size = 4115timeout = timedelta(milliseconds=1)116collectives = dist._StoreCollectives(store, 1, world_size)117with self.assertRaisesRegex(118Exception, "gather failed -- missing ranks: 0, 2, 3"119):120collectives.gather_recv("foo", "data", timeout)121
122def test_scatter_timeout(self) -> None:123store = dist.HashStore()124
125world_size = 4126timeout = timedelta(milliseconds=1)127collectives = dist._StoreCollectives(store, 1, world_size)128with self.assertRaisesRegex(Exception, "Wait timeout"):129collectives.scatter_recv("foo", timeout)130
131def test_all_gather_timeout(self) -> None:132store = dist.HashStore()133
134world_size = 4135timeout = timedelta(milliseconds=1)136collectives = dist._StoreCollectives(store, 1, world_size)137with self.assertRaisesRegex(138Exception, "all_gather failed -- missing ranks: 0, 2, 3"139):140collectives.all_gather("foo", "data", timeout)141
142def test_barrier_timeout(self) -> None:143store = dist.HashStore()144
145world_size = 4146timeout = timedelta(milliseconds=1)147collectives = dist._StoreCollectives(store, 1, world_size)148with self.assertRaisesRegex(149Exception, "barrier failed -- missing ranks: 0, 2, 3"150):151collectives.barrier("foo", timeout, True)152
153def test_all_sum_timeout(self) -> None:154store = dist.HashStore()155
156world_size = 4157timeout = timedelta(milliseconds=1)158collectives = dist._StoreCollectives(store, 1, world_size)159with self.assertRaisesRegex(160Exception, "barrier failed -- missing ranks: 0, 2, 3"161):162collectives.all_sum("foo", 1, timeout)163
164def test_unique(self) -> None:165store = dist.HashStore()166
167collectives = dist._StoreCollectives(store, 1, 1)168collectives.broadcast_send("foo", "bar")169
170with self.assertRaisesRegex(Exception, "Key foo has already been used"):171collectives.broadcast_send("foo", "bar")172
173with self.assertRaisesRegex(Exception, "Key foo has already been used"):174collectives.broadcast_recv("foo")175
176with self.assertRaisesRegex(Exception, "Key foo has already been used"):177collectives.gather_send("foo", "bar")178
179with self.assertRaisesRegex(Exception, "Key foo has already been used"):180collectives.gather_recv("foo", "asdf")181
182with self.assertRaisesRegex(Exception, "Key foo has already been used"):183collectives.scatter_send("foo", ["asdf"])184
185with self.assertRaisesRegex(Exception, "Key foo has already been used"):186collectives.scatter_recv("foo")187
188with self.assertRaisesRegex(Exception, "Key foo has already been used"):189collectives.all_gather("foo", "bar")190
191with self.assertRaisesRegex(Exception, "Key foo has already been used"):192collectives.all_sum("foo", 2)193
194def test_simple_user_func(self) -> None:195store = dist.HashStore()196world_size = 4197
198def f(rank: int) -> None:199# user need to create child collectives200# but simple_user_func do not need to be changed for different child collectives201store_collectives = dist._StoreCollectives(store, rank, world_size)202out = simple_user_func(store_collectives, rank)203self.assertEqual(out, sum(range(world_size)))204
205with ThreadPool(world_size) as pool:206pool.map(f, range(world_size))207
208
209if __name__ == "__main__":210assert (211not torch.cuda._initialized212), "test_distributed must not have initialized CUDA context on main process"213
214run_tests()215