1
# Owner(s): ["oncall: r2p"]
3
from torch.testing._internal.common_utils import (
4
TestCase, run_tests, skipIfTorchDynamo,
7
from datetime import timedelta, datetime
11
from torch.monitor import (
15
register_event_handler,
16
unregister_event_handler,
18
TensorboardEventHandler,
21
class TestMonitor(TestCase):
22
def test_interval_stat(self) -> None:
28
handle = register_event_handler(handler)
31
(Aggregation.SUM, Aggregation.COUNT),
32
timedelta(milliseconds=1),
34
self.assertEqual(s.name, "asdf")
38
# NOTE: different platforms sleep may be inaccurate so we loop
40
time.sleep(1 / 1000) # ms
44
self.assertGreaterEqual(len(events), 1)
45
unregister_event_handler(handle)
47
def test_fixed_count_stat(self) -> None:
50
(Aggregation.SUM, Aggregation.COUNT),
57
self.assertEqual(name, "asdf")
58
self.assertEqual(s.count, 2)
60
self.assertEqual(s.count, 0)
61
self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
63
def test_log_event(self) -> None:
65
name="torch.monitor.TestEvent",
66
timestamp=datetime.now(),
73
self.assertEqual(e.name, "torch.monitor.TestEvent")
74
self.assertIsNotNone(e.timestamp)
75
self.assertIsNotNone(e.data)
78
@skipIfTorchDynamo("Really weird error")
79
def test_event_handler(self) -> None:
82
def handler(event: Event) -> None:
85
handle = register_event_handler(handler)
87
name="torch.monitor.TestEvent",
88
timestamp=datetime.now(),
92
self.assertEqual(len(events), 1)
93
self.assertEqual(events[0], e)
95
self.assertEqual(len(events), 2)
97
unregister_event_handler(handle)
99
self.assertEqual(len(events), 2)
102
@skipIfTorchDynamo("Really weird error")
103
class TestMonitorTensorboard(TestCase):
105
global SummaryWriter, event_multiplexer
107
from torch.utils.tensorboard import SummaryWriter
108
from tensorboard.backend.event_processing import (
109
plugin_event_multiplexer as event_multiplexer,
112
return self.skipTest("Skip the test since TensorBoard is not installed")
115
def create_summary_writer(self):
116
temp_dir = tempfile.TemporaryDirectory() # noqa: P201
117
self.temp_dirs.append(temp_dir)
118
return SummaryWriter(temp_dir.name)
121
# Remove directories created by SummaryWriter
122
for temp_dir in self.temp_dirs:
125
def test_event_handler(self):
126
with self.create_summary_writer() as w:
127
handle = register_event_handler(TensorboardEventHandler(w))
131
(Aggregation.SUM, Aggregation.COUNT),
137
self.assertEqual(s.count, 0)
139
unregister_event_handler(handle)
141
mul = event_multiplexer.EventMultiplexer()
142
mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
144
scalar_dict = mul.PluginRunToTagToContent("scalars")
146
tag: mul.Tensors(run, tag)
147
for run, run_dict in scalar_dict.items()
151
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
153
self.assertEqual(scalars, {
159
if __name__ == '__main__':