pytorch
124 строки · 4.1 Кб
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import logging10import multiprocessing as mp11import signal12import time13
14import torch.distributed.elastic.timer as timer15import torch.multiprocessing as torch_mp16from torch.testing._internal.common_utils import (17IS_MACOS,18IS_WINDOWS,19run_tests,20skip_but_pass_in_sandcastle_if,21TEST_WITH_DEV_DBG_ASAN,22TestCase,23)
24
25
26logging.basicConfig(27level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"28)
29
30
31def _happy_function(rank, mp_queue):32timer.configure(timer.LocalTimerClient(mp_queue))33with timer.expires(after=1):34time.sleep(0.5)35
36
37def _stuck_function(rank, mp_queue):38timer.configure(timer.LocalTimerClient(mp_queue))39with timer.expires(after=1):40time.sleep(5)41
42
43# timer is not supported on macos or windows
44if not (IS_WINDOWS or IS_MACOS):45
46class LocalTimerExample(TestCase):47"""48Demonstrates how to use LocalTimerServer and LocalTimerClient
49to enforce expiration of code-blocks.
50
51Since torch multiprocessing's ``start_process`` method currently
52does not take the multiprocessing context as parameter argument
53there is no way to create the mp.Queue in the correct
54context BEFORE spawning child processes. Once the ``start_process``
55API is changed in torch, then re-enable ``test_torch_mp_example``
56unittest. As of now this will SIGSEGV.
57"""
58
59@skip_but_pass_in_sandcastle_if(60TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"61)62def test_torch_mp_example(self):63# in practice set the max_interval to a larger value (e.g. 60 seconds)64mp_queue = mp.get_context("spawn").Queue()65server = timer.LocalTimerServer(mp_queue, max_interval=0.01)66server.start()67
68world_size = 869
70# all processes should complete successfully71# since start_process does NOT take context as parameter argument yet72# this method WILL FAIL (hence the test is disabled)73torch_mp.spawn(74fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True75)76
77with self.assertRaises(Exception):78# torch.multiprocessing.spawn kills all sub-procs79# if one of them gets killed80torch_mp.spawn(81fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True82)83
84server.stop()85
86@skip_but_pass_in_sandcastle_if(87TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"88)89def test_example_start_method_spawn(self):90self._run_example_with(start_method="spawn")91
92# @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")93# def test_example_start_method_forkserver(self):94# self._run_example_with(start_method="forkserver")95
96def _run_example_with(self, start_method):97spawn_ctx = mp.get_context(start_method)98mp_queue = spawn_ctx.Queue()99server = timer.LocalTimerServer(mp_queue, max_interval=0.01)100server.start()101
102world_size = 8103processes = []104for i in range(0, world_size):105if i % 2 == 0:106p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue))107else:108p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue))109p.start()110processes.append(p)111
112for i in range(0, world_size):113p = processes[i]114p.join()115if i % 2 == 0:116self.assertEqual(-signal.SIGKILL, p.exitcode)117else:118self.assertEqual(0, p.exitcode)119
120server.stop()121
122
123if __name__ == "__main__":124run_tests()125