pytorch

Форк
0
/
local_timer_example.py 
124 строки · 4.1 Кб
1
#!/usr/bin/env python3
2
# Owner(s): ["oncall: r2p"]
3

4
# Copyright (c) Facebook, Inc. and its affiliates.
5
# All rights reserved.
6
#
7
# This source code is licensed under the BSD-style license found in the
8
# LICENSE file in the root directory of this source tree.
9
import logging
10
import multiprocessing as mp
11
import signal
12
import time
13

14
import torch.distributed.elastic.timer as timer
15
import torch.multiprocessing as torch_mp
16
from torch.testing._internal.common_utils import (
17
    IS_MACOS,
18
    IS_WINDOWS,
19
    run_tests,
20
    skip_but_pass_in_sandcastle_if,
21
    TEST_WITH_DEV_DBG_ASAN,
22
    TestCase,
23
)
24

25

26
logging.basicConfig(
27
    level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
28
)
29

30

31
def _happy_function(rank, mp_queue):
32
    timer.configure(timer.LocalTimerClient(mp_queue))
33
    with timer.expires(after=1):
34
        time.sleep(0.5)
35

36

37
def _stuck_function(rank, mp_queue):
38
    timer.configure(timer.LocalTimerClient(mp_queue))
39
    with timer.expires(after=1):
40
        time.sleep(5)
41

42

43
# timer is not supported on macos or windows
44
if not (IS_WINDOWS or IS_MACOS):
45

46
    class LocalTimerExample(TestCase):
47
        """
48
        Demonstrates how to use LocalTimerServer and LocalTimerClient
49
        to enforce expiration of code-blocks.
50

51
        Since torch multiprocessing's ``start_process`` method currently
52
        does not take the multiprocessing context as parameter argument
53
        there is no way to create the mp.Queue in the correct
54
        context BEFORE spawning child processes. Once the ``start_process``
55
        API is changed in torch, then re-enable ``test_torch_mp_example``
56
        unittest. As of now this will SIGSEGV.
57
        """
58

59
        @skip_but_pass_in_sandcastle_if(
60
            TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"
61
        )
62
        def test_torch_mp_example(self):
63
            # in practice set the max_interval to a larger value (e.g. 60 seconds)
64
            mp_queue = mp.get_context("spawn").Queue()
65
            server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
66
            server.start()
67

68
            world_size = 8
69

70
            # all processes should complete successfully
71
            # since start_process does NOT take context as parameter argument yet
72
            # this method WILL FAIL (hence the test is disabled)
73
            torch_mp.spawn(
74
                fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True
75
            )
76

77
            with self.assertRaises(Exception):
78
                # torch.multiprocessing.spawn kills all sub-procs
79
                # if one of them gets killed
80
                torch_mp.spawn(
81
                    fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True
82
                )
83

84
            server.stop()
85

86
        @skip_but_pass_in_sandcastle_if(
87
            TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"
88
        )
89
        def test_example_start_method_spawn(self):
90
            self._run_example_with(start_method="spawn")
91

92
        # @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
93
        # def test_example_start_method_forkserver(self):
94
        #     self._run_example_with(start_method="forkserver")
95

96
        def _run_example_with(self, start_method):
97
            spawn_ctx = mp.get_context(start_method)
98
            mp_queue = spawn_ctx.Queue()
99
            server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
100
            server.start()
101

102
            world_size = 8
103
            processes = []
104
            for i in range(0, world_size):
105
                if i % 2 == 0:
106
                    p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue))
107
                else:
108
                    p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue))
109
                p.start()
110
                processes.append(p)
111

112
            for i in range(0, world_size):
113
                p = processes[i]
114
                p.join()
115
                if i % 2 == 0:
116
                    self.assertEqual(-signal.SIGKILL, p.exitcode)
117
                else:
118
                    self.assertEqual(0, p.exitcode)
119

120
            server.stop()
121

122

123
if __name__ == "__main__":
124
    run_tests()
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.