pytorch

Форк
0
/
test_monitor.py 
160 строк · 4.5 Кб
1
# Owner(s): ["oncall: r2p"]
2

3
from torch.testing._internal.common_utils import (
4
    TestCase, run_tests, skipIfTorchDynamo,
5
)
6

7
from datetime import timedelta, datetime
8
import tempfile
9
import time
10

11
from torch.monitor import (
12
    Aggregation,
13
    Event,
14
    log_event,
15
    register_event_handler,
16
    unregister_event_handler,
17
    Stat,
18
    TensorboardEventHandler,
19
)
20

21
class TestMonitor(TestCase):
22
    def test_interval_stat(self) -> None:
23
        events = []
24

25
        def handler(event):
26
            events.append(event)
27

28
        handle = register_event_handler(handler)
29
        s = Stat(
30
            "asdf",
31
            (Aggregation.SUM, Aggregation.COUNT),
32
            timedelta(milliseconds=1),
33
        )
34
        self.assertEqual(s.name, "asdf")
35

36
        s.add(2)
37
        for _ in range(100):
38
            # NOTE: different platforms sleep may be inaccurate so we loop
39
            # instead (i.e. win)
40
            time.sleep(1 / 1000)  # ms
41
            s.add(3)
42
            if len(events) >= 1:
43
                break
44
        self.assertGreaterEqual(len(events), 1)
45
        unregister_event_handler(handle)
46

47
    def test_fixed_count_stat(self) -> None:
48
        s = Stat(
49
            "asdf",
50
            (Aggregation.SUM, Aggregation.COUNT),
51
            timedelta(hours=100),
52
            3,
53
        )
54
        s.add(1)
55
        s.add(2)
56
        name = s.name
57
        self.assertEqual(name, "asdf")
58
        self.assertEqual(s.count, 2)
59
        s.add(3)
60
        self.assertEqual(s.count, 0)
61
        self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
62

63
    def test_log_event(self) -> None:
64
        e = Event(
65
            name="torch.monitor.TestEvent",
66
            timestamp=datetime.now(),
67
            data={
68
                "str": "a string",
69
                "float": 1234.0,
70
                "int": 1234,
71
            },
72
        )
73
        self.assertEqual(e.name, "torch.monitor.TestEvent")
74
        self.assertIsNotNone(e.timestamp)
75
        self.assertIsNotNone(e.data)
76
        log_event(e)
77

78
    @skipIfTorchDynamo("Really weird error")
79
    def test_event_handler(self) -> None:
80
        events = []
81

82
        def handler(event: Event) -> None:
83
            events.append(event)
84

85
        handle = register_event_handler(handler)
86
        e = Event(
87
            name="torch.monitor.TestEvent",
88
            timestamp=datetime.now(),
89
            data={},
90
        )
91
        log_event(e)
92
        self.assertEqual(len(events), 1)
93
        self.assertEqual(events[0], e)
94
        log_event(e)
95
        self.assertEqual(len(events), 2)
96

97
        unregister_event_handler(handle)
98
        log_event(e)
99
        self.assertEqual(len(events), 2)
100

101

102
@skipIfTorchDynamo("Really weird error")
103
class TestMonitorTensorboard(TestCase):
104
    def setUp(self):
105
        global SummaryWriter, event_multiplexer
106
        try:
107
            from torch.utils.tensorboard import SummaryWriter
108
            from tensorboard.backend.event_processing import (
109
                plugin_event_multiplexer as event_multiplexer,
110
            )
111
        except ImportError:
112
            return self.skipTest("Skip the test since TensorBoard is not installed")
113
        self.temp_dirs = []
114

115
    def create_summary_writer(self):
116
        temp_dir = tempfile.TemporaryDirectory()  # noqa: P201
117
        self.temp_dirs.append(temp_dir)
118
        return SummaryWriter(temp_dir.name)
119

120
    def tearDown(self):
121
        # Remove directories created by SummaryWriter
122
        for temp_dir in self.temp_dirs:
123
            temp_dir.cleanup()
124

125
    def test_event_handler(self):
126
        with self.create_summary_writer() as w:
127
            handle = register_event_handler(TensorboardEventHandler(w))
128

129
            s = Stat(
130
                "asdf",
131
                (Aggregation.SUM, Aggregation.COUNT),
132
                timedelta(hours=1),
133
                5,
134
            )
135
            for i in range(10):
136
                s.add(i)
137
            self.assertEqual(s.count, 0)
138

139
            unregister_event_handler(handle)
140

141
        mul = event_multiplexer.EventMultiplexer()
142
        mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
143
        mul.Reload()
144
        scalar_dict = mul.PluginRunToTagToContent("scalars")
145
        raw_result = {
146
            tag: mul.Tensors(run, tag)
147
            for run, run_dict in scalar_dict.items()
148
            for tag in run_dict
149
        }
150
        scalars = {
151
            tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
152
        }
153
        self.assertEqual(scalars, {
154
            "asdf.sum": [10],
155
            "asdf.count": [5],
156
        })
157

158

159
if __name__ == '__main__':
160
    run_tests()
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.