6
from datetime import datetime, timedelta
8
from torch.monitor import (
12
register_event_handler,
14
TensorboardEventHandler,
15
unregister_event_handler,
18
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
20
class TestMonitor(TestCase):
21
def test_interval_stat(self) -> None:
27
handle = register_event_handler(handler)
30
(Aggregation.SUM, Aggregation.COUNT),
31
timedelta(milliseconds=1),
33
self.assertEqual(s.name, "asdf")
43
self.assertGreaterEqual(len(events), 1)
44
unregister_event_handler(handle)
46
def test_fixed_count_stat(self) -> None:
49
(Aggregation.SUM, Aggregation.COUNT),
56
self.assertEqual(name, "asdf")
57
self.assertEqual(s.count, 2)
59
self.assertEqual(s.count, 0)
60
self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3})
62
def test_log_event(self) -> None:
64
name="torch.monitor.TestEvent",
65
timestamp=datetime.now(),
72
self.assertEqual(e.name, "torch.monitor.TestEvent")
73
self.assertIsNotNone(e.timestamp)
74
self.assertIsNotNone(e.data)
77
@skipIfTorchDynamo("Really weird error")
78
def test_event_handler(self) -> None:
81
def handler(event: Event) -> None:
84
handle = register_event_handler(handler)
86
name="torch.monitor.TestEvent",
87
timestamp=datetime.now(),
91
self.assertEqual(len(events), 1)
92
self.assertEqual(events[0], e)
94
self.assertEqual(len(events), 2)
96
unregister_event_handler(handle)
98
self.assertEqual(len(events), 2)
100
def test_wait_counter(self) -> None:
101
wait_counter = _WaitCounter(
104
with wait_counter.guard() as wcg:
108
@skipIfTorchDynamo("Really weird error")
109
class TestMonitorTensorboard(TestCase):
111
global SummaryWriter, event_multiplexer
113
from torch.utils.tensorboard import SummaryWriter
114
from tensorboard.backend.event_processing import (
115
plugin_event_multiplexer as event_multiplexer,
118
return self.skipTest("Skip the test since TensorBoard is not installed")
121
def create_summary_writer(self):
122
temp_dir = tempfile.TemporaryDirectory()
123
self.temp_dirs.append(temp_dir)
124
return SummaryWriter(temp_dir.name)
128
for temp_dir in self.temp_dirs:
131
def test_event_handler(self):
132
with self.create_summary_writer() as w:
133
handle = register_event_handler(TensorboardEventHandler(w))
137
(Aggregation.SUM, Aggregation.COUNT),
143
self.assertEqual(s.count, 0)
145
unregister_event_handler(handle)
147
mul = event_multiplexer.EventMultiplexer()
148
mul.AddRunsFromDirectory(self.temp_dirs[-1].name)
150
scalar_dict = mul.PluginRunToTagToContent("scalars")
152
tag: mul.Tensors(run, tag)
153
for run, run_dict in scalar_dict.items()
157
tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items()
159
self.assertEqual(scalars, {
165
if __name__ == '__main__':