pytorch

Форк
0
/
test_c10d_logger.py 
203 строки · 6.8 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import json
4
import logging
5
import os
6
import re
7
import sys
8
import time
9
from functools import partial, wraps
10

11
import torch
12
import torch.distributed as dist
13

14
from torch.distributed.c10d_logger import _c10d_logger, _exception_logger, _time_logger
15

16
if not dist.is_available():
17
    print("Distributed not available, skipping tests", file=sys.stderr)
18
    sys.exit(0)
19

20
from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
21
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
22

23
if TEST_WITH_DEV_DBG_ASAN:
24
    print(
25
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
26
        file=sys.stderr,
27
    )
28
    sys.exit(0)
29

30
BACKEND = dist.Backend.NCCL
31
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
32

33

34
def with_comms(func=None):
35
    if func is None:
36
        return partial(
37
            with_comms,
38
        )
39

40
    @wraps(func)
41
    def wrapper(self, *args, **kwargs):
42
        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
43
            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
44
        self.dist_init()
45
        func(self)
46
        self.destroy_comms()
47

48
    return wrapper
49

50

51
class C10dErrorLoggerTest(MultiProcessTestCase):
52
    def setUp(self):
53
        super().setUp()
54
        os.environ["WORLD_SIZE"] = str(self.world_size)
55
        os.environ["BACKEND"] = BACKEND
56
        self._spawn_processes()
57

58
    @property
59
    def device(self):
60
        return (
61
            torch.device(self.rank)
62
            if BACKEND == dist.Backend.NCCL
63
            else torch.device("cpu")
64
        )
65

66
    @property
67
    def world_size(self):
68
        return WORLD_SIZE
69

70
    @property
71
    def process_group(self):
72
        return dist.group.WORLD
73

74
    def destroy_comms(self):
75
        # Wait for all ranks to reach here before starting shutdown.
76
        dist.barrier()
77
        dist.destroy_process_group()
78

79
    def dist_init(self):
80
        dist.init_process_group(
81
            backend=BACKEND,
82
            world_size=self.world_size,
83
            rank=self.rank,
84
            init_method=f"file://{self.file_name}",
85
        )
86

87
        # set device for nccl pg for collectives
88
        if BACKEND == "nccl":
89
            torch.cuda.set_device(self.rank)
90

91
    def test_get_or_create_logger(self):
92
        self.assertIsNotNone(_c10d_logger)
93
        self.assertEqual(1, len(_c10d_logger.handlers))
94
        self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler)
95

96
    @_exception_logger
97
    def _failed_broadcast_raise_exception(self):
98
        tensor = torch.arange(2, dtype=torch.int64)
99
        dist.broadcast(tensor, self.world_size + 1)
100

101
    @_exception_logger
102
    def _failed_broadcast_not_raise_exception(self):
103
        try:
104
            tensor = torch.arange(2, dtype=torch.int64)
105
            dist.broadcast(tensor, self.world_size + 1)
106
        except Exception:
107
            pass
108

109
    @with_comms
110
    def test_exception_logger(self) -> None:
111
        with self.assertRaises(Exception):
112
            self._failed_broadcast_raise_exception()
113

114
        with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
115
            self._failed_broadcast_not_raise_exception()
116
            error_msg_dict = json.loads(
117
                re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
118
            )
119

120
            self.assertEqual(len(error_msg_dict), 10)
121

122
            self.assertIn("pg_name", error_msg_dict.keys())
123
            self.assertEqual("None", error_msg_dict["pg_name"])
124

125
            self.assertIn("func_name", error_msg_dict.keys())
126
            self.assertEqual("broadcast", error_msg_dict["func_name"])
127

128
            self.assertIn("args", error_msg_dict.keys())
129

130
            self.assertIn("backend", error_msg_dict.keys())
131
            self.assertEqual("nccl", error_msg_dict["backend"])
132

133
            self.assertIn("nccl_version", error_msg_dict.keys())
134
            nccl_ver = torch.cuda.nccl.version()
135
            self.assertEqual(
136
                ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
137
            )
138

139
            # In this test case, group_size = world_size, since we don't have multiple processes on one node.
140
            self.assertIn("group_size", error_msg_dict.keys())
141
            self.assertEqual(str(self.world_size), error_msg_dict["group_size"])
142

143
            self.assertIn("world_size", error_msg_dict.keys())
144
            self.assertEqual(str(self.world_size), error_msg_dict["world_size"])
145

146
            self.assertIn("global_rank", error_msg_dict.keys())
147
            self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"])
148

149
            # In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
150
            self.assertIn("local_rank", error_msg_dict.keys())
151
            self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"])
152

153
    @_time_logger
154
    def _dummy_sleep(self):
155
        time.sleep(5)
156

157
    @with_comms
158
    def test_time_logger(self) -> None:
159
        with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
160
            self._dummy_sleep()
161
            msg_dict = json.loads(
162
                re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
163
            )
164
            self.assertEqual(len(msg_dict), 10)
165

166
            self.assertIn("pg_name", msg_dict.keys())
167
            self.assertEqual("None", msg_dict["pg_name"])
168

169
            self.assertIn("func_name", msg_dict.keys())
170
            self.assertEqual("_dummy_sleep", msg_dict["func_name"])
171

172
            self.assertIn("args", msg_dict.keys())
173

174
            self.assertIn("backend", msg_dict.keys())
175
            self.assertEqual("nccl", msg_dict["backend"])
176

177
            self.assertIn("nccl_version", msg_dict.keys())
178
            nccl_ver = torch.cuda.nccl.version()
179
            self.assertEqual(
180
                ".".join(str(v) for v in nccl_ver), msg_dict["nccl_version"]
181
            )
182

183
            # In this test case, group_size = world_size, since we don't have multiple processes on one node.
184
            self.assertIn("group_size", msg_dict.keys())
185
            self.assertEqual(str(self.world_size), msg_dict["group_size"])
186

187
            self.assertIn("world_size", msg_dict.keys())
188
            self.assertEqual(str(self.world_size), msg_dict["world_size"])
189

190
            self.assertIn("global_rank", msg_dict.keys())
191
            self.assertIn(str(dist.get_rank()), msg_dict["global_rank"])
192

193
            # In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
194
            self.assertIn("local_rank", msg_dict.keys())
195
            self.assertIn(str(dist.get_rank()), msg_dict["local_rank"])
196

197
            self.assertIn("time_spent", msg_dict.keys())
198
            time_ns = re.findall(r"\d+", msg_dict["time_spent"])[0]
199
            self.assertEqual(5, int(float(time_ns) / pow(10, 9)))
200

201

202
if __name__ == "__main__":
203
    run_tests()
204

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.