pytorch

Форк
0
/
test_tensorboard.py 
817 строк · 32.3 Кб
1
# Owner(s): ["module: unknown"]
2

3
import io
4
import os
5
import shutil
6
import sys
7
import tempfile
8
import unittest
9
from pathlib import Path
10

11
import expecttest
12
import numpy as np
13

14

15
TEST_TENSORBOARD = True
16
try:
17
    import tensorboard.summary.writer.event_file_writer  # noqa: F401
18
    from tensorboard.compat.proto.summary_pb2 import Summary
19
except ImportError:
20
    TEST_TENSORBOARD = False
21

22
HAS_TORCHVISION = True
23
try:
24
    import torchvision
25
except ImportError:
26
    HAS_TORCHVISION = False
27
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
28

29
TEST_MATPLOTLIB = True
30
try:
31
    import matplotlib
32
    if os.environ.get('DISPLAY', '') == '':
33
        matplotlib.use('Agg')
34
    import matplotlib.pyplot as plt
35
except ImportError:
36
    TEST_MATPLOTLIB = False
37
skipIfNoMatplotlib = unittest.skipIf(not TEST_MATPLOTLIB, "no matplotlib")
38

39
import torch
40
from torch.testing._internal.common_utils import (
41
    instantiate_parametrized_tests,
42
    IS_MACOS,
43
    IS_WINDOWS,
44
    parametrize,
45
    run_tests,
46
    TEST_WITH_CROSSREF,
47
    TestCase,
48
    skipIfTorchDynamo,
49
)
50

51

52
def tensor_N(shape, dtype=float):
53
    numel = np.prod(shape)
54
    x = (np.arange(numel, dtype=dtype)).reshape(shape)
55
    return x
56

57
class BaseTestCase(TestCase):
58
    """ Base class used for all TensorBoard tests """
59
    def setUp(self):
60
        super().setUp()
61
        if not TEST_TENSORBOARD:
62
            return self.skipTest("Skip the test since TensorBoard is not installed")
63
        if TEST_WITH_CROSSREF:
64
            return self.skipTest("Don't run TensorBoard tests with crossref")
65
        self.temp_dirs = []
66

67
    def createSummaryWriter(self):
68
        # Just to get the name of the directory in a writable place. tearDown()
69
        # is responsible for clean-ups.
70
        temp_dir = tempfile.TemporaryDirectory(prefix="test_tensorboard").name
71
        self.temp_dirs.append(temp_dir)
72
        return SummaryWriter(temp_dir)
73

74
    def tearDown(self):
75
        super().tearDown()
76
        # Remove directories created by SummaryWriter
77
        for temp_dir in self.temp_dirs:
78
            if os.path.exists(temp_dir):
79
                shutil.rmtree(temp_dir)
80

81

82
if TEST_TENSORBOARD:
83
    from google.protobuf import text_format
84
    from PIL import Image
85
    from tensorboard.compat.proto.graph_pb2 import GraphDef
86
    from tensorboard.compat.proto.types_pb2 import DataType
87

88
    from torch.utils.tensorboard import summary, SummaryWriter
89
    from torch.utils.tensorboard._convert_np import make_np
90
    from torch.utils.tensorboard._pytorch_graph import graph
91
    from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC
92
    from torch.utils.tensorboard.summary import int_to_half, tensor_proto
93

94
class TestTensorBoardPyTorchNumpy(BaseTestCase):
95
    def test_pytorch_np(self):
96
        tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)]
97
        for tensor in tensors:
98
            # regular tensor
99
            self.assertIsInstance(make_np(tensor), np.ndarray)
100

101
            # CUDA tensor
102
            if torch.cuda.is_available():
103
                self.assertIsInstance(make_np(tensor.cuda()), np.ndarray)
104

105
            # regular variable
106
            self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray)
107

108
            # CUDA variable
109
            if torch.cuda.is_available():
110
                self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray)
111

112
        # python primitive type
113
        self.assertIsInstance(make_np(0), np.ndarray)
114
        self.assertIsInstance(make_np(0.1), np.ndarray)
115

116
    def test_pytorch_autograd_np(self):
117
        x = torch.autograd.Variable(torch.empty(1))
118
        self.assertIsInstance(make_np(x), np.ndarray)
119

120
    def test_pytorch_write(self):
121
        with self.createSummaryWriter() as w:
122
            w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0)
123

124
    def test_pytorch_histogram(self):
125
        with self.createSummaryWriter() as w:
126
            w.add_histogram('float histogram', torch.rand((50,)))
127
            w.add_histogram('int histogram', torch.randint(0, 100, (50,)))
128
            w.add_histogram('bfloat16 histogram', torch.rand(50, dtype=torch.bfloat16))
129

130
    def test_pytorch_histogram_raw(self):
131
        with self.createSummaryWriter() as w:
132
            num = 50
133
            floats = make_np(torch.rand((num,)))
134
            bins = [0.0, 0.25, 0.5, 0.75, 1.0]
135
            counts, limits = np.histogram(floats, bins)
136
            sum_sq = floats.dot(floats).item()
137
            w.add_histogram_raw('float histogram raw',
138
                                min=floats.min().item(),
139
                                max=floats.max().item(),
140
                                num=num,
141
                                sum=floats.sum().item(),
142
                                sum_squares=sum_sq,
143
                                bucket_limits=limits[1:].tolist(),
144
                                bucket_counts=counts.tolist())
145

146
            ints = make_np(torch.randint(0, 100, (num,)))
147
            bins = [0, 25, 50, 75, 100]
148
            counts, limits = np.histogram(ints, bins)
149
            sum_sq = ints.dot(ints).item()
150
            w.add_histogram_raw('int histogram raw',
151
                                min=ints.min().item(),
152
                                max=ints.max().item(),
153
                                num=num,
154
                                sum=ints.sum().item(),
155
                                sum_squares=sum_sq,
156
                                bucket_limits=limits[1:].tolist(),
157
                                bucket_counts=counts.tolist())
158

159
            ints = torch.tensor(range(0, 100)).float()
160
            nbins = 100
161
            counts = torch.histc(ints, bins=nbins, min=0, max=99)
162
            limits = torch.tensor(range(nbins))
163
            sum_sq = ints.dot(ints).item()
164
            w.add_histogram_raw('int histogram raw',
165
                                min=ints.min().item(),
166
                                max=ints.max().item(),
167
                                num=num,
168
                                sum=ints.sum().item(),
169
                                sum_squares=sum_sq,
170
                                bucket_limits=limits.tolist(),
171
                                bucket_counts=counts.tolist())
172

173
class TestTensorBoardUtils(BaseTestCase):
174
    def test_to_HWC(self):
175
        test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8)
176
        converted = convert_to_HWC(test_image, 'chw')
177
        self.assertEqual(converted.shape, (32, 32, 3))
178
        test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8)
179
        converted = convert_to_HWC(test_image, 'nchw')
180
        self.assertEqual(converted.shape, (64, 256, 3))
181
        test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8)
182
        converted = convert_to_HWC(test_image, 'hw')
183
        self.assertEqual(converted.shape, (32, 32, 3))
184

185
    def test_convert_to_HWC_dtype_remains_same(self):
186
        # test to ensure convert_to_HWC restores the dtype of input np array and
187
        # thus the scale_factor calculated for the image is 1
188
        test_image = torch.tensor([[[[1, 2, 3], [4, 5, 6]]]], dtype=torch.uint8)
189
        tensor = make_np(test_image)
190
        tensor = convert_to_HWC(tensor, 'NCHW')
191
        scale_factor = summary._calc_scale_factor(tensor)
192
        self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1')
193

194

195
    def test_prepare_video(self):
196
        # At each timeframe, the sum over all other
197
        # dimensions of the video should be the same.
198
        shapes = [
199
            (16, 30, 3, 28, 28),
200
            (36, 30, 3, 28, 28),
201
            (19, 29, 3, 23, 19),
202
            (3, 3, 3, 3, 3)
203
        ]
204
        for s in shapes:
205
            V_input = np.random.random(s)
206
            V_after = _prepare_video(np.copy(V_input))
207
            total_frame = s[1]
208
            V_input = np.swapaxes(V_input, 0, 1)
209
            for f in range(total_frame):
210
                x = np.reshape(V_input[f], newshape=(-1))
211
                y = np.reshape(V_after[f], newshape=(-1))
212
                np.testing.assert_array_almost_equal(np.sum(x), np.sum(y))
213

214
    def test_numpy_vid_uint8(self):
215
        V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8)
216
        V_after = _prepare_video(np.copy(V_input)) * 255
217
        total_frame = V_input.shape[1]
218
        V_input = np.swapaxes(V_input, 0, 1)
219
        for f in range(total_frame):
220
            x = np.reshape(V_input[f], newshape=(-1))
221
            y = np.reshape(V_after[f], newshape=(-1))
222
            np.testing.assert_array_almost_equal(np.sum(x), np.sum(y))
223

224
freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440]
225

226
true_positive_counts = [75, 64, 21, 5, 0]
227
false_positive_counts = [150, 105, 18, 0, 0]
228
true_negative_counts = [0, 45, 132, 150, 150]
229
false_negative_counts = [0, 11, 54, 70, 75]
230
precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0]
231
recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0]
232

233
class TestTensorBoardWriter(BaseTestCase):
234
    def test_writer(self):
235
        with self.createSummaryWriter() as writer:
236
            sample_rate = 44100
237

238
            n_iter = 0
239
            writer.add_hparams(
240
                {'lr': 0.1, 'bsize': 1},
241
                {'hparam/accuracy': 10, 'hparam/loss': 10}
242
            )
243
            writer.add_scalar('data/scalar_systemtime', 0.1, n_iter)
244
            writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter)
245
            writer.add_scalar('data/new_style', 0.2, n_iter, new_style=True)
246
            writer.add_scalars('data/scalar_group', {
247
                "xsinx": n_iter * np.sin(n_iter),
248
                "xcosx": n_iter * np.cos(n_iter),
249
                "arctanx": np.arctan(n_iter)
250
            }, n_iter)
251
            x = np.zeros((32, 3, 64, 64))  # output from network
252
            writer.add_images('Image', x, n_iter)  # Tensor
253
            writer.add_image_with_boxes('imagebox',
254
                                        np.zeros((3, 64, 64)),
255
                                        np.array([[10, 10, 40, 40], [40, 40, 60, 60]]),
256
                                        n_iter)
257
            x = np.zeros(sample_rate * 2)
258

259
            writer.add_audio('myAudio', x, n_iter)
260
            writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter)
261
            writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter)
262
            writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter)
263
            writer.add_histogram('hist', np.random.rand(100, 100), n_iter)
264
            writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(
265
                100), n_iter)  # needs tensorboard 0.4RC or later
266
            writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts,
267
                                    false_positive_counts,
268
                                    true_negative_counts,
269
                                    false_negative_counts,
270
                                    precision,
271
                                    recall, n_iter)
272

273
            v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
274
            c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
275
            f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int)
276
            writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f)
277

278
class TestTensorBoardSummaryWriter(BaseTestCase):
279
    def test_summary_writer_ctx(self):
280
        # after using a SummaryWriter as a ctx it should be closed
281
        with self.createSummaryWriter() as writer:
282
            writer.add_scalar('test', 1)
283
        self.assertIs(writer.file_writer, None)
284

285
    def test_summary_writer_close(self):
286
        # Opening and closing SummaryWriter a lot should not run into
287
        # OSError: [Errno 24] Too many open files
288
        passed = True
289
        try:
290
            writer = self.createSummaryWriter()
291
            writer.close()
292
        except OSError:
293
            passed = False
294

295
        self.assertTrue(passed)
296

297
    def test_pathlib(self):
298
        with tempfile.TemporaryDirectory(prefix="test_tensorboard_pathlib") as d:
299
            p = Path(d)
300
            with SummaryWriter(p) as writer:
301
                writer.add_scalar('test', 1)
302

303
class TestTensorBoardEmbedding(BaseTestCase):
304
    def test_embedding(self):
305
        w = self.createSummaryWriter()
306
        all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]])
307
        all_labels = torch.tensor([33., 44., 55.])
308
        all_images = torch.zeros(3, 3, 5, 5)
309

310
        w.add_embedding(all_features,
311
                        metadata=all_labels,
312
                        label_img=all_images,
313
                        global_step=2)
314

315
        dataset_label = ['test'] * 2 + ['train'] * 2
316
        all_labels = list(zip(all_labels, dataset_label))
317
        w.add_embedding(all_features,
318
                        metadata=all_labels,
319
                        label_img=all_images,
320
                        metadata_header=['digit', 'dataset'],
321
                        global_step=2)
322
        # assert...
323

324
    def test_embedding_64(self):
325
        w = self.createSummaryWriter()
326
        all_features = torch.tensor([[1., 2., 3.], [5., 4., 1.], [3., 7., 7.]])
327
        all_labels = torch.tensor([33., 44., 55.])
328
        all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64)
329

330
        w.add_embedding(all_features,
331
                        metadata=all_labels,
332
                        label_img=all_images,
333
                        global_step=2)
334

335
        dataset_label = ['test'] * 2 + ['train'] * 2
336
        all_labels = list(zip(all_labels, dataset_label))
337
        w.add_embedding(all_features,
338
                        metadata=all_labels,
339
                        label_img=all_images,
340
                        metadata_header=['digit', 'dataset'],
341
                        global_step=2)
342

343
class TestTensorBoardSummary(BaseTestCase):
344
    def test_uint8_image(self):
345
        '''
346
        Tests that uint8 image (pixel values in [0, 255]) is not changed
347
        '''
348
        test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8)
349
        scale_factor = summary._calc_scale_factor(test_image)
350
        self.assertEqual(scale_factor, 1, msg='Values are already in [0, 255], scale factor should be 1')
351

352
    def test_float32_image(self):
353
        '''
354
        Tests that float32 image (pixel values in [0, 1]) are scaled correctly
355
        to [0, 255]
356
        '''
357
        test_image = np.random.rand(3, 32, 32).astype(np.float32)
358
        scale_factor = summary._calc_scale_factor(test_image)
359
        self.assertEqual(scale_factor, 255, msg='Values are in [0, 1], scale factor should be 255')
360

361
    def test_list_input(self):
362
        with self.assertRaises(Exception) as e_info:
363
            summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow')
364

365
    def test_empty_input(self):
366
        with self.assertRaises(Exception) as e_info:
367
            summary.histogram('dummy', np.ndarray(0), 'tensorflow')
368

369
    def test_image_with_boxes(self):
370
        self.assertTrue(compare_image_proto(summary.image_boxes('dummy',
371
                                            tensor_N(shape=(3, 32, 32)),
372
                                            np.array([[10, 10, 40, 40]])),
373
                                            self))
374

375
    def test_image_with_one_channel(self):
376
        self.assertTrue(compare_image_proto(
377
            summary.image('dummy',
378
                          tensor_N(shape=(1, 8, 8)),
379
                          dataformats='CHW'),
380
                          self))  # noqa: E131
381

382
    def test_image_with_one_channel_batched(self):
383
        self.assertTrue(compare_image_proto(
384
            summary.image('dummy',
385
                          tensor_N(shape=(2, 1, 8, 8)),
386
                          dataformats='NCHW'),
387
                          self))  # noqa: E131
388

389
    def test_image_with_3_channel_batched(self):
390
        self.assertTrue(compare_image_proto(
391
            summary.image('dummy',
392
                          tensor_N(shape=(2, 3, 8, 8)),
393
                          dataformats='NCHW'),
394
                          self))  # noqa: E131
395

396
    def test_image_without_channel(self):
397
        self.assertTrue(compare_image_proto(
398
            summary.image('dummy',
399
                          tensor_N(shape=(8, 8)),
400
                          dataformats='HW'),
401
                          self))  # noqa: E131
402

403
    def test_video(self):
404
        try:
405
            import moviepy  # noqa: F401
406
        except ImportError:
407
            return
408
        self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self))
409
        summary.video('dummy', np.random.rand(16, 48, 1, 28, 28))
410
        summary.video('dummy', np.random.rand(20, 7, 1, 8, 8))
411

412
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
413
    def test_audio(self):
414
        self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self))
415

416
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
417
    def test_text(self):
418
        self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self))
419

420
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
421
    def test_histogram_auto(self):
422
        self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self))
423

424
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
425
    def test_histogram_fd(self):
426
        self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self))
427

428
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
429
    def test_histogram_doane(self):
430
        self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self))
431

432
    def test_custom_scalars(self):
433
        layout = {
434
            'Taiwan': {
435
                'twse': ['Multiline', ['twse/0050', 'twse/2330']]
436
            },
437
            'USA': {
438
                'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']],
439
                'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]
440
            }
441
        }
442
        summary.custom_scalars(layout)  # only smoke test. Because protobuf in python2/3 serialize dictionary differently.
443

444

445
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
446
    def test_mesh(self):
447
        v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float)
448
        c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int)
449
        f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int)
450
        mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None)
451
        self.assertTrue(compare_proto(mesh, self))
452

453
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
454
    def test_scalar_new_style(self):
455
        scalar = summary.scalar('test_scalar', 1.0, new_style=True)
456
        self.assertTrue(compare_proto(scalar, self))
457
        with self.assertRaises(AssertionError):
458
            summary.scalar('test_scalar2', torch.Tensor([1, 2, 3]), new_style=True)
459

460

461
def remove_whitespace(string):
462
    return string.replace(' ', '').replace('\t', '').replace('\n', '')
463

464
def get_expected_file(function_ptr):
465
    module_id = function_ptr.__class__.__module__
466
    test_file = sys.modules[module_id].__file__
467
    # Look for the .py file (since __file__ could be pyc).
468
    test_file = ".".join(test_file.split('.')[:-1]) + '.py'
469

470
    # Use realpath to follow symlinks appropriately.
471
    test_dir = os.path.dirname(os.path.realpath(test_file))
472
    functionName = function_ptr.id().split('.')[-1]
473
    return os.path.join(test_dir,
474
                        "expect",
475
                        'TestTensorBoard.' + functionName + ".expect")
476

477
def read_expected_content(function_ptr):
478
    expected_file = get_expected_file(function_ptr)
479
    assert os.path.exists(expected_file), expected_file
480
    with open(expected_file) as f:
481
        return f.read()
482

483
def compare_image_proto(actual_proto, function_ptr):
484
    if expecttest.ACCEPT:
485
        expected_file = get_expected_file(function_ptr)
486
        with open(expected_file, 'w') as f:
487
            f.write(text_format.MessageToString(actual_proto))
488
        return True
489
    expected_str = read_expected_content(function_ptr)
490
    expected_proto = Summary()
491
    text_format.Parse(expected_str, expected_proto)
492

493
    [actual, expected] = [actual_proto.value[0], expected_proto.value[0]]
494
    actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string))
495
    expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string))
496

497
    return (
498
        actual.tag == expected.tag and
499
        actual.image.height == expected.image.height and
500
        actual.image.width == expected.image.width and
501
        actual.image.colorspace == expected.image.colorspace and
502
        actual_img == expected_img
503
    )
504

505
def compare_proto(str_to_compare, function_ptr):
506
    if expecttest.ACCEPT:
507
        write_proto(str_to_compare, function_ptr)
508
        return True
509
    expected = read_expected_content(function_ptr)
510
    str_to_compare = str(str_to_compare)
511
    return remove_whitespace(str_to_compare) == remove_whitespace(expected)
512

513
def write_proto(str_to_compare, function_ptr):
514
    expected_file = get_expected_file(function_ptr)
515
    with open(expected_file, 'w') as f:
516
        f.write(str(str_to_compare))
517

518
class TestTensorBoardPytorchGraph(BaseTestCase):
519
    def test_pytorch_graph(self):
520
        dummy_input = (torch.zeros(1, 3),)
521

522
        class myLinear(torch.nn.Module):
523
            def __init__(self) -> None:
524
                super().__init__()
525
                self.l = torch.nn.Linear(3, 5)
526

527
            def forward(self, x):
528
                return self.l(x)
529

530
        with self.createSummaryWriter() as w:
531
            w.add_graph(myLinear(), dummy_input)
532

533
        actual_proto, _ = graph(myLinear(), dummy_input)
534

535
        expected_str = read_expected_content(self)
536
        expected_proto = GraphDef()
537
        text_format.Parse(expected_str, expected_proto)
538

539
        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
540
        for i in range(len(expected_proto.node)):
541
            expected_node = expected_proto.node[i]
542
            actual_node = actual_proto.node[i]
543
            self.assertEqual(expected_node.name, actual_node.name)
544
            self.assertEqual(expected_node.op, actual_node.op)
545
            self.assertEqual(expected_node.input, actual_node.input)
546
            self.assertEqual(expected_node.device, actual_node.device)
547
            self.assertEqual(
548
                sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
549

550
    def test_nested_nn_squential(self):
551

552
        dummy_input = torch.randn(2, 3)
553

554
        class InnerNNSquential(torch.nn.Module):
555
            def __init__(self, dim1, dim2):
556
                super().__init__()
557
                self.inner_nn_squential = torch.nn.Sequential(
558
                    torch.nn.Linear(dim1, dim2),
559
                    torch.nn.Linear(dim2, dim1),
560
                )
561

562
            def forward(self, x):
563
                x = self.inner_nn_squential(x)
564
                return x
565

566
        class OuterNNSquential(torch.nn.Module):
567
            def __init__(self, dim1=3, dim2=4, depth=2):
568
                super().__init__()
569
                layers = []
570
                for _ in range(depth):
571
                    layers.append(InnerNNSquential(dim1, dim2))
572
                self.outer_nn_squential = torch.nn.Sequential(*layers)
573

574
            def forward(self, x):
575
                x = self.outer_nn_squential(x)
576
                return x
577

578
        with self.createSummaryWriter() as w:
579
            w.add_graph(OuterNNSquential(), dummy_input)
580

581
        actual_proto, _ = graph(OuterNNSquential(), dummy_input)
582

583
        expected_str = read_expected_content(self)
584
        expected_proto = GraphDef()
585
        text_format.Parse(expected_str, expected_proto)
586

587
        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
588
        for i in range(len(expected_proto.node)):
589
            expected_node = expected_proto.node[i]
590
            actual_node = actual_proto.node[i]
591
            self.assertEqual(expected_node.name, actual_node.name)
592
            self.assertEqual(expected_node.op, actual_node.op)
593
            self.assertEqual(expected_node.input, actual_node.input)
594
            self.assertEqual(expected_node.device, actual_node.device)
595
            self.assertEqual(
596
                sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
597

598
    def test_pytorch_graph_dict_input(self):
599
        class Model(torch.nn.Module):
600
            def __init__(self) -> None:
601
                super().__init__()
602
                self.l = torch.nn.Linear(3, 5)
603

604
            def forward(self, x):
605
                return self.l(x)
606

607
        class ModelDict(torch.nn.Module):
608
            def __init__(self) -> None:
609
                super().__init__()
610
                self.l = torch.nn.Linear(3, 5)
611

612
            def forward(self, x):
613
                return {"out": self.l(x)}
614

615

616
        dummy_input = torch.zeros(1, 3)
617

618
        with self.createSummaryWriter() as w:
619
            w.add_graph(Model(), dummy_input)
620

621
        with self.createSummaryWriter() as w:
622
            w.add_graph(Model(), dummy_input, use_strict_trace=True)
623

624
        # expect error: Encountering a dict at the output of the tracer...
625
        with self.assertRaises(RuntimeError):
626
            with self.createSummaryWriter() as w:
627
                w.add_graph(ModelDict(), dummy_input, use_strict_trace=True)
628

629
        with self.createSummaryWriter() as w:
630
            w.add_graph(ModelDict(), dummy_input, use_strict_trace=False)
631

632

633
    def test_mlp_graph(self):
634
        dummy_input = (torch.zeros(2, 1, 28, 28),)
635

636
        # This MLP class with the above input is expected
637
        # to fail JIT optimizations as seen at
638
        # https://github.com/pytorch/pytorch/issues/18903
639
        #
640
        # However, it should not raise an error during
641
        # the add_graph call and still continue.
642
        class myMLP(torch.nn.Module):
643
            def __init__(self) -> None:
644
                super().__init__()
645
                self.input_len = 1 * 28 * 28
646
                self.fc1 = torch.nn.Linear(self.input_len, 1200)
647
                self.fc2 = torch.nn.Linear(1200, 1200)
648
                self.fc3 = torch.nn.Linear(1200, 10)
649

650
            def forward(self, x, update_batch_stats=True):
651
                h = torch.nn.functional.relu(
652
                    self.fc1(x.view(-1, self.input_len)))
653
                h = self.fc2(h)
654
                h = torch.nn.functional.relu(h)
655
                h = self.fc3(h)
656
                return h
657

658
        with self.createSummaryWriter() as w:
659
            w.add_graph(myMLP(), dummy_input)
660

661
    def test_wrong_input_size(self):
662
        with self.assertRaises(RuntimeError) as e_info:
663
            dummy_input = torch.rand(1, 9)
664
            model = torch.nn.Linear(3, 5)
665
            with self.createSummaryWriter() as w:
666
                w.add_graph(model, dummy_input)  # error
667

668
    @skipIfNoTorchVision
669
    def test_torchvision_smoke(self):
670
        model_input_shapes = {
671
            'alexnet': (2, 3, 224, 224),
672
            'resnet34': (2, 3, 224, 224),
673
            'resnet152': (2, 3, 224, 224),
674
            'densenet121': (2, 3, 224, 224),
675
            'vgg16': (2, 3, 224, 224),
676
            'vgg19': (2, 3, 224, 224),
677
            'vgg16_bn': (2, 3, 224, 224),
678
            'vgg19_bn': (2, 3, 224, 224),
679
            'mobilenet_v2': (2, 3, 224, 224),
680
        }
681
        for model_name, input_shape in model_input_shapes.items():
682
            with self.createSummaryWriter() as w:
683
                model = getattr(torchvision.models, model_name)()
684
                w.add_graph(model, torch.zeros(input_shape))
685

686
class TestTensorBoardFigure(BaseTestCase):
687
    @skipIfNoMatplotlib
688
    def test_figure(self):
689
        writer = self.createSummaryWriter()
690

691
        figure, axes = plt.figure(), plt.gca()
692
        circle1 = plt.Circle((0.2, 0.5), 0.2, color='r')
693
        circle2 = plt.Circle((0.8, 0.5), 0.2, color='g')
694
        axes.add_patch(circle1)
695
        axes.add_patch(circle2)
696
        plt.axis('scaled')
697
        plt.tight_layout()
698

699
        writer.add_figure("add_figure/figure", figure, 0, close=False)
700
        self.assertTrue(plt.fignum_exists(figure.number))
701

702
        writer.add_figure("add_figure/figure", figure, 1)
703
        if matplotlib.__version__ != '3.3.0':
704
            self.assertFalse(plt.fignum_exists(figure.number))
705
        else:
706
            print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")
707

708
        writer.close()
709

710
    @skipIfNoMatplotlib
711
    def test_figure_list(self):
712
        writer = self.createSummaryWriter()
713

714
        figures = []
715
        for i in range(5):
716
            figure = plt.figure()
717
            plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i))
718
            plt.xlabel("X")
719
            plt.xlabel("Y")
720
            plt.legend()
721
            plt.tight_layout()
722
            figures.append(figure)
723

724
        writer.add_figure("add_figure/figure_list", figures, 0, close=False)
725
        self.assertTrue(all(plt.fignum_exists(figure.number) is True for figure in figures))  # noqa: F812
726

727
        writer.add_figure("add_figure/figure_list", figures, 1)
728
        if matplotlib.__version__ != '3.3.0':
729
            self.assertTrue(all(plt.fignum_exists(figure.number) is False for figure in figures))  # noqa: F812
730
        else:
731
            print("Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163")
732

733
        writer.close()
734

735
class TestTensorBoardNumpy(BaseTestCase):
736
    @unittest.skipIf(IS_WINDOWS, "Skipping on windows, see https://github.com/pytorch/pytorch/pull/109349 ")
737
    @unittest.skipIf(IS_MACOS, "Skipping on mac, see https://github.com/pytorch/pytorch/pull/109349 ")
738
    def test_scalar(self):
739
        res = make_np(1.1)
740
        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
741
        res = make_np(1 << 64 - 1)  # uint64_max
742
        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
743
        res = make_np(np.float16(1.00000087))
744
        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
745
        res = make_np(np.float128(1.00008 + 9))
746
        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
747
        res = make_np(np.int64(100000000000))
748
        self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,))
749

750
    def test_pytorch_np_expect_fail(self):
751
        with self.assertRaises(NotImplementedError):
752
            res = make_np({'pytorch': 1.0})
753

754

755

756
class TestTensorProtoSummary(BaseTestCase):
757
    @parametrize(
758
        "tensor_type,proto_type",
759
        [
760
            (torch.float16, DataType.DT_HALF),
761
            (torch.bfloat16, DataType.DT_BFLOAT16),
762
        ],
763
    )
764
    @skipIfTorchDynamo("Unsuitable test for Dynamo, behavior changes with version")
765
    def test_half_tensor_proto(self, tensor_type, proto_type):
766
        float_values = [1.0, 2.0, 3.0]
767
        actual_proto = tensor_proto(
768
            "dummy",
769
            torch.tensor(float_values, dtype=tensor_type),
770
        ).value[0].tensor
771
        self.assertSequenceEqual(
772
            [int_to_half(x) for x in actual_proto.half_val],
773
            float_values,
774
        )
775
        self.assertTrue(actual_proto.dtype == proto_type)
776

777
    def test_float_tensor_proto(self):
778
        float_values = [1.0, 2.0, 3.0]
779
        actual_proto = (
780
            tensor_proto("dummy", torch.tensor(float_values)).value[0].tensor
781
        )
782
        self.assertEqual(actual_proto.float_val, float_values)
783
        self.assertTrue(actual_proto.dtype == DataType.DT_FLOAT)
784

785
    def test_int_tensor_proto(self):
786
        int_values = [1, 2, 3]
787
        actual_proto = (
788
            tensor_proto("dummy", torch.tensor(int_values, dtype=torch.int32))
789
            .value[0]
790
            .tensor
791
        )
792
        self.assertEqual(actual_proto.int_val, int_values)
793
        self.assertTrue(actual_proto.dtype == DataType.DT_INT32)
794

795
    def test_scalar_tensor_proto(self):
796
        scalar_value = 0.1
797
        actual_proto = (
798
            tensor_proto("dummy", torch.tensor(scalar_value)).value[0].tensor
799
        )
800
        self.assertAlmostEqual(actual_proto.float_val[0], scalar_value)
801

802
    def test_complex_tensor_proto(self):
803
        real = torch.tensor([1.0, 2.0])
804
        imag = torch.tensor([3.0, 4.0])
805
        actual_proto = (
806
            tensor_proto("dummy", torch.complex(real, imag)).value[0].tensor
807
        )
808
        self.assertEqual(actual_proto.scomplex_val, [1.0, 3.0, 2.0, 4.0])
809

810
    def test_empty_tensor_proto(self):
811
        actual_proto = tensor_proto("dummy", torch.empty(0)).value[0].tensor
812
        self.assertEqual(actual_proto.float_val, [])
813

814
instantiate_parametrized_tests(TestTensorProtoSummary)
815

816
if __name__ == '__main__':
817
    run_tests()
818

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.