pytorch

Форк
0
/
test_monitor.py 
166 строк · 4.7 Кб
1
# Owner(s): ["oncall: r2p"]
2

3
import tempfile
4
import time
5

6
from datetime import datetime, timedelta
7

8
from torch.monitor import (
9
    Aggregation,
10
    Event,
11
    log_event,
12
    register_event_handler,
13
    Stat,
14
    TensorboardEventHandler,
15
    unregister_event_handler,
16
    _WaitCounter,
17
)
18
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
19

20
class TestMonitor(TestCase):
21
    def test_interval_stat(self) -> None:
22
        events = []
23

24
        def handler(event):
25
            events.append(event)
26

27
        handle = register_event_handler(handler)
28
        s = Stat(
29
            "asdf",
30
            (Aggregation.SUM, Aggregation.COUNT),
31
            timedelta(milliseconds=1),
32
        )
33
        self.assertEqual(s.name, "asdf")
34

35
        s.add(2)
36
        for _ in range(100):
37
            # NOTE: different platforms sleep may be inaccurate so we loop
38
            # instead (i.e. win)
39
            time.sleep(1 / 1000)  # ms
40
            s.add(3)
41
            if len(events) >= 1:
42
                break
43
        self.assertGreaterEqual(len(events), 1)
44
        unregister_event_handler(handle)
45

46
    def test_fixed_count_stat(self) -> None:
47
        s = Stat(
48
            "asdf",
49
            (Aggregation.SUM, Aggregation.COUNT),
50
            timedelta(hours=100),
51
            3,
52
        )
53
        s.add(1)
54
        s.add(2)
55
        name = s.name
56
        self.assertEqual(name, "asdf")
57
        self.assertEqual(s.count, 2)
58
        s.add(3)
59
        self.assertEqual(s.count, 0)
60
        self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
61

62
    def test_log_event(self) -> None:
63
        e = Event(
64
            name="torch.monitor.TestEvent",
65
            timestamp=datetime.now(),
66
            data={
67
                "str": "a string",
68
                "float": 1234.0,
69
                "int": 1234,
70
            },
71
        )
72
        self.assertEqual(e.name, "torch.monitor.TestEvent")
73
        self.assertIsNotNone(e.timestamp)
74
        self.assertIsNotNone(e.data)
75
        log_event(e)
76

77
    @skipIfTorchDynamo("Really weird error")
78
    def test_event_handler(self) -> None:
79
        events = []
80

81
        def handler(event: Event) -> None:
82
            events.append(event)
83

84
        handle = register_event_handler(handler)
85
        e = Event(
86
            name="torch.monitor.TestEvent",
87
            timestamp=datetime.now(),
88
            data={},
89
        )
90
        log_event(e)
91
        self.assertEqual(len(events), 1)
92
        self.assertEqual(events[0], e)
93
        log_event(e)
94
        self.assertEqual(len(events), 2)
95

96
        unregister_event_handler(handle)
97
        log_event(e)
98
        self.assertEqual(len(events), 2)
99

100
    def test_wait_counter(self) -> None:
101
        wait_counter = _WaitCounter(
102
            "test_wait_counter",
103
        )
104
        with wait_counter.guard() as wcg:
105
            pass
106

107

108
@skipIfTorchDynamo("Really weird error")
109
class TestMonitorTensorboard(TestCase):
110
    def setUp(self):
111
        global SummaryWriter, event_multiplexer
112
        try:
113
            from torch.utils.tensorboard import SummaryWriter
114
            from tensorboard.backend.event_processing import (
115
                plugin_event_multiplexer as event_multiplexer,
116
            )
117
        except ImportError:
118
            return self.skipTest("Skip the test since TensorBoard is not installed")
119
        self.temp_dirs = []
120

121
    def create_summary_writer(self):
122
        temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
123
        self.temp_dirs.append(temp_dir)
124
        return SummaryWriter(temp_dir.name)
125

126
    def tearDown(self):
127
        # Remove directories created by SummaryWriter
128
        for temp_dir in self.temp_dirs:
129
            temp_dir.cleanup()
130

131
    def test_event_handler(self):
132
        with self.create_summary_writer() as w:
133
            handle = register_event_handler(TensorboardEventHandler(w))
134

135
            s = Stat(
136
                "asdf",
137
                (Aggregation.SUM, Aggregation.COUNT),
138
                timedelta(hours=1),
139
                5,
140
            )
141
            for i in range(10):
142
                s.add(i)
143
            self.assertEqual(s.count, 0)
144

145
            unregister_event_handler(handle)
146

147
        mul = event_multiplexer.EventMultiplexer()
148
        mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
149
        mul.Reload()
150
        scalar_dict = mul.PluginRunToTagToContent("scalars")
151
        raw_result = {
152
            tag: mul.Tensors(run, tag)
153
            for run, run_dict in scalar_dict.items()
154
            for tag in run_dict
155
        }
156
        scalars = {
157
            tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
158
        }
159
        self.assertEqual(scalars, {
160
            "asdf.sum": [10],
161
            "asdf.count": [5],
162
        })
163

164

165
if __name__ == '__main__':
166
    run_tests()
167

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.