pytorch
203 строки · 6.8 Кб
1# Owner(s): ["oncall: distributed"]
2
3import json4import logging5import os6import re7import sys8import time9from functools import partial, wraps10
11import torch12import torch.distributed as dist13
14from torch.distributed.c10d_logger import _c10d_logger, _exception_logger, _time_logger15
16if not dist.is_available():17print("Distributed not available, skipping tests", file=sys.stderr)18sys.exit(0)19
20from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS21from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN22
23if TEST_WITH_DEV_DBG_ASAN:24print(25"Skip dev-asan as torch + multiprocessing spawn have known issues",26file=sys.stderr,27)28sys.exit(0)29
30BACKEND = dist.Backend.NCCL31WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))32
33
34def with_comms(func=None):35if func is None:36return partial(37with_comms,38)39
40@wraps(func)41def wrapper(self, *args, **kwargs):42if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:43sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)44self.dist_init()45func(self)46self.destroy_comms()47
48return wrapper49
50
51class C10dErrorLoggerTest(MultiProcessTestCase):52def setUp(self):53super().setUp()54os.environ["WORLD_SIZE"] = str(self.world_size)55os.environ["BACKEND"] = BACKEND56self._spawn_processes()57
58@property59def device(self):60return (61torch.device(self.rank)62if BACKEND == dist.Backend.NCCL63else torch.device("cpu")64)65
66@property67def world_size(self):68return WORLD_SIZE69
70@property71def process_group(self):72return dist.group.WORLD73
74def destroy_comms(self):75# Wait for all ranks to reach here before starting shutdown.76dist.barrier()77dist.destroy_process_group()78
79def dist_init(self):80dist.init_process_group(81backend=BACKEND,82world_size=self.world_size,83rank=self.rank,84init_method=f"file://{self.file_name}",85)86
87# set device for nccl pg for collectives88if BACKEND == "nccl":89torch.cuda.set_device(self.rank)90
91def test_get_or_create_logger(self):92self.assertIsNotNone(_c10d_logger)93self.assertEqual(1, len(_c10d_logger.handlers))94self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler)95
96@_exception_logger97def _failed_broadcast_raise_exception(self):98tensor = torch.arange(2, dtype=torch.int64)99dist.broadcast(tensor, self.world_size + 1)100
101@_exception_logger102def _failed_broadcast_not_raise_exception(self):103try:104tensor = torch.arange(2, dtype=torch.int64)105dist.broadcast(tensor, self.world_size + 1)106except Exception:107pass108
109@with_comms110def test_exception_logger(self) -> None:111with self.assertRaises(Exception):112self._failed_broadcast_raise_exception()113
114with self.assertLogs(_c10d_logger, level="DEBUG") as captured:115self._failed_broadcast_not_raise_exception()116error_msg_dict = json.loads(117re.search("({.+})", captured.output[0]).group(0).replace("'", '"')118)119
120self.assertEqual(len(error_msg_dict), 10)121
122self.assertIn("pg_name", error_msg_dict.keys())123self.assertEqual("None", error_msg_dict["pg_name"])124
125self.assertIn("func_name", error_msg_dict.keys())126self.assertEqual("broadcast", error_msg_dict["func_name"])127
128self.assertIn("args", error_msg_dict.keys())129
130self.assertIn("backend", error_msg_dict.keys())131self.assertEqual("nccl", error_msg_dict["backend"])132
133self.assertIn("nccl_version", error_msg_dict.keys())134nccl_ver = torch.cuda.nccl.version()135self.assertEqual(136".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]137)138
139# In this test case, group_size = world_size, since we don't have multiple processes on one node.140self.assertIn("group_size", error_msg_dict.keys())141self.assertEqual(str(self.world_size), error_msg_dict["group_size"])142
143self.assertIn("world_size", error_msg_dict.keys())144self.assertEqual(str(self.world_size), error_msg_dict["world_size"])145
146self.assertIn("global_rank", error_msg_dict.keys())147self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"])148
149# In this test case, local_rank = global_rank, since we don't have multiple processes on one node.150self.assertIn("local_rank", error_msg_dict.keys())151self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"])152
153@_time_logger154def _dummy_sleep(self):155time.sleep(5)156
157@with_comms158def test_time_logger(self) -> None:159with self.assertLogs(_c10d_logger, level="DEBUG") as captured:160self._dummy_sleep()161msg_dict = json.loads(162re.search("({.+})", captured.output[0]).group(0).replace("'", '"')163)164self.assertEqual(len(msg_dict), 10)165
166self.assertIn("pg_name", msg_dict.keys())167self.assertEqual("None", msg_dict["pg_name"])168
169self.assertIn("func_name", msg_dict.keys())170self.assertEqual("_dummy_sleep", msg_dict["func_name"])171
172self.assertIn("args", msg_dict.keys())173
174self.assertIn("backend", msg_dict.keys())175self.assertEqual("nccl", msg_dict["backend"])176
177self.assertIn("nccl_version", msg_dict.keys())178nccl_ver = torch.cuda.nccl.version()179self.assertEqual(180".".join(str(v) for v in nccl_ver), msg_dict["nccl_version"]181)182
183# In this test case, group_size = world_size, since we don't have multiple processes on one node.184self.assertIn("group_size", msg_dict.keys())185self.assertEqual(str(self.world_size), msg_dict["group_size"])186
187self.assertIn("world_size", msg_dict.keys())188self.assertEqual(str(self.world_size), msg_dict["world_size"])189
190self.assertIn("global_rank", msg_dict.keys())191self.assertIn(str(dist.get_rank()), msg_dict["global_rank"])192
193# In this test case, local_rank = global_rank, since we don't have multiple processes on one node.194self.assertIn("local_rank", msg_dict.keys())195self.assertIn(str(dist.get_rank()), msg_dict["local_rank"])196
197self.assertIn("time_spent", msg_dict.keys())198time_ns = re.findall(r"\d+", msg_dict["time_spent"])[0]199self.assertEqual(5, int(float(time_ns) / pow(10, 9)))200
201
202if __name__ == "__main__":203run_tests()204