pytorch

Форк
0
/
test_control_collectives.py 
214 строк · 7.2 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
from datetime import timedelta
4
from multiprocessing.pool import ThreadPool
5

6
import torch
7
import torch.distributed as dist
8
from torch.testing._internal.common_utils import run_tests, TestCase
9

10

11
# simple example of user code that takes the base class ControlCollectives
12
# and executes multiple different collectives
13
def simple_user_func(collectives: dist._ControlCollectives, rank: int) -> int:
14
    timeout = timedelta(seconds=10)
15
    # first a barrier
16
    collectives.barrier("1", timeout, True)
17
    # then an all_sum
18
    out = collectives.all_sum("2", rank, timeout)
19
    return out
20

21

22
class TestCollectives(TestCase):
23
    def test_barrier(self) -> None:
24
        store = dist.HashStore()
25

26
        world_size = 2
27

28
        def f(rank: int) -> None:
29
            collectives = dist._StoreCollectives(store, rank, world_size)
30
            collectives.barrier("foo", timedelta(seconds=10), True)
31

32
        with ThreadPool(world_size) as pool:
33
            pool.map(f, range(world_size))
34

35
    def test_broadcast(self) -> None:
36
        store = dist.HashStore()
37

38
        world_size = 4
39
        timeout = timedelta(seconds=10)
40

41
        def f(rank: int) -> None:
42
            collectives = dist._StoreCollectives(store, rank, world_size)
43
            if rank == 2:
44
                collectives.broadcast_send("foo", b"data", timeout)
45
            else:
46
                out = collectives.broadcast_recv("foo", timeout)
47
                self.assertEqual(out, b"data")
48

49
        with ThreadPool(world_size) as pool:
50
            pool.map(f, range(world_size))
51

52
    def test_gather(self) -> None:
53
        store = dist.HashStore()
54

55
        world_size = 4
56
        timeout = timedelta(seconds=10)
57

58
        def f(rank: int) -> None:
59
            collectives = dist._StoreCollectives(store, rank, world_size)
60
            if rank == 2:
61
                out = collectives.gather_recv("foo", str(rank), timeout)
62
                self.assertEqual(out, [b"0", b"1", b"2", b"3"])
63
            else:
64
                collectives.gather_send("foo", str(rank), timeout)
65

66
        with ThreadPool(world_size) as pool:
67
            pool.map(f, range(world_size))
68

69
    def test_scatter(self) -> None:
70
        store = dist.HashStore()
71

72
        world_size = 4
73
        timeout = timedelta(seconds=10)
74

75
        def f(rank: int) -> None:
76
            collectives = dist._StoreCollectives(store, rank, world_size)
77
            if rank == 2:
78
                out = collectives.scatter_send(
79
                    "foo", [str(i) for i in range(world_size)], timeout
80
                )
81
            else:
82
                out = collectives.scatter_recv("foo", timeout)
83
            self.assertEqual(out, str(rank).encode())
84

85
        with ThreadPool(world_size) as pool:
86
            pool.map(f, range(world_size))
87

88
    def test_all_sum(self) -> None:
89
        store = dist.HashStore()
90

91
        world_size = 4
92
        timeout = timedelta(seconds=10)
93

94
        def f(rank: int) -> None:
95
            collectives = dist._StoreCollectives(store, rank, world_size)
96
            out = collectives.all_sum("foo", rank, timeout)
97
            self.assertEqual(out, sum(range(world_size)))
98

99
        with ThreadPool(world_size) as pool:
100
            pool.map(f, range(world_size))
101

102
    def test_broadcast_timeout(self) -> None:
103
        store = dist.HashStore()
104

105
        world_size = 4
106
        timeout = timedelta(milliseconds=1)
107
        collectives = dist._StoreCollectives(store, 1, world_size)
108
        with self.assertRaisesRegex(Exception, "Wait timeout"):
109
            collectives.broadcast_recv("foo", timeout)
110

111
    def test_gather_timeout(self) -> None:
112
        store = dist.HashStore()
113

114
        world_size = 4
115
        timeout = timedelta(milliseconds=1)
116
        collectives = dist._StoreCollectives(store, 1, world_size)
117
        with self.assertRaisesRegex(
118
            Exception, "gather failed -- missing ranks: 0, 2, 3"
119
        ):
120
            collectives.gather_recv("foo", "data", timeout)
121

122
    def test_scatter_timeout(self) -> None:
123
        store = dist.HashStore()
124

125
        world_size = 4
126
        timeout = timedelta(milliseconds=1)
127
        collectives = dist._StoreCollectives(store, 1, world_size)
128
        with self.assertRaisesRegex(Exception, "Wait timeout"):
129
            collectives.scatter_recv("foo", timeout)
130

131
    def test_all_gather_timeout(self) -> None:
132
        store = dist.HashStore()
133

134
        world_size = 4
135
        timeout = timedelta(milliseconds=1)
136
        collectives = dist._StoreCollectives(store, 1, world_size)
137
        with self.assertRaisesRegex(
138
            Exception, "all_gather failed -- missing ranks: 0, 2, 3"
139
        ):
140
            collectives.all_gather("foo", "data", timeout)
141

142
    def test_barrier_timeout(self) -> None:
143
        store = dist.HashStore()
144

145
        world_size = 4
146
        timeout = timedelta(milliseconds=1)
147
        collectives = dist._StoreCollectives(store, 1, world_size)
148
        with self.assertRaisesRegex(
149
            Exception, "barrier failed -- missing ranks: 0, 2, 3"
150
        ):
151
            collectives.barrier("foo", timeout, True)
152

153
    def test_all_sum_timeout(self) -> None:
154
        store = dist.HashStore()
155

156
        world_size = 4
157
        timeout = timedelta(milliseconds=1)
158
        collectives = dist._StoreCollectives(store, 1, world_size)
159
        with self.assertRaisesRegex(
160
            Exception, "barrier failed -- missing ranks: 0, 2, 3"
161
        ):
162
            collectives.all_sum("foo", 1, timeout)
163

164
    def test_unique(self) -> None:
165
        store = dist.HashStore()
166

167
        collectives = dist._StoreCollectives(store, 1, 1)
168
        collectives.broadcast_send("foo", "bar")
169

170
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
171
            collectives.broadcast_send("foo", "bar")
172

173
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
174
            collectives.broadcast_recv("foo")
175

176
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
177
            collectives.gather_send("foo", "bar")
178

179
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
180
            collectives.gather_recv("foo", "asdf")
181

182
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
183
            collectives.scatter_send("foo", ["asdf"])
184

185
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
186
            collectives.scatter_recv("foo")
187

188
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
189
            collectives.all_gather("foo", "bar")
190

191
        with self.assertRaisesRegex(Exception, "Key foo has already been used"):
192
            collectives.all_sum("foo", 2)
193

194
    def test_simple_user_func(self) -> None:
195
        store = dist.HashStore()
196
        world_size = 4
197

198
        def f(rank: int) -> None:
199
            # user need to create child collectives
200
            # but simple_user_func do not need to be changed for different child collectives
201
            store_collectives = dist._StoreCollectives(store, rank, world_size)
202
            out = simple_user_func(store_collectives, rank)
203
            self.assertEqual(out, sum(range(world_size)))
204

205
        with ThreadPool(world_size) as pool:
206
            pool.map(f, range(world_size))
207

208

209
if __name__ == "__main__":
210
    assert (
211
        not torch.cuda._initialized
212
    ), "test_distributed must not have initialized CUDA context on main process"
213

214
    run_tests()
215

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.