pytorch

Форк
0
/
test_c10d_object_collectives.py 
179 строк · 5.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import os
4
import sys
5
from functools import partial, wraps
6

7
import torch
8
import torch.distributed as dist
9

10

11
if not dist.is_available():
12
    print("Distributed not available, skipping tests", file=sys.stderr)
13
    sys.exit(0)
14

15
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
16
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
17

18

19
if TEST_WITH_DEV_DBG_ASAN:
20
    print(
21
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
22
        file=sys.stderr,
23
    )
24
    sys.exit(0)
25

26
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
27
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
28

29

30
def with_comms(func=None):
31
    if func is None:
32
        return partial(
33
            with_comms,
34
        )
35

36
    @wraps(func)
37
    def wrapper(self, *args, **kwargs):
38
        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
39
            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
40
        self.dist_init()
41
        func(self)
42
        self.destroy_comms()
43

44
    return wrapper
45

46

47
class TestObjectCollectives(MultiProcessTestCase):
48
    def setUp(self):
49
        super().setUp()
50
        os.environ["WORLD_SIZE"] = str(self.world_size)
51
        os.environ["BACKEND"] = BACKEND
52
        self._spawn_processes()
53

54
    @property
55
    def device(self):
56
        return (
57
            torch.device(self.rank)
58
            if BACKEND == dist.Backend.NCCL
59
            else torch.device("cpu")
60
        )
61

62
    @property
63
    def world_size(self):
64
        return WORLD_SIZE
65

66
    @property
67
    def process_group(self):
68
        return dist.group.WORLD
69

70
    def destroy_comms(self):
71
        # Wait for all ranks to reach here before starting shutdown.
72
        dist.barrier()
73
        dist.destroy_process_group()
74

75
    def dist_init(self):
76
        dist.init_process_group(
77
            backend=BACKEND,
78
            world_size=self.world_size,
79
            rank=self.rank,
80
            init_method=f"file://{self.file_name}",
81
        )
82

83
        # set device for nccl pg for collectives
84
        if BACKEND == "nccl":
85
            torch.cuda.set_device(self.rank)
86

87
    @with_comms()
88
    def test_all_gather_object(self):
89
        output = [None] * dist.get_world_size()
90
        dist.all_gather_object(object_list=output, obj=self.rank)
91

92
        for i, v in enumerate(output):
93
            self.assertEqual(i, v, f"rank: {self.rank}")
94

95
    @with_comms()
96
    def test_gather_object(self):
97
        output = [None] * dist.get_world_size() if self.rank == 0 else None
98
        dist.gather_object(obj=self.rank, object_gather_list=output)
99

100
        if self.rank == 0:
101
            for i, v in enumerate(output):
102
                self.assertEqual(i, v, f"rank: {self.rank}")
103

104
    @with_comms()
105
    def test_send_recv_object_list(self):
106
        val = 99 if self.rank == 0 else None
107
        object_list = [val] * dist.get_world_size()
108
        if self.rank == 0:
109
            dist.send_object_list(object_list, 1)
110
        if self.rank == 1:
111
            dist.recv_object_list(object_list, 0)
112

113
        if self.rank < 2:
114
            self.assertEqual(99, object_list[0])
115
        else:
116
            self.assertEqual(None, object_list[0])
117

118
    @with_comms()
119
    def test_broadcast_object_list(self):
120
        val = 99 if self.rank == 0 else None
121
        object_list = [val] * dist.get_world_size()
122
        # TODO test with broadcast_object_list's device argument
123
        dist.broadcast_object_list(object_list=object_list)
124

125
        self.assertEqual(99, object_list[0])
126

127
    @with_comms()
128
    def test_scatter_object_list(self):
129
        input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
130
        output_list = [None]
131
        dist.scatter_object_list(
132
            scatter_object_output_list=output_list, scatter_object_input_list=input_list
133
        )
134

135
        self.assertEqual(self.rank, output_list[0])
136

137
    # Test Object Collectives With Sub Pg
138

139
    def setup_sub_pg(self):
140
        rank = dist.get_rank()
141
        base_rank = rank - (rank % 2)
142
        ranks = [base_rank, base_rank + 1]
143
        my_pg = dist.new_group(ranks, use_local_synchronization=True)
144
        return rank, ranks, my_pg
145

146
    @with_comms()
147
    def test_subpg_scatter_object(self):
148
        rank, ranks, my_pg = self.setup_sub_pg()
149
        out_list = [None]
150
        dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
151
        self.assertEqual(rank, out_list[0])
152

153
    @with_comms()
154
    def test_subpg_all_gather_object(self):
155
        rank, ranks, my_pg = self.setup_sub_pg()
156
        out_list = [None] * len(ranks)
157
        dist.all_gather_object(out_list, rank, group=my_pg)
158
        self.assertEqual(ranks, out_list)
159

160
    @with_comms()
161
    def test_subpg_gather_object(self):
162
        rank, ranks, my_pg = self.setup_sub_pg()
163
        out_list = [None] * len(ranks) if rank == ranks[0] else None
164
        dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
165
        if rank == ranks[0]:
166
            self.assertEqual(ranks, out_list)
167

168
    @with_comms()
169
    def test_subpg_broadcast_object(self):
170
        rank, ranks, my_pg = self.setup_sub_pg()
171
        out_list = [None]
172
        if rank == ranks[0]:
173
            out_list[0] = rank
174
        dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg)
175
        self.assertEqual(ranks[0], out_list[0])
176

177

178
if __name__ == "__main__":
179
    run_tests()
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.