9
from pathlib import Path
15
TEST_TENSORBOARD = True
17
import tensorboard.summary.writer.event_file_writer
18
from tensorboard.compat.proto.summary_pb2 import Summary
20
TEST_TENSORBOARD = False
26
HAS_TORCHVISION = False
27
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
32
if os.environ.get('DISPLAY', '') == '':
34
import matplotlib.pyplot as plt
36
TEST_MATPLOTLIB = False
37
skipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib")
40
from torch.testing._internal.common_utils import (
41
instantiate_parametrized_tests,
52
def tensor_N(shape, dtype=float):
53
numel = np.prod(shape)
54
x = (np.arange(numel, dtype=dtype)).reshape(shape)
57
class BaseTestCase(TestCase):
58
""" Base class used for all TensorBoard tests """
61
if not TEST_TENSORBOARD:
62
return self.skipTest("Skip the test since TensorBoard is not installed")
63
if TEST_WITH_CROSSREF:
64
return self.skipTest("Don't run TensorBoard tests with crossref")
67
def createSummaryWriter(self):
70
temp_dir = tempfile.TemporaryDirectory(prefix="test_tensorboard").name
71
self.temp_dirs.append(temp_dir)
72
return SummaryWriter(temp_dir)
77
for temp_dir in self.temp_dirs:
78
if os.path.exists(temp_dir):
79
shutil.rmtree(temp_dir)
83
from google.protobuf import text_format
85
from tensorboard.compat.proto.graph_pb2 import GraphDef
86
from tensorboard.compat.proto.types_pb2 import DataType
88
from torch.utils.tensorboard import summary, SummaryWriter
89
from torch.utils.tensorboard._convert_np import make_np
90
from torch.utils.tensorboard._pytorch_graph import graph
91
from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
92
from torch.utils.tensorboard.summary import int_to_half, tensor_proto
94
class TestTensorBoardPyTorchNumpy(BaseTestCase):
95
def test_pytorch_np(self):
96
tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)]
97
for tensor in tensors:
99
self.assertIsInstance(make_np(tensor), np.ndarray)
102
if torch.cuda.is_available():
103
self.assertIsInstance(make_np(tensor.cuda()), np.ndarray)
106
self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray)
109
if torch.cuda.is_available():
110
self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray)
113
self.assertIsInstance(make_np(0), np.ndarray)
114
self.assertIsInstance(make_np(0.1), np.ndarray)
116
def test_pytorch_autograd_np(self):
117
x = torch.autograd.Variable(torch.empty(1))
118
self.assertIsInstance(make_np(x), np.ndarray)
120
def test_pytorch_write(self):
121
with self.createSummaryWriter() as w:
122
w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0)
124
def test_pytorch_histogram(self):
125
with self.createSummaryWriter() as w:
126
w.add_histogram('float histogram', torch.rand((50,)))
127
w.add_histogram('int histogram', torch.randint(0, 100, (50,)))
128
w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16))
130
def test_pytorch_histogram_raw(self):
131
with self.createSummaryWriter() as w:
133
floats = make_np(torch.rand((num,)))
134
bins = [0.0, 0.25, 0.5, 0.75, 1.0]
135
counts, limits = np.histogram(floats, bins)
136
sum_sq = floats.dot(floats).item()
137
w.add_histogram_raw('float histogram raw',
138
min=floats.min().item(),
139
max=floats.max().item(),
141
sum=floats.sum().item(),
143
bucket_limits=limits[1:].tolist(),
144
bucket_counts=counts.tolist())
146
ints = make_np(torch.randint(0, 100, (num,)))
147
bins = [0, 25, 50, 75, 100]
148
counts, limits = np.histogram(ints, bins)
149
sum_sq = ints.dot(ints).item()
150
w.add_histogram_raw('int histogram raw',
151
min=ints.min().item(),
152
max=ints.max().item(),
154
sum=ints.sum().item(),
156
bucket_limits=limits[1:].tolist(),
157
bucket_counts=counts.tolist())
159
ints = torch.tensor(range(0, 100)).float()
161
counts = torch.histc(ints, bins=nbins, min=0, max=99)
162
limits = torch.tensor(range(nbins))
163
sum_sq = ints.dot(ints).item()
164
w.add_histogram_raw('int histogram raw',
165
min=ints.min().item(),
166
max=ints.max().item(),
168
sum=ints.sum().item(),
170
bucket_limits=limits.tolist(),
171
bucket_counts=counts.tolist())
173
class TestTensorBoardUtils(BaseTestCase):
174
def test_to_HWC(self):
175
test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8)
176
converted = convert_to_HWC(test_image, 'chw')
177
self.assertEqual(converted.shape, (32, 32, 3))
178
test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8)
179
converted = convert_to_HWC(test_image, 'nchw')
180
self.assertEqual(converted.shape, (64, 256, 3))
181
test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8)
182
converted = convert_to_HWC(test_image, 'hw')
183
self.assertEqual(converted.shape, (32, 32, 3))
185
def test_convert_to_HWC_dtype_remains_same(self):
188
test_image = torch.tensor([[[[1, 2, 3], [4, 5, 6]]]], dtype=torch.uint8)
189
tensor = make_np(test_image)
190
tensor = convert_to_HWC(tensor, 'NCHW')
191
scale_factor = summary._calc_scale_factor(tensor)
192
self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1')
195
def test_prepare_video(self):
205
V_input = np.random.random(s)
206
V_after = _prepare_video(np.copy(V_input))
208
V_input = np.swapaxes(V_input, 0, 1)
209
for f in range(total_frame):
210
x = np.reshape(V_input[f], newshape=(-1))
211
y = np.reshape(V_after[f], newshape=(-1))
212
np.testing.assert_array_almost_equal(np.sum(x), np.sum(y))
214
def test_numpy_vid_uint8(self):
215
V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8)
216
V_after = _prepare_video(np.copy(V_input)) * 255
217
total_frame = V_input.shape[1]
218
V_input = np.swapaxes(V_input, 0, 1)
219
for f in range(total_frame):
220
x = np.reshape(V_input[f], newshape=(-1))
221
y = np.reshape(V_after[f], newshape=(-1))
222
np.testing.assert_array_almost_equal(np.sum(x), np.sum(y))
224
freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440]
226
true_positive_counts = [75, 64, 21, 5, 0]
227
false_positive_counts = [150, 105, 18, 0, 0]
228
true_negative_counts = [0, 45, 132, 150, 150]
229
false_negative_counts = [0, 11, 54, 70, 75]
230
precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]
231
recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0]
233
class TestTensorBoardWriter(BaseTestCase):
234
def test_writer(self):
235
with self.createSummaryWriter() as writer:
240
{'lr': 0.1, 'bsize': 1},
241
{'hparam/accuracy': 10, 'hparam/loss': 10}
243
writer.add_scalar('data/scalar_systemtime', 0.1, n_iter)
244
writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter)
245
writer.add_scalar('data/new_style', 0.2, n_iter, new_style=True)
246
writer.add_scalars('data/scalar_group', {
247
"xsinx": n_iter * np.sin(n_iter),
248
"xcosx": n_iter * np.cos(n_iter),
249
"arctanx": np.arctan(n_iter)
251
x = np.zeros((32, 3, 64, 64))
252
writer.add_images('Image', x, n_iter)
253
writer.add_image_with_boxes('imagebox',
254
np.zeros((3, 64, 64)),
255
np.array([[10, 10, 40, 40], [40, 40, 60, 60]]),
257
x = np.zeros(sample_rate * 2)
259
writer.add_audio('myAudio', x, n_iter)
260
writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter)
261
writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter)
262
writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter)
263
writer.add_histogram('hist', np.random.rand(100, 100), n_iter)
264
writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(
266
writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts,
267
false_positive_counts,
268
true_negative_counts,
269
false_negative_counts,
273
v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
274
c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
275
f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int)
276
writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f)
278
class TestTensorBoardSummaryWriter(BaseTestCase):
279
def test_summary_writer_ctx(self):
281
with self.createSummaryWriter() as writer:
282
writer.add_scalar('test', 1)
283
self.assertIs(writer.file_writer, None)
285
def test_summary_writer_close(self):
290
writer = self.createSummaryWriter()
295
self.assertTrue(passed)
297
def test_pathlib(self):
298
with tempfile.TemporaryDirectory(prefix="test_tensorboard_pathlib") as d:
300
with SummaryWriter(p) as writer:
301
writer.add_scalar('test', 1)
303
class TestTensorBoardEmbedding(BaseTestCase):
304
def test_embedding(self):
305
w = self.createSummaryWriter()
306
all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]])
307
all_labels = torch.tensor([33., 44., 55.])
308
all_images = torch.zeros(3, 3, 5, 5)
310
w.add_embedding(all_features,
312
label_img=all_images,
315
dataset_label = ['test'] * 2 + ['train'] * 2
316
all_labels = list(zip(all_labels, dataset_label))
317
w.add_embedding(all_features,
319
label_img=all_images,
320
metadata_header=['digit', 'dataset'],
324
def test_embedding_64(self):
325
w = self.createSummaryWriter()
326
all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]])
327
all_labels = torch.tensor([33., 44., 55.])
328
all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64)
330
w.add_embedding(all_features,
332
label_img=all_images,
335
dataset_label = ['test'] * 2 + ['train'] * 2
336
all_labels = list(zip(all_labels, dataset_label))
337
w.add_embedding(all_features,
339
label_img=all_images,
340
metadata_header=['digit', 'dataset'],
343
class TestTensorBoardSummary(BaseTestCase):
344
def test_uint8_image(self):
346
Tests that uint8 image (pixel values in [0, 255]) is not changed
348
test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8)
349
scale_factor = summary._calc_scale_factor(test_image)
350
self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1')
352
def test_float32_image(self):
354
Tests that float32 image (pixel values in [0, 1]) are scaled correctly
357
test_image = np.random.rand(3, 32, 32).astype(np.float32)
358
scale_factor = summary._calc_scale_factor(test_image)
359
self.assertEqual(scale_factor, 255, msg='Values are in [0, 1], scale factor should be 255')
361
def test_list_input(self):
362
with self.assertRaises(Exception) as e_info:
363
summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow')
365
def test_empty_input(self):
366
with self.assertRaises(Exception) as e_info:
367
summary.histogram('dummy', np.ndarray(0), 'tensorflow')
369
def test_image_with_boxes(self):
370
self.assertTrue(compare_image_proto(summary.image_boxes('dummy',
371
tensor_N(shape=(3, 32, 32)),
372
np.array([[10, 10, 40, 40]])),
375
def test_image_with_one_channel(self):
376
self.assertTrue(compare_image_proto(
377
summary.image('dummy',
378
tensor_N(shape=(1, 8, 8)),
382
def test_image_with_one_channel_batched(self):
383
self.assertTrue(compare_image_proto(
384
summary.image('dummy',
385
tensor_N(shape=(2, 1, 8, 8)),
389
def test_image_with_3_channel_batched(self):
390
self.assertTrue(compare_image_proto(
391
summary.image('dummy',
392
tensor_N(shape=(2, 3, 8, 8)),
396
def test_image_without_channel(self):
397
self.assertTrue(compare_image_proto(
398
summary.image('dummy',
399
tensor_N(shape=(8, 8)),
403
def test_video(self):
408
self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self))
409
summary.video('dummy', np.random.rand(16, 48, 1, 28, 28))
410
summary.video('dummy', np.random.rand(20, 7, 1, 8, 8))
412
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
413
def test_audio(self):
414
self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self))
416
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
418
self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self))
420
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
421
def test_histogram_auto(self):
422
self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self))
424
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
425
def test_histogram_fd(self):
426
self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self))
428
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
429
def test_histogram_doane(self):
430
self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self))
432
def test_custom_scalars(self):
435
'twse': ['Multiline', ['twse/0050', 'twse/2330']]
438
'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']],
439
'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]
442
summary.custom_scalars(layout)
445
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
447
v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
448
c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
449
f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int)
450
mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None)
451
self.assertTrue(compare_proto(mesh, self))
453
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
454
def test_scalar_new_style(self):
455
scalar = summary.scalar('test_scalar', 1.0, new_style=True)
456
self.assertTrue(compare_proto(scalar, self))
457
with self.assertRaises(AssertionError):
458
summary.scalar('test_scalar2', torch.Tensor([1, 2, 3]), new_style=True)
461
def remove_whitespace(string):
462
return string.replace(' ', '').replace('\t', '').replace('\n', '')
464
def get_expected_file(function_ptr):
465
module_id = function_ptr.__class__.__module__
466
test_file = sys.modules[module_id].__file__
468
test_file = ".".join(test_file.split('.')[:-1]) + '.py'
471
test_dir = os.path.dirname(os.path.realpath(test_file))
472
functionName = function_ptr.id().split('.')[-1]
473
return os.path.join(test_dir,
475
'TestTensorBoard.' + functionName + ".expect")
477
def read_expected_content(function_ptr):
478
expected_file = get_expected_file(function_ptr)
479
assert os.path.exists(expected_file), expected_file
480
with open(expected_file) as f:
483
def compare_image_proto(actual_proto, function_ptr):
484
if expecttest.ACCEPT:
485
expected_file = get_expected_file(function_ptr)
486
with open(expected_file, 'w') as f:
487
f.write(text_format.MessageToString(actual_proto))
489
expected_str = read_expected_content(function_ptr)
490
expected_proto = Summary()
491
text_format.Parse(expected_str, expected_proto)
493
[actual, expected] = [actual_proto.value[0], expected_proto.value[0]]
494
actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string))
495
expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string))
498
actual.tag == expected.tag and
499
actual.image.height == expected.image.height and
500
actual.image.width == expected.image.width and
501
actual.image.colorspace == expected.image.colorspace and
502
actual_img == expected_img
505
def compare_proto(str_to_compare, function_ptr):
506
if expecttest.ACCEPT:
507
write_proto(str_to_compare, function_ptr)
509
expected = read_expected_content(function_ptr)
510
str_to_compare = str(str_to_compare)
511
return remove_whitespace(str_to_compare) == remove_whitespace(expected)
513
def write_proto(str_to_compare, function_ptr):
514
expected_file = get_expected_file(function_ptr)
515
with open(expected_file, 'w') as f:
516
f.write(str(str_to_compare))
518
class TestTensorBoardPytorchGraph(BaseTestCase):
519
def test_pytorch_graph(self):
520
dummy_input = (torch.zeros(1, 3),)
522
class myLinear(torch.nn.Module):
523
def __init__(self) -> None:
525
self.l = torch.nn.Linear(3, 5)
527
def forward(self, x):
530
with self.createSummaryWriter() as w:
531
w.add_graph(myLinear(), dummy_input)
533
actual_proto, _ = graph(myLinear(), dummy_input)
535
expected_str = read_expected_content(self)
536
expected_proto = GraphDef()
537
text_format.Parse(expected_str, expected_proto)
539
self.assertEqual(len(expected_proto.node), len(actual_proto.node))
540
for i in range(len(expected_proto.node)):
541
expected_node = expected_proto.node[i]
542
actual_node = actual_proto.node[i]
543
self.assertEqual(expected_node.name, actual_node.name)
544
self.assertEqual(expected_node.op, actual_node.op)
545
self.assertEqual(expected_node.input, actual_node.input)
546
self.assertEqual(expected_node.device, actual_node.device)
548
sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
550
def test_nested_nn_squential(self):
552
dummy_input = torch.randn(2, 3)
554
class InnerNNSquential(torch.nn.Module):
555
def __init__(self, dim1, dim2):
557
self.inner_nn_squential = torch.nn.Sequential(
558
torch.nn.Linear(dim1, dim2),
559
torch.nn.Linear(dim2, dim1),
562
def forward(self, x):
563
x = self.inner_nn_squential(x)
566
class OuterNNSquential(torch.nn.Module):
567
def __init__(self, dim1=3, dim2=4, depth=2):
570
for _ in range(depth):
571
layers.append(InnerNNSquential(dim1, dim2))
572
self.outer_nn_squential = torch.nn.Sequential(*layers)
574
def forward(self, x):
575
x = self.outer_nn_squential(x)
578
with self.createSummaryWriter() as w:
579
w.add_graph(OuterNNSquential(), dummy_input)
581
actual_proto, _ = graph(OuterNNSquential(), dummy_input)
583
expected_str = read_expected_content(self)
584
expected_proto = GraphDef()
585
text_format.Parse(expected_str, expected_proto)
587
self.assertEqual(len(expected_proto.node), len(actual_proto.node))
588
for i in range(len(expected_proto.node)):
589
expected_node = expected_proto.node[i]
590
actual_node = actual_proto.node[i]
591
self.assertEqual(expected_node.name, actual_node.name)
592
self.assertEqual(expected_node.op, actual_node.op)
593
self.assertEqual(expected_node.input, actual_node.input)
594
self.assertEqual(expected_node.device, actual_node.device)
596
sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
598
def test_pytorch_graph_dict_input(self):
599
class Model(torch.nn.Module):
600
def __init__(self) -> None:
602
self.l = torch.nn.Linear(3, 5)
604
def forward(self, x):
607
class ModelDict(torch.nn.Module):
608
def __init__(self) -> None:
610
self.l = torch.nn.Linear(3, 5)
612
def forward(self, x):
613
return {"out": self.l(x)}
616
dummy_input = torch.zeros(1, 3)
618
with self.createSummaryWriter() as w:
619
w.add_graph(Model(), dummy_input)
621
with self.createSummaryWriter() as w:
622
w.add_graph(Model(), dummy_input, use_strict_trace=True)
625
with self.assertRaises(RuntimeError):
626
with self.createSummaryWriter() as w:
627
w.add_graph(ModelDict(), dummy_input, use_strict_trace=True)
629
with self.createSummaryWriter() as w:
630
w.add_graph(ModelDict(), dummy_input, use_strict_trace=False)
633
def test_mlp_graph(self):
634
dummy_input = (torch.zeros(2, 1, 28, 28),)
642
class myMLP(torch.nn.Module):
643
def __init__(self) -> None:
645
self.input_len = 1 * 28 * 28
646
self.fc1 = torch.nn.Linear(self.input_len, 1200)
647
self.fc2 = torch.nn.Linear(1200, 1200)
648
self.fc3 = torch.nn.Linear(1200, 10)
650
def forward(self, x, update_batch_stats=True):
651
h = torch.nn.functional.relu(
652
self.fc1(x.view(-1, self.input_len)))
654
h = torch.nn.functional.relu(h)
658
with self.createSummaryWriter() as w:
659
w.add_graph(myMLP(), dummy_input)
661
def test_wrong_input_size(self):
662
with self.assertRaises(RuntimeError) as e_info:
663
dummy_input = torch.rand(1, 9)
664
model = torch.nn.Linear(3, 5)
665
with self.createSummaryWriter() as w:
666
w.add_graph(model, dummy_input)
669
def test_torchvision_smoke(self):
670
model_input_shapes = {
671
'alexnet': (2, 3, 224, 224),
672
'resnet34': (2, 3, 224, 224),
673
'resnet152': (2, 3, 224, 224),
674
'densenet121': (2, 3, 224, 224),
675
'vgg16': (2, 3, 224, 224),
676
'vgg19': (2, 3, 224, 224),
677
'vgg16_bn': (2, 3, 224, 224),
678
'vgg19_bn': (2, 3, 224, 224),
679
'mobilenet_v2': (2, 3, 224, 224),
681
for model_name, input_shape in model_input_shapes.items():
682
with self.createSummaryWriter() as w:
683
model = getattr(torchvision.models, model_name)()
684
w.add_graph(model, torch.zeros(input_shape))
686
class TestTensorBoardFigure(BaseTestCase):
688
def test_figure(self):
689
writer = self.createSummaryWriter()
691
figure, axes = plt.figure(), plt.gca()
692
circle1 = plt.Circle((0.2, 0.5), 0.2, color='r')
693
circle2 = plt.Circle((0.8, 0.5), 0.2, color='g')
694
axes.add_patch(circle1)
695
axes.add_patch(circle2)
699
writer.add_figure("add_figure/figure", figure, 0, close=False)
700
self.assertTrue(plt.fignum_exists(figure.number))
702
writer.add_figure("add_figure/figure", figure, 1)
703
if matplotlib.__version__ != '3.3.0':
704
self.assertFalse(plt.fignum_exists(figure.number))
706
print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")
711
def test_figure_list(self):
712
writer = self.createSummaryWriter()
716
figure = plt.figure()
717
plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i))
722
figures.append(figure)
724
writer.add_figure("add_figure/figure_list", figures, 0, close=False)
725
self.assertTrue(all(plt.fignum_exists(figure.number) is True for figure in figures))
727
writer.add_figure("add_figure/figure_list", figures, 1)
728
if matplotlib.__version__ != '3.3.0':
729
self.assertTrue(all(plt.fignum_exists(figure.number) is False for figure in figures))
731
print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")
735
class TestTensorBoardNumpy(BaseTestCase):
736
@unittest.skipIf(IS_WINDOWS, "Skipping on windows, see https://github.com/pytorch/pytorch/pull/109349 ")
737
@unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
738
def test_scalar(self):
740
self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
741
res = make_np(1 << 64 - 1)
742
self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
743
res = make_np(np.float16(1.00000087))
744
self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
745
res = make_np(np.float128(1.00008 + 9))
746
self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
747
res = make_np(np.int64(100000000000))
748
self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
750
def test_pytorch_np_expect_fail(self):
751
with self.assertRaises(NotImplementedError):
752
res = make_np({'pytorch': 1.0})
756
class TestTensorProtoSummary(BaseTestCase):
758
"tensor_type,proto_type",
760
(torch.float16, DataType.DT_HALF),
761
(torch.bfloat16, DataType.DT_BFLOAT16),
764
@skipIfTorchDynamo("Unsuitable test for Dynamo, behavior changes with version")
765
def test_half_tensor_proto(self, tensor_type, proto_type):
766
float_values = [1.0, 2.0, 3.0]
767
actual_proto = tensor_proto(
769
torch.tensor(float_values, dtype=tensor_type),
771
self.assertSequenceEqual(
772
[int_to_half(x) for x in actual_proto.half_val],
775
self.assertTrue(actual_proto.dtype == proto_type)
777
def test_float_tensor_proto(self):
778
float_values = [1.0, 2.0, 3.0]
780
tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor
782
self.assertEqual(actual_proto.float_val, float_values)
783
self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT)
785
def test_int_tensor_proto(self):
786
int_values = [1, 2, 3]
788
tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32))
792
self.assertEqual(actual_proto.int_val, int_values)
793
self.assertTrue(actual_proto.dtype == DataType.DT_INT32)
795
def test_scalar_tensor_proto(self):
798
tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor
800
self.assertAlmostEqual(actual_proto.float_val[0], scalar_value)
802
def test_complex_tensor_proto(self):
803
real = torch.tensor([1.0, 2.0])
804
imag = torch.tensor([3.0, 4.0])
806
tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor
808
self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0])
810
def test_empty_tensor_proto(self):
811
actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
812
self.assertEqual(actual_proto.float_val, [])
814
instantiate_parametrized_tests(TestTensorProtoSummary)
816
if __name__ == '__main__':