pytorch

Форк
0
/
test_profiler.py 
2681 строка · 99.0 Кб
1
# Owner(s): ["oncall: profiler"]
2

3
import collections
4
import gc
5
import json
6
import mmap
7
import os
8
import pickle
9
import random
10
import re
11
import struct
12
import subprocess
13
import sys
14
import tempfile
15
import threading
16
import time
17
import unittest
18
from dataclasses import dataclass, field
19
from typing import List, Optional
20
from unittest.mock import patch
21

22
import expecttest
23

24
import torch
25
import torch.nn as nn
26
import torch.optim
27
import torch.utils.data
28
from torch._C._profiler import _ExperimentalConfig, _ExtraFields_PyCall
29
from torch.autograd.profiler import KinetoStepTracker, profile as _profile
30
from torch.autograd.profiler_legacy import profile as _profile_legacy
31
from torch.profiler import (
32
    _utils,
33
    DeviceType,
34
    kineto_available,
35
    profile,
36
    ProfilerAction,
37
    ProfilerActivity,
38
    record_function,
39
    supported_activities,
40
)
41
from torch.profiler._pattern_matcher import (
42
    Conv2dBiasFollowedByBatchNorm2dPattern,
43
    ExtraCUDACopyPattern,
44
    ForLoopIndexingPattern,
45
    FP32MatMulPattern,
46
    GradNotSetToNonePattern,
47
    MatMulDimInFP16Pattern,
48
    NamePattern,
49
    OptimizerSingleTensorPattern,
50
    Pattern,
51
    report_all_anti_patterns,
52
    SynchronizedDataLoaderPattern,
53
)
54
from torch.testing._internal.common_cuda import TEST_MULTIGPU
55
from torch.testing._internal.common_device_type import skipCUDAVersionIn
56
from torch.testing._internal.common_utils import (
57
    instantiate_parametrized_tests,
58
    IS_ARM64,
59
    IS_JETSON,
60
    IS_LINUX,
61
    IS_WINDOWS,
62
    parametrize,
63
    run_tests,
64
    serialTest,
65
    skipIfTorchDynamo,
66
    TemporaryDirectoryName,
67
    TemporaryFileName,
68
    TEST_WITH_ASAN,
69
    TEST_WITH_CROSSREF,
70
    TEST_WITH_ROCM,
71
    TestCase,
72
)
73

74

75
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
76
# This causes an issue in the multithreading test because we check all events
77
# in that test with their tids. The events that correspond to these lingering
78
# threads all have TID of (uint64_t)(-1) which is invalid.
79
# The work around is turnning off monitoring thread when tqdm is loaded.
80
# Since these are unit tests, it is safe to turn off monitor thread.
81
try:
82
    import tqdm
83

84
    tqdm.tqdm.monitor_interval = 0
85
except ImportError:
86
    pass
87

88
try:
89
    import psutil
90

91
    HAS_PSUTIL = True
92
except ModuleNotFoundError:
93
    HAS_PSUTIL = False
94
    psutil = None
95

96

97
@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
98
@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
99
@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows")
100
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
101
class TestProfilerCUDA(TestCase):
102
    @skipCUDAVersionIn([(11, 5)])  # https://github.com/pytorch/pytorch/issues/69023
103
    def test_mem_leak(self):
104
        """Checks that there's no memory leak when using profiler with CUDA"""
105
        t = torch.rand(1, 1).cuda()
106
        p = psutil.Process()
107
        last_rss = collections.deque(maxlen=5)
108
        for outer_idx in range(10):
109
            with _profile(use_cuda=True):
110
                for _ in range(1024):
111
                    t = torch.mm(t, t)
112

113
            gc.collect()
114
            torch.cuda.empty_cache()
115
            last_rss.append(p.memory_info().rss)
116

117
        # with CUDA events leaking the increase in memory was ~7 MB between
118
        # profiler invocations above
119
        is_increasing = all(
120
            last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss))
121
        )
122
        max_diff = -1
123
        for idx in range(1, len(last_rss)):
124
            max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
125
        self.assertTrue(
126
            not (is_increasing and max_diff > 100 * 1024),
127
            msg=f"memory usage is increasing, {str(last_rss)}",
128
        )
129

130
    def test_custom_module_input_op_ids(self):
131
        class MyFunc(torch.autograd.Function):
132
            @staticmethod
133
            def forward(ctx, x):
134
                ctx.save_for_backward(x)
135
                return x
136

137
            @staticmethod
138
            def backward(ctx, gO):
139
                (x,) = ctx.saved_tensors
140
                return x
141

142
        def custom_layer(input_ten):
143
            return MyFunc.apply(input_ten)
144

145
        # Only testing that emit_nvtx runs when
146
        # record_shapes option is enabled.
147
        with torch.autograd.profiler.emit_nvtx(record_shapes=True) as prof:
148
            x = torch.randn(10, 10, requires_grad=True)
149
            y = torch.randn(10, 10, requires_grad=True)
150
            z = x + y
151
            s = custom_layer(z)
152
            q = s.sum()
153
            q.backward()
154

155
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
156
    def test_cudagraph_profiling_workaround(self):
157
        import subprocess
158

159
        # repro taken from #75504
160
        # Launch in a separate process to catch hanging/illegal memory errors
161
        # and to make sure CUPTI isn't already initialized.
162
        p = subprocess.check_call(
163
            [
164
                sys.executable,
165
                "-c",
166
                """
167
import os
168
import torch
169
from torch.profiler import ProfilerActivity, profile
170

171
def add_one(in_: torch.Tensor):
172
    return in_ + 1
173

174
sample_arg = torch.zeros(10, device="cuda").requires_grad_(True)
175

176
# add this before cuda graphs are created
177
torch.profiler._utils._init_for_cuda_graphs()
178

179
add_one_graphed = torch.cuda.graphs.make_graphed_callables(add_one, sample_args=(sample_arg,))
180
zeros = torch.zeros(10, device="cuda")
181
out = add_one_graphed(zeros)
182
assert out[0] == 1
183

184
with profile(activities=[ProfilerActivity.CPU]):
185
    add_one_graphed(zeros)
186

187
with profile(activities=[ProfilerActivity.CUDA]):
188
    add_one_graphed(zeros)
189
""",
190
            ],
191
            universal_newlines=True,
192
            timeout=60,
193
        )
194

195
        # ^ this will throw an exception if the script fails.
196

197

198
@unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
199
class TestProfilerITT(TestCase):
200
    def test_custom_module_input_op_ids(self):
201
        class MyFunc(torch.autograd.Function):
202
            @staticmethod
203
            def forward(ctx, x):
204
                ctx.save_for_backward(x)
205
                return x
206

207
            @staticmethod
208
            def backward(ctx, gO):
209
                (x,) = ctx.saved_tensors
210
                return x
211

212
        def custom_layer(input_ten):
213
            return MyFunc.apply(input_ten)
214

215
        # Only testing that emit_itt runs when
216
        # record_shapes option is enabled.
217
        with torch.autograd.profiler.emit_itt(record_shapes=True) as prof:
218
            x = torch.randn(10, 10, requires_grad=True)
219
            y = torch.randn(10, 10, requires_grad=True)
220
            z = x + y
221
            s = custom_layer(z)
222
            q = s.sum()
223
            q.backward()
224

225

226
@instantiate_parametrized_tests
227
class TestProfiler(TestCase):
228
    @unittest.skipIf(
229
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
230
    )
231
    def test_source(self):
232
        """Checks that source code attribution works for eager, TS and autograd mode"""
233
        # avoid automatic inlining
234
        prev_opt = torch._C._get_graph_executor_optimize()
235
        torch._C._set_graph_executor_optimize(False)
236

237
        @torch.jit.script
238
        def ts_method_2(x, y):
239
            return torch.matmul(x, y)
240

241
        @torch.jit.script
242
        def ts_method_1(x, y, z):
243
            a = x + z
244
            w = ts_method_2(x, y) + a
245
            return w.sum()
246

247
        class DummyModule(nn.Module):
248
            def __init__(self) -> None:
249
                super().__init__()
250
                self.conv = torch.nn.Conv2d(
251
                    3, 2, kernel_size=1, stride=2, padding=3, bias=False
252
                )
253

254
            def forward(self, x):
255
                return self.conv(x)
256

257
        mod = DummyModule()
258

259
        def call_module(x):
260
            return mod(x)
261

262
        with _profile(
263
            with_stack=True,
264
            use_kineto=kineto_available(),
265
            experimental_config=_ExperimentalConfig(verbose=True),
266
        ) as p:
267
            x = torch.randn(10, 10, requires_grad=True)
268
            y = torch.randn(10, 10, requires_grad=True)
269
            z = x + y
270
            w = ts_method_1(x, y, z)
271
            v = 2 * w
272
            v.backward()
273
            a = torch.randn(2, 3, 2, 2, requires_grad=True)
274
            b = call_module(a)
275
            c = b.sum()
276
            c.backward()
277

278
        for e in p.function_events:
279
            if "aten::add" in e.name or "AddBackward" in e.name:
280
                self.assertTrue(any("test_profiler" in entry for entry in e.stack))
281
                self.assertTrue(
282
                    any(
283
                        (
284
                            "test_source" in entry
285
                            or "ts_method_1" in entry
286
                            or "ts_method_2" in entry
287
                        )
288
                        for entry in e.stack
289
                    )
290
                )
291

292
        # TODO: https://github.com/pytorch/kineto/issues/617
293
        if kineto_available() and not IS_WINDOWS:
294
            with TemporaryFileName(mode="w+") as fname:
295
                p.export_chrome_trace(fname)
296
                with open(fname) as f:
297
                    events = json.load(f)["traceEvents"]
298

299
                def extract(pattern: str):
300
                    matches = [e for e in events if re.search(pattern, e["name"])]
301
                    self.assertEqual(
302
                        len(matches), 1, repr([e["name"] for e in matches])
303
                    )
304
                    return matches[0]
305

306
                module_event = extract(r"DummyModule_0")
307
                wrapper_event = extract(r"call_module")
308
                self.assertEqual(
309
                    module_event["args"]["Python parent id"],
310
                    wrapper_event["args"]["Python id"],
311
                )
312

313
        torch._C._set_graph_executor_optimize(prev_opt)
314

315
    @parametrize(
316
        "name,thread_spec",
317
        {
318
            "basic": ((False, False),),
319
            "multiple_preexisting": ((False, False),) * 2,
320
            "open_in_scope": ((True, False),),
321
            "close_in_scope": ((False, True),),
322
            "complex": (
323
                # Large number of background threads
324
                (False, False),
325
                (False, False),
326
                (False, False),
327
                (False, False),
328
                # some of which finish during profiling
329
                (False, True),
330
                (False, True),
331
                # And the profiled section is also multithreaded
332
                (True, False),
333
                (True, True),
334
            ),
335
        }.items(),
336
        name_fn=lambda name, thread_spec: name,
337
    )
338
    @serialTest()
339
    @parametrize("work_in_main_thread", [True, False])
340
    def test_source_multithreaded(self, name, thread_spec, work_in_main_thread):
341
        """Test various threading configurations.
342

343
        `thread_spec` is a Tuple[Tuple[bool, bool], ...] where each pair is a
344
        thread. The first bool indicates if the thread should be started under
345
        the profiler context and the second is if it should be joined under the
346
        profiler context.
347
        """
348

349
        timeout = 15
350
        num_threads = len(thread_spec) + 1  # Main thread
351
        start_barrier = threading.Barrier(num_threads, timeout=timeout)
352
        end_barrier = threading.Barrier(num_threads, timeout=timeout)
353

354
        class Task(threading.Thread):
355
            def __init__(self) -> None:
356
                self._end_gate = threading.Event()
357
                super().__init__(daemon=True)
358
                self.start()
359
                self.finished = False
360

361
            def run(self):
362
                self._run(self._end_gate)
363

364
            def release(self):
365
                self._end_gate.set()
366

367
            @staticmethod
368
            def _run(end_gate=None):
369
                def known_preexisting_function():
370
                    start_barrier.wait()
371

372
                # Fixed point that we can use to test capture of functions
373
                # which are already running when profiling is enabled.
374
                known_preexisting_function()
375

376
                model = torch.nn.Sequential(
377
                    torch.nn.Linear(10, 10),
378
                    torch.nn.ReLU(),
379
                )
380

381
                def invoked_during_run():
382
                    pass
383

384
                invoked_during_run()
385

386
                _ = model(torch.rand(4, 10))
387
                end_barrier.wait()
388

389
                if end_gate is not None:
390
                    end_gate.wait(timeout=timeout)
391

392
        threads = {}
393

394
        def add_threads(context: bool):
395
            for idx, (start_under_profiler, _) in enumerate(thread_spec):
396
                if start_under_profiler == context:
397
                    assert idx not in threads
398
                    threads[idx] = Task()
399

400
        def join_threads(context: bool):
401
            for idx, (_, end_under_profiler) in enumerate(thread_spec):
402
                if end_under_profiler == context:
403
                    threads[idx].release()
404

405
            for idx, (_, end_under_profiler) in enumerate(thread_spec):
406
                t = threads[idx]
407
                if end_under_profiler == context:
408
                    t.join(timeout=timeout)
409

410
        try:
411
            add_threads(False)
412
            with torch.profiler.profile(with_stack=True) as prof:
413
                # Threads added while the profiler are running will not be observed
414
                # since there is no way to hook into Python's thread start call to
415
                # register the observer. These are here purely to verify safety.
416
                add_threads(True)
417

418
                if work_in_main_thread:
419
                    Task._run()
420
                else:
421
                    start_barrier.wait()
422
                    end_barrier.wait()
423

424
                join_threads(True)
425
            join_threads(False)
426

427
        finally:
428
            # It is very important that we clean up everything because the
429
            # Python tracer will detect ALL active threads. (Even orphans from
430
            # prior failed tests.) If we don't clean up properly we can
431
            # contaminate subsequent tests.
432
            start_barrier.abort()
433
            end_barrier.abort()
434
            for t in threads.values():
435
                t.release()
436

437
            for t in threads.values():
438
                t.join(timeout=timeout)
439

440
            for t in threads.values():
441
                self.assertFalse(t.is_alive())
442

443
        roots = prof.profiler.kineto_results.experimental_event_tree()
444
        nodes = [
445
            node
446
            for node in _utils.traverse_dfs(roots)
447
            if isinstance(node.extra_fields, _ExtraFields_PyCall)
448
        ]
449
        tid_counts = collections.Counter([node.start_tid for node in nodes])
450

451
        prior_threads = sum(
452
            not start_under_profiler for start_under_profiler, _ in thread_spec
453
        )
454
        expected_threads = prior_threads + 1
455
        self.assertEqual(
456
            len(tid_counts), expected_threads, f"{expected_threads}, {tid_counts}"
457
        )
458
        self.assertEqual(len(nodes), sum(tid_counts.values()))
459

460
        # Profiler uses uint64_t max as a placeholder until TID can be determined.
461
        no_tid = 2**64 - 1
462
        self.assertFalse(no_tid in tid_counts)
463

464
        worker_threads = prior_threads + (1 if work_in_main_thread else 0)
465

466
        observed_preexisting = [
467
            node.start_tid
468
            for node in nodes
469
            if "known_preexisting_function" in node.name
470
        ]
471
        self.assertEqual(len(observed_preexisting), worker_threads)
472
        self.assertEqual(len(observed_preexisting), len(set(observed_preexisting)))
473

474
        observed_during_run = [
475
            node.start_tid for node in nodes if "invoked_during_run" in node.name
476
        ]
477
        self.assertEqual(len(observed_during_run), worker_threads)
478
        self.assertEqual(len(observed_during_run), len(set(observed_during_run)))
479

480
    def payload(self, use_cuda=False):
481
        x = torch.randn(10, 10)
482
        if use_cuda:
483
            x = x.cuda()
484
        y = torch.randn(10, 10)
485
        if use_cuda:
486
            y = y.cuda()
487
        z = torch.mm(x, y)
488
        z = z + y
489
        if use_cuda:
490
            z = z.cpu()
491

492
    def _check_stats(self, profiler_stats):
493
        self.assertGreater(profiler_stats.profiling_window_duration_sec, 0)
494
        self.assertGreater(profiler_stats.number_of_events, 0)
495
        self.assertGreater(profiler_stats.profiler_prepare_call_duration_us, 0)
496
        self.assertGreater(profiler_stats.profiler_enable_call_duration_us, 0)
497
        self.assertGreater(profiler_stats.profiler_disable_call_duration_us, 0)
498
        self.assertGreater(profiler_stats.parse_kineto_call_duration_us, 0)
499
        self.assertGreater(
500
            profiler_stats.function_events_build_tree_call_duration_us, 0
501
        )
502

503
    @unittest.skipIf(not kineto_available(), "Kineto is required")
504
    def test_kineto(self):
505
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
506
        with _profile(use_cuda=use_cuda, use_kineto=True):
507
            self.payload(use_cuda=use_cuda)
508

509
        # rerun to avoid initial start overhead
510
        with _profile(use_cuda=use_cuda, use_kineto=True) as p:
511
            self.payload(use_cuda=use_cuda)
512

513
        self.assertTrue("aten::mm" in str(p))
514

515
        output = p.key_averages().table(
516
            sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
517
            row_limit=-1,
518
        )
519
        # print(output)
520
        found_gemm = False
521
        found_memcpy = False
522
        found_mm = False
523
        for e in p.function_events:
524
            if "aten::mm" in e.name:
525
                found_mm = True
526
            if "gemm" in e.name.lower() or "Cijk" in e.name:
527
                found_gemm = True
528
            if "memcpy" in e.name.lower():
529
                found_memcpy = True
530
        if use_cuda:
531
            self.assertTrue(found_gemm)
532
            self.assertTrue(found_memcpy)
533
        else:
534
            self.assertTrue(found_mm)
535
        self._check_stats(p._stats)
536
        # p.export_chrome_trace("/tmp/test_trace.json")
537

538
    @unittest.skipIf(not kineto_available(), "Kineto is required")
539
    @unittest.skipIf(not TEST_MULTIGPU, "Multiple GPUs needed")
540
    @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
541
    def test_kineto_multigpu(self):
542
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
543
            for gpu_id in [0, 1]:
544
                x = torch.randn(10, 10).cuda(gpu_id)
545
                y = torch.randn(10, 10).cuda(gpu_id)
546
                z = x.matmul(y)
547

548
        found_gemm_0 = False
549
        found_gemm_1 = False
550
        found_cuda = False
551
        for evt in prof.events():
552
            if "gemm" in evt.name.lower() and evt.device_type == DeviceType.CUDA:
553
                if evt.device_index == 0:
554
                    found_gemm_0 = True
555
                elif evt.device_index == 1:
556
                    found_gemm_1 = True
557
            if "cuda" in evt.name.lower() and evt.device_type == DeviceType.CPU:
558
                found_cuda = True
559

560
        self.assertTrue(found_gemm_0)
561
        self.assertTrue(found_gemm_1)
562
        self.assertTrue(found_cuda)
563
        self._check_stats(prof._stats())
564

565
    def test_memory_profiler(self):
566
        def run_profiler(tensor_creation_fn):
567
            # collecting allocs / deallocs
568
            with _profile(
569
                profile_memory=True,
570
                record_shapes=True,
571
                use_kineto=kineto_available(),
572
            ) as prof:
573
                x = None
574
                with record_function("test_user_scope_alloc"):
575
                    x = tensor_creation_fn()
576
                with record_function("test_user_scope_dealloc"):
577
                    del x
578
            return prof.key_averages(group_by_input_shape=True)
579

580
        def check_metrics(stats, metric, allocs=None, deallocs=None):
581
            stat_metrics = {}
582
            # print(stats)
583
            for stat in stats:
584
                stat_metrics[stat.key] = getattr(stat, metric)
585
            # print(stat_metrics)
586
            if allocs is not None:
587
                for alloc_fn in allocs:
588
                    self.assertTrue(alloc_fn in stat_metrics)
589
                    self.assertGreater(
590
                        stat_metrics[alloc_fn], 0, f"alloc_fn = {alloc_fn}"
591
                    )
592
            if deallocs is not None:
593
                for dealloc_fn in deallocs:
594
                    self.assertTrue(dealloc_fn in stat_metrics)
595
                    self.assertLess(
596
                        stat_metrics[dealloc_fn], 0, f"alloc_fn = {dealloc_fn}"
597
                    )
598

599
        def create_cpu_tensor():
600
            return torch.rand(10, 10)
601

602
        def create_cuda_tensor():
603
            return torch.rand(10, 10).cuda()
604

605
        def create_mkldnn_tensor():
606
            return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()
607

608
        stats = run_profiler(create_cpu_tensor)
609
        check_metrics(
610
            stats,
611
            "cpu_memory_usage",
612
            allocs=[
613
                "aten::empty",
614
                "aten::rand",
615
                "test_user_scope_alloc",
616
            ],
617
            deallocs=[
618
                "test_user_scope_dealloc",
619
            ],
620
        )
621

622
        if kineto_available():
623
            with TemporaryFileName(mode="w+") as fname:
624
                with profile(profile_memory=True) as prof:
625
                    x = None
626
                    with record_function("test_user_scope_alloc"):
627
                        x = create_cpu_tensor()
628
                    with record_function("test_user_scope_dealloc"):
629
                        del x
630
                prof.export_chrome_trace(fname)
631
                with open(fname) as f:
632
                    trace = json.load(f)
633
                    assert "traceEvents" in trace
634
                    events = trace["traceEvents"]
635
                    found_memory_events = False
636
                    for evt in events:
637
                        assert "name" in evt
638
                        if evt["name"] == "[memory]":
639
                            found_memory_events = True
640
                            assert "args" in evt
641
                            assert "Addr" in evt["args"]
642
                            assert "Device Type" in evt["args"]
643
                            assert "Device Id" in evt["args"]
644
                            assert "Bytes" in evt["args"]
645

646
                            # Memory should be an instantaneous event.
647
                            assert "dur" not in evt["args"]
648
                            assert "cat" not in evt["args"]
649
                    assert found_memory_events
650

651
        if torch.cuda.is_available():
652
            create_cuda_tensor()
653
            stats = run_profiler(create_cuda_tensor)
654
            check_metrics(
655
                stats,
656
                "device_memory_usage",
657
                allocs=[
658
                    "test_user_scope_alloc",
659
                    "aten::to",
660
                    "aten::empty_strided",
661
                ],
662
                deallocs=[
663
                    "test_user_scope_dealloc",
664
                ],
665
            )
666
            check_metrics(
667
                stats,
668
                "cpu_memory_usage",
669
                allocs=[
670
                    "aten::rand",
671
                    "aten::empty",
672
                ],
673
            )
674

675
        if torch.backends.mkldnn.is_available():
676
            create_mkldnn_tensor()
677
            stats = run_profiler(create_mkldnn_tensor)
678
            check_metrics(
679
                stats,
680
                "cpu_memory_usage",
681
                allocs=[
682
                    "test_user_scope_alloc",
683
                    "aten::rand",
684
                    "aten::empty",
685
                    "aten::to_mkldnn",
686
                ],
687
                deallocs=[
688
                    "test_user_scope_dealloc",
689
                ],
690
            )
691

692
        # check top-level memory events
693
        with _profile(profile_memory=True, use_kineto=kineto_available()) as prof:
694
            x = torch.rand(10, 10)
695
            del x
696
            if torch.cuda.is_available():
697
                y = torch.rand(10, 10).cuda()
698
                del y
699
            gc.collect()
700
        stats = prof.key_averages(group_by_input_shape=True)
701
        check_metrics(
702
            stats,
703
            "cpu_memory_usage",
704
            allocs=["aten::rand", "aten::empty"],
705
            deallocs=["[memory]"],
706
        )
707
        if torch.cuda.is_available():
708
            check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
709

710
    @unittest.skipIf(
711
        IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
712
    )
713
    def test_oom_tracing(self):
714
        def run_profiler(tensor_creation_fn):
715
            with _profile(profile_memory=True, record_shapes=True) as prof:
716
                with self.assertRaisesRegex(RuntimeError, ".*[tT]ried to allocate.*"):
717
                    x = tensor_creation_fn()
718
                return prof
719

720
        def create_cuda_tensor_oom():
721
            device = torch.device("cuda:0")
722
            return torch.empty(
723
                1024, 1024, 1024, 1024, dtype=torch.float32, device=device
724
            )
725

726
        def check_trace(fname):
727
            prof.export_chrome_trace(fname)
728
            with open(fname) as f:
729
                trace = json.load(f)
730
                self.assertTrue("traceEvents" in trace)
731
                events = trace["traceEvents"]
732
                found_out_of_memory_events = False
733
                for evt in events:
734
                    self.assertTrue("name" in evt)
735
                    if evt["name"] == "[OutOfMemory]":
736
                        found_out_of_memory_events = True
737
                        self.assertTrue("args" in evt)
738
                        self.assertTrue("Device Type" in evt["args"])
739
                        self.assertTrue("Device Id" in evt["args"])
740
                        self.assertTrue("Bytes" in evt["args"])
741

742
                        # Memory should be an instantaneous event.
743
                        self.assertTrue("dur" not in evt["args"])
744
                        self.assertTrue("cat" not in evt["args"])
745
                self.assertTrue(found_out_of_memory_events)
746

747
        if torch.cuda.is_available():
748
            with TemporaryFileName(mode="w+") as fname:
749
                prof = run_profiler(create_cuda_tensor_oom)
750
                check_trace(fname)
751

752
    @unittest.skipIf(not kineto_available(), "Kineto is required")
753
    def test_module_hierarchy(self):
754
        class A(nn.Module):
755
            def my_new_method(self, x):
756
                return x * 3
757

758
            def forward_impl_(self, x, y):
759
                return self.my_new_method(x) + y
760

761
            def forward(self, x, y):
762
                y = y - 2
763
                return self.forward_impl_(x, y)
764

765
        class B(nn.Module):
766
            def forward(self, x):
767
                return x + 2
768

769
        class C(nn.Module):
770
            def __init__(self) -> None:
771
                super().__init__()
772
                self.A0 = A()
773
                self.B0 = B()
774

775
            def call_b(self, x):
776
                return self.B0.forward(x)
777

778
            def forward(self, x, y):
779
                return self.A0.forward(x, y) + self.call_b(x)
780

781
        model = C()
782
        model = torch.jit.script(model)
783
        input_a = torch.rand(128, 128)
784
        input_b = torch.rand(128, 128)
785
        op_to_module_hierarchy = {}
786
        op_to_module_hierarchy["aten::sub"] = ["TOP(C)::forward.A0(A)::forward."]
787
        op_to_module_hierarchy["aten::mul"] = [
788
            "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.SELF(A)::my_new_method."
789
        ]
790
        op_to_module_hierarchy["aten::add"] = [
791
            "TOP(C)::forward.A0(A)::forward.SELF(A)::forward_impl_.",
792
            "TOP(C)::forward.SELF(C)::call_b.B0(B)::forward.",
793
            "TOP(C)::forward.",
794
        ]
795
        with TemporaryFileName(mode="w+") as fname:
796
            with profile(
797
                activities=[torch.profiler.ProfilerActivity.CPU],
798
                with_modules=True,
799
            ) as prof:
800
                model(input_a, input_b)
801
            prof.export_chrome_trace(fname)
802
            with open(fname) as f:
803
                trace = json.load(f)
804
                assert "traceEvents" in trace
805
                events = trace["traceEvents"]
806
                found_memory_events = False
807
                for evt in events:
808
                    assert "name" in evt
809
                    if "args" in evt:
810
                        op_name = evt["name"]
811
                        if "Module Hierarchy" in evt["args"]:
812
                            hierarchy = evt["args"]["Module Hierarchy"]
813
                            if op_name in op_to_module_hierarchy:
814
                                assert hierarchy in op_to_module_hierarchy[op_name]
815

816
    def test_high_level_trace(self):
817
        """Checks that python side high level events are recorded."""
818

819
        class RepeatedDataset(torch.utils.data.Dataset):
820
            def __init__(self, N, D_in, D_out):
821
                self.N = N
822
                self.x = torch.randn(N, D_in)
823
                self.y = torch.randn(N, D_out)
824

825
            def __len__(self):
826
                return self.N
827

828
            def __getitem__(self, idx):
829
                return self.x, self.y
830

831
        class TwoLayerNet(torch.nn.Module):
832
            def __init__(self, D_in, H, D_out):
833
                super().__init__()
834
                self.linear1 = torch.nn.Linear(D_in, H)
835
                self.linear2 = torch.nn.Linear(H, D_out)
836

837
            def forward(self, x):
838
                h_relu = self.linear1(x).clamp(min=0)
839
                y_pred = self.linear2(h_relu)
840
                return y_pred
841

842
        class CustomSGD(torch.optim.SGD):
843
            def __init__(self, *args, **kwargs):
844
                super().__init__(*args, **kwargs)
845

846
        def train():
847
            for _, data in enumerate(dataloader):
848
                x, y = data[0], data[1]
849
                y_pred = model(x)
850
                loss = criterion(y_pred, y)
851
                optimizer.zero_grad()
852
                loss.backward()
853
                optimizer.step()
854

855
        N, D_in, H, D_out = 8, 10, 5, 2
856
        model = TwoLayerNet(D_in, H, D_out)
857
        criterion = torch.nn.MSELoss(reduction="sum")
858
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
859
        ds = RepeatedDataset(N, D_in, D_out)
860
        dataloader = torch.utils.data.DataLoader(ds, batch_size=1)
861

862
        try:
863
            train()
864
        except Exception:
865
            self.assertTrue(False, "Expected no exception without profiling.")
866

867
        # Create multiple instances, expect each func is hooked only one time.
868
        # Nested wrappers(repeated patching) will make following test fail.
869
        optimizer_duplicate = torch.optim.SGD(model.parameters(), lr=1e-4)
870
        dataloader_duplicate = torch.utils.data.DataLoader(ds, batch_size=1)
871

872
        def judge(expected_event_count, prof):
873
            actual_event_count = {}
874
            for e in prof.function_events:
875
                if "#" in e.name:
876
                    key = e.name
877
                    if key in expected_event_count.keys():
878
                        actual_event_count[key] = (
879
                            actual_event_count.setdefault(key, 0) + 1
880
                        )
881
            for key, count in expected_event_count.items():
882
                self.assertTrue(
883
                    (key in actual_event_count.keys())
884
                    and (count == actual_event_count[key])
885
                )
886

887
        with _profile(use_kineto=kineto_available()) as prof:
888
            train()
889
        expected_event_count = {
890
            # "+1" because the final iteration will enter __next__ but skip the loop body.
891
            "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
892
            "Optimizer.step#SGD.step": N,
893
            "Optimizer.zero_grad#SGD.zero_grad": N,
894
        }
895
        judge(expected_event_count, prof)
896

897
        # Test on pickle/unpickle. Expect to work in multi-processing.
898
        optimizer = pickle.loads(pickle.dumps(optimizer))
899
        with _profile(use_kineto=kineto_available()) as prof:
900
            train()
901
        judge(expected_event_count, prof)
902

903
        # Test on customized optimizer.
904
        optimizer = CustomSGD(model.parameters(), lr=1e-4)
905
        with _profile(use_kineto=kineto_available()) as prof:
906
            train()
907
        expected_event_count = {
908
            "enumerate(DataLoader)#_SingleProcessDataLoaderIter.__next__": (N + 1),
909
            "Optimizer.step#CustomSGD.step": N,
910
            "Optimizer.zero_grad#CustomSGD.zero_grad": N,
911
        }
912
        judge(expected_event_count, prof)
913

914
    def test_flops(self):
915
        model = torch.nn.Sequential(
916
            nn.Conv2d(16, 33, 18),
917
            nn.ReLU(),
918
            nn.Linear(243, 243),
919
            nn.ReLU(),
920
        )
921
        inputs = torch.randn(40, 16, 18, 260)
922
        nested_tensor = torch.nested.nested_tensor(
923
            [torch.randn((2, 5)), torch.randn((3, 5))], layout=torch.jagged
924
        )
925
        with _profile(
926
            record_shapes=True, with_flops=True, use_kineto=kineto_available()
927
        ) as prof:
928
            model(inputs)
929
            # test that nested tensor won't cause exception during flop compute
930
            nested_tensor = nested_tensor + nested_tensor
931
        profiler_output = prof.key_averages(group_by_input_shape=True).table(
932
            sort_by="cpu_time_total", row_limit=10
933
        )
934
        self.assertIn("Total MFLOPs", profiler_output)
935
        if not (kineto_available() and torch.cuda.is_available()):
936
            return
937

938
        with profile(
939
            activities=[
940
                torch.profiler.ProfilerActivity.CPU,
941
                torch.profiler.ProfilerActivity.CUDA,
942
            ],
943
            record_shapes=True,
944
            with_flops=True,
945
        ) as kineto_profiler:
946
            model(inputs)
947
        profiler_output = kineto_profiler.key_averages().table(
948
            sort_by="self_cuda_time_total", row_limit=-1
949
        )
950
        self.assertIn("Total MFLOPs", profiler_output)
951

952
    def test_kineto_profiler_api(self):
953
        called_num = [0]
954

955
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
956
        with profile(activities=supported_activities()):
957
            self.payload(use_cuda=use_cuda)
958

959
        def trace_handler(p):
960
            output = p.key_averages().table(
961
                sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
962
                row_limit=-1,
963
            )
964
            # print(output)
965
            # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
966
            called_num[0] += 1
967

968
        initial_step = KinetoStepTracker.current_step()
969

970
        with profile(
971
            activities=supported_activities(),
972
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
973
            on_trace_ready=trace_handler,
974
        ) as p:
975
            for idx in range(8):
976
                self.payload(use_cuda=use_cuda)
977
                p.step()
978

979
        self.assertEqual(called_num[0], 2)
980
        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8)
981

982
        # case without schedule
983
        with profile(activities=supported_activities()) as p:
984
            self.payload(use_cuda=use_cuda)
985
            self.payload(use_cuda=use_cuda)
986
        output = p.key_averages().table(
987
            sort_by="self_cuda_time_total" if use_cuda else "self_cpu_time_total",
988
            row_limit=-1,
989
        )
990
        # print(output)
991

992
        test_schedule = torch.profiler.schedule(
993
            skip_first=2, wait=1, warmup=1, active=2, repeat=2
994
        )
995
        test_schedule_expected_outputs = [
996
            ProfilerAction.NONE,
997
            ProfilerAction.NONE,
998
            ProfilerAction.NONE,
999
            ProfilerAction.WARMUP,
1000
            ProfilerAction.RECORD,
1001
            ProfilerAction.RECORD_AND_SAVE,
1002
            ProfilerAction.NONE,
1003
            ProfilerAction.WARMUP,
1004
            ProfilerAction.RECORD,
1005
            ProfilerAction.RECORD_AND_SAVE,
1006
            ProfilerAction.NONE,
1007
            ProfilerAction.NONE,
1008
            ProfilerAction.NONE,
1009
            ProfilerAction.NONE,
1010
        ]
1011
        for step in range(len(test_schedule_expected_outputs)):
1012
            self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
1013

1014
    def test_kineto_profiler_multiple_steppers(self):
1015
        niters = 8
1016
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1017
        net = SimpleNet()
1018
        opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
1019
        opt.zero_grad()
1020
        inputs = torch.rand(10)
1021

1022
        with profile(activities=supported_activities()):
1023
            self.payload(use_cuda=use_cuda)
1024

1025
        def optimizer_step():
1026
            """This simulates a step() hook in the optimizer"""
1027
            KinetoStepTracker.increment_step("yet_another_step")
1028

1029
        initial_step = KinetoStepTracker.current_step()
1030

1031
        def run_batch():
1032
            out = net(inputs)
1033
            loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
1034
            loss.backward()
1035
            opt.step()
1036
            # Manually call the hook. TODO: Remove this once we add the
1037
            # profiler step hooks in the Optimizer class that will get triggered above.
1038
            # See https://github.com/pytorch/pytorch/issues/88446
1039
            optimizer_step()
1040

1041
        for idx in range(niters):
1042
            run_batch()
1043

1044
        with profile(
1045
            activities=supported_activities(),
1046
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
1047
        ) as p:
1048
            for idx in range(niters):
1049
                run_batch()
1050
                p.step()
1051

1052
        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters)
1053

1054
    def test_export_stacks(self):
1055
        with _profile(
1056
            with_stack=True,
1057
            use_kineto=kineto_available(),
1058
            experimental_config=_ExperimentalConfig(verbose=True),
1059
        ) as p:
1060
            x = torch.randn(10, 10)
1061
            y = torch.randn(10, 10)
1062
            z = torch.mm(x, y)
1063
            z = z + y
1064

1065
        with TemporaryFileName(mode="w+") as fname:
1066
            p.export_stacks(fname)
1067
            with open(fname) as f:
1068
                lines = f.readlines()
1069
            assert len(lines) > 0, "Empty stacks file"
1070
            for line in lines:
1071
                is_int = False
1072
                try:
1073
                    assert int(line.split(" ")[-1]) > 0, "Invalid stacks record"
1074
                    is_int = True
1075
                except ValueError:
1076
                    pass
1077
                assert is_int, "Invalid stacks record"
1078

1079
    @unittest.skipIf(not kineto_available(), "Kineto is required")
1080
    def test_tensorboard_trace_handler(self):
1081
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1082
        with _profile(use_cuda=use_cuda, use_kineto=True):
1083
            self.payload(use_cuda=use_cuda)
1084

1085
        with TemporaryDirectoryName() as dname:
1086
            with profile(
1087
                activities=[torch.profiler.ProfilerActivity.CPU]
1088
                + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
1089
                schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
1090
                on_trace_ready=torch.profiler.tensorboard_trace_handler(dname),
1091
            ) as p:
1092
                for _ in range(18):
1093
                    self.payload(use_cuda=use_cuda)
1094
                    p.step()
1095

1096
            self.assertTrue(os.path.exists(dname))
1097
            file_num = 0
1098
            for file_name in os.listdir(dname):
1099
                parts = file_name.split(".")
1100
                self.assertTrue(len(parts) > 4)
1101
                self.assertTrue(
1102
                    parts[-4].isdigit() and int(parts[-4]) > 0,
1103
                    "Wrong tracing file name pattern",
1104
                )
1105
                self.assertEqual(parts[-3:], ["pt", "trace", "json"])
1106
                file_num += 1
1107
            self.assertEqual(file_num, 3)
1108

1109
        # test case for gzip file format
1110
        with TemporaryDirectoryName() as dname:
1111
            p = profile(
1112
                activities=[torch.profiler.ProfilerActivity.CPU]
1113
                + ([torch.profiler.ProfilerActivity.CUDA] if use_cuda else []),
1114
                schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=3),
1115
                on_trace_ready=torch.profiler.tensorboard_trace_handler(
1116
                    dname, use_gzip=True
1117
                ),
1118
            )
1119
            p.start()
1120
            for _ in range(18):
1121
                self.payload(use_cuda=use_cuda)
1122
                p.step()
1123
            p.stop()
1124

1125
            self.assertTrue(os.path.exists(dname))
1126
            file_num = 0
1127
            for file_name in os.listdir(dname):
1128
                parts = file_name.split(".")
1129
                self.assertTrue(len(parts) > 4)
1130
                self.assertTrue(
1131
                    parts[-5].isdigit() and int(parts[-5]) > 0,
1132
                    "Wrong tracing file name pattern",
1133
                )
1134
                self.assertEqual(parts[-4:], ["pt", "trace", "json", "gz"])
1135
                file_num += 1
1136
            self.assertEqual(file_num, 3)
1137

1138
    @unittest.skipIf(not kineto_available(), "Kineto is required")
1139
    def test_profiler_metadata(self):
1140
        t1, t2 = torch.ones(1), torch.ones(1)
1141
        with profile() as prof:
1142
            torch.add(t1, t2)
1143
            prof.add_metadata("test_key1", "test_value1")
1144
            prof.add_metadata_json("test_key2", "[1,2,3]")
1145

1146
        with TemporaryFileName(mode="w+") as fname:
1147
            prof.export_chrome_trace(fname)
1148
            with open(fname) as f:
1149
                trace = json.load(f)
1150
                assert "test_key1" in trace
1151
                assert trace["test_key1"] == "test_value1"
1152
                assert "test_key2" in trace
1153
                assert trace["test_key2"] == [1, 2, 3]
1154

1155
    def _test_profiler_tracing(self, use_kineto):
1156
        with _profile(use_kineto=use_kineto) as prof:
1157
            t1, t2 = torch.ones(1), torch.ones(1)
1158
            torch.add(t1, t2)
1159

1160
        with TemporaryFileName(mode="w+") as fname:
1161
            prof.export_chrome_trace(fname)
1162
            # read the trace and expect valid json
1163
            # if the JSON generated by export_chrome_trace is not valid, this will throw and fail the test.
1164
            with open(fname) as f:
1165
                json.load(f)
1166

1167
        # test empty trace
1168
        with _profile(use_kineto=use_kineto) as prof:
1169
            pass
1170
        # saving an empty trace
1171
        with TemporaryFileName(mode="w+") as fname:
1172
            prof.export_chrome_trace(fname)
1173
            if use_kineto:
1174
                with open(fname) as f:
1175
                    contents = json.load(f)
1176
                    # Some builds may not have logger observer
1177
                    # so skip if not
1178
                    if "WARNING" in contents:
1179
                        found_empty_warning = False
1180
                        for warning in contents["WARNING"]:
1181
                            if "No Valid Trace Events" in warning:
1182
                                found_empty_warning = True
1183
                        self.assertTrue(found_empty_warning)
1184

1185
        # Same test but for cuda.
1186
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1187
        if not use_cuda:
1188
            return
1189

1190
        device = torch.device("cuda:0")
1191
        with _profile(use_cuda=True, use_kineto=use_kineto) as prof:
1192
            t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
1193
            torch.add(t1, t2)
1194

1195
        with TemporaryFileName(mode="w+") as fname:
1196
            prof.export_chrome_trace(fname)
1197
            # Now validate the json
1198
            with open(fname) as f:
1199
                json.load(f)
1200

1201
    def test_profiler_tracing(self):
1202
        self._test_profiler_tracing(False)
1203
        if kineto_available():
1204
            self._test_profiler_tracing(True)
1205

1206
    def test_profiler_op_event_args(self):
1207
        torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1208
        with _profile(record_shapes=True) as prof:
1209
            a = torch.ones((64, 32), dtype=torch.float32)
1210
            c = torch.cat([a, a]).sin()
1211
        with TemporaryFileName(mode="w+") as fname:
1212
            prof.export_chrome_trace(fname)
1213
            with open(fname) as f:
1214
                j = json.load(f)
1215
                op_events = [
1216
                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1217
                ]
1218
                for e in op_events:
1219
                    args = e["args"]
1220
                    if e["name"] == "aten::ones":
1221
                        self.assertEqual(
1222
                            args["Input type"],
1223
                            ["ScalarList", "Scalar", "", "", "Scalar"],
1224
                        )
1225
                        self.assertEqual(
1226
                            args["Concrete Inputs"], ["[64, 32]", "6", "", "", "False"]
1227
                        )
1228

1229
                    if e["name"] == "aten::cat":
1230
                        self.assertEqual(args["Input Dims"], [[[64, 32], [64, 32]], []])
1231
                        self.assertEqual(args["Input type"], ["TensorList", "Scalar"])
1232

1233
                    # check that each op has record function id
1234
                    self.assertGreaterEqual(
1235
                        args.get("Record function id", -1),
1236
                        0,
1237
                        f"Failed finding record funciont for op = {e}",
1238
                    )
1239

1240
    def test_profiler_strides(self):
1241
        torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1242
        base_tensor = torch.randn(1024, dtype=torch.float32)
1243
        a = base_tensor.as_strided((16, 16), (17, 1), 0)
1244
        b = base_tensor.as_strided((16, 16), (25, 2), 272)
1245
        with _profile(record_shapes=True) as prof:
1246
            c = torch.add(a, b)
1247

1248
        with TemporaryFileName(mode="w+") as fname:
1249
            prof.export_chrome_trace(fname)
1250
            with open(fname) as f:
1251
                j = json.load(f)
1252
                op_events = [
1253
                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1254
                ]
1255
                for e in op_events:
1256
                    args = e["args"]
1257
                    if e["name"] == "aten::add":
1258
                        self.assertEqual(args["Input Strides"], [[17, 1], [25, 2], []])
1259

1260
    def test_profiler_fwd_bwd_link(self):
1261
        with _profile(use_kineto=True) as prof:
1262
            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1263
                1, requires_grad=True
1264
            )
1265
            z = torch.add(t1, t2)
1266
            y = torch.ones(1)
1267
            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
1268
            loss.backward()
1269
        with TemporaryFileName(mode="w+") as fname:
1270
            prof.export_chrome_trace(fname)
1271
            with open(fname) as f:
1272
                j = json.load(f)
1273
                events = j["traceEvents"]
1274
                ts_to_name = {}
1275
                flow_s_to_ts = {}
1276
                flow_f_to_ts = {}
1277
                for e in events:
1278
                    if e["ph"] == "X":
1279
                        ts_to_name[e["ts"]] = e["name"]
1280
                    if (
1281
                        "cat" in e
1282
                        and "name" in e
1283
                        and e["cat"] == "fwdbwd"
1284
                        and e["name"] == "fwdbwd"
1285
                    ):
1286
                        if e["ph"] == "s":
1287
                            flow_s_to_ts[e["id"]] = e["ts"]
1288
                        elif e["ph"] == "f":
1289
                            flow_f_to_ts[e["id"]] = e["ts"]
1290

1291
                self.assertEqual(len(flow_s_to_ts), 2)
1292
                self.assertEqual(len(flow_f_to_ts), 2)
1293
                self.assertIn(1, flow_s_to_ts)
1294
                self.assertIn(1, flow_f_to_ts)
1295
                self.assertIn(2, flow_s_to_ts)
1296
                self.assertIn(2, flow_f_to_ts)
1297
                s_ts_1 = flow_s_to_ts[1]
1298
                f_ts_1 = flow_f_to_ts[1]
1299
                s_ts_2 = flow_s_to_ts[2]
1300
                f_ts_2 = flow_f_to_ts[2]
1301
                self.assertTrue(
1302
                    all(
1303
                        ts in ts_to_name.keys()
1304
                        for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]
1305
                    )
1306
                )
1307
                self.assertTrue(
1308
                    ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits"
1309
                )
1310
                self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
1311

1312
    def test_profiler_disable_fwd_bwd_link(self):
1313
        try:
1314
            torch._C._profiler._set_fwd_bwd_enabled_val(False)
1315

1316
            with _profile(use_kineto=True) as prof:
1317
                t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
1318
                    1, requires_grad=True
1319
                )
1320
                z = torch.add(t1, t2)
1321
                y = torch.ones(1)
1322
                loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
1323
                loss.backward()
1324

1325
            with TemporaryFileName(mode="w+") as fname:
1326
                prof.export_chrome_trace(fname)
1327
                with open(fname) as f:
1328
                    j = json.load(f)
1329
                    events = j["traceEvents"]
1330

1331
                    for e in events:
1332
                        self.assertNotEqual(e.get("cat", None), "fwdbwd")
1333
        finally:
1334
            torch._C._profiler._set_fwd_bwd_enabled_val(True)
1335

1336
    # This test is broken on Windows, the likely reason is that kineto/CUPTI
1337
    # is not supported that particular environment. Once the CI stabilizes
1338
    # we can narrow the condition so Windows is checked as well (TODO)
1339
    @unittest.skipIf(not kineto_available(), "Kineto is required")
1340
    @unittest.skipIf(IS_WINDOWS, "Test does not work on Windows")
1341
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1342
    def test_profiler_cuda_sync_events(self):
1343
        device = torch.device("cuda:0")
1344
        t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device)
1345

1346
        def workload() -> None:
1347
            torch.add(t1, t2)
1348
            torch.cuda.synchronize()
1349
            torch.add(t1, t2)
1350

1351
        def trace_and_check(exp_config: Optional[_ExperimentalConfig]) -> None:
1352
            with _profile(
1353
                use_kineto=True,
1354
                use_cuda=True,
1355
                experimental_config=exp_config,
1356
            ) as prof:
1357
                workload()
1358

1359
            with TemporaryFileName(mode="w+") as fname:
1360
                # fname = "/tmp/kineto_out.json"
1361
                prof.export_chrome_trace(fname)
1362
                with open(fname) as f:
1363
                    j = json.load(f)
1364
                    cats = {e.get("cat", None) for e in j["traceEvents"]}
1365
            self.assertTrue(
1366
                "cuda_sync" in cats,
1367
                "Expected to find cuda_sync event" f" found = {cats}",
1368
            )
1369

1370
        print("Testing enable_cuda_sync_events in _ExperimentalConfig")
1371
        trace_and_check(exp_config=_ExperimentalConfig(enable_cuda_sync_events=True))
1372

1373
        print("Testing _profiler._set_cuda_sync_enabled_val()")
1374
        try:
1375
            torch._C._profiler._set_cuda_sync_enabled_val(True)
1376
            trace_and_check(exp_config=None)
1377
        finally:
1378
            torch._C._profiler._set_cuda_sync_enabled_val(False)
1379

1380
    def test_profiler_type(self):
1381
        profiler_type = torch._C._autograd._profiler_type
1382
        ActiveProfilerType = torch._C._profiler.ActiveProfilerType
1383
        self.assertEqual(profiler_type(), ActiveProfilerType.NONE)
1384

1385
        # Autograd profiler
1386
        with _profile_legacy():
1387
            self.assertEqual(profiler_type(), ActiveProfilerType.LEGACY)
1388

1389
        # Kineto profiler
1390
        with profile():
1391
            self.assertEqual(profiler_type(), ActiveProfilerType.KINETO)
1392

1393
    def test_profiler_correlation_id(self):
1394
        """
1395
        We expect the correlation_id to be unique across multiple invokation of the profiler,
1396
        So we will reuse id_uniqueness_set.
1397
        """
1398
        id_uniqueness_set = set()
1399
        model = torch.nn.Sequential(
1400
            nn.Conv2d(16, 33, 18),
1401
            nn.ReLU(),
1402
            nn.Linear(243, 243),
1403
            nn.ReLU(),
1404
        )
1405
        inputs = torch.randn(40, 16, 18, 260)
1406
        uint32_max = 2**32 - 1
1407
        for i in range(5):
1408
            with profile() as prof:
1409
                model(inputs)
1410
            for event in prof.profiler.kineto_results.events():
1411
                corr_id = event.correlation_id()
1412
                if (corr_id) and event.device_type() == DeviceType.CPU:
1413
                    self.assertTrue(corr_id not in id_uniqueness_set)
1414
                    id_uniqueness_set.add(corr_id)
1415
                    self.assertTrue(corr_id < uint32_max)
1416

1417
    def test_nested_tensor_with_shapes(self):
1418
        a = torch.randn(4, 4)
1419
        b = torch.randn(4, 4)
1420
        c = torch.randn(4, 4)
1421
        inp = torch.nested.nested_tensor([a, b])
1422
        with torch.profiler.profile(record_shapes=True) as prof:
1423
            torch.nn.functional.linear(inp, c, None)
1424
        for e in prof.events():
1425
            if e.name in ("aten::mm", "aten::addmm"):
1426
                # intentionally vague tests to protect against possible future changes
1427
                # of mm to addmm or other impl, or changing internal order of args
1428
                self.assertTrue(len(e.input_shapes) > 0)
1429
                self.assertTrue(len(e.input_shapes[0]) > 0)
1430

1431
    @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
1432
    @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
1433
    def test_kineto_profiler_with_environment_variable(self):
1434
        script = """
1435
import torch
1436
import torch.nn as nn
1437
from torch.profiler import supported_activities, profile
1438
from torch.autograd.profiler import KinetoStepTracker
1439

1440
class SimpleNet(nn.Module):
1441
    def __init__(self) -> None:
1442
        super().__init__()
1443
        self.fc1 = nn.Linear(10, 5)
1444
        self.fc2 = nn.Linear(5, 2)
1445

1446
    def forward(self, x):
1447
        return self.fc2(self.fc1(x))
1448

1449

1450
def payload(use_cuda=False):
1451
    x = torch.randn(10, 10)
1452
    if use_cuda:
1453
        x = x.cuda()
1454
    y = torch.randn(10, 10)
1455
    if use_cuda:
1456
        y = y.cuda()
1457
    z = torch.mm(x, y)
1458
    z = z + y
1459
    if use_cuda:
1460
        z = z.cpu()
1461

1462
niters = 8
1463
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1464
net = SimpleNet()
1465
opt = torch.optim.SGD(net.parameters(), lr=0.01)
1466
opt.zero_grad()
1467
inputs = torch.rand(10)
1468

1469
with profile(activities=supported_activities()):
1470
    payload(use_cuda=use_cuda)
1471

1472
initial_step = KinetoStepTracker.current_step()
1473

1474
def run_batch():
1475
    out = net(inputs)
1476
    loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
1477
    loss.backward()
1478
    opt.step()
1479

1480
for _ in range(niters):
1481
    run_batch()
1482

1483
with profile(
1484
    activities=supported_activities(),
1485
    schedule=torch.profiler.schedule(
1486
        wait=1,
1487
        warmup=1,
1488
        active=2),
1489
) as p:
1490
    for _ in range(niters):
1491
        run_batch()
1492
        p.step()
1493
assert KinetoStepTracker.current_step() == initial_step + 2 * niters
1494
"""
1495
        try:
1496
            subprocess.check_output(
1497
                [sys.executable, "-W", "always", "-c", script],
1498
                cwd=os.path.dirname(os.path.realpath(__file__)),
1499
            )
1500
        except subprocess.CalledProcessError as e:
1501
            if e.returncode != 0:
1502
                self.assertTrue(
1503
                    False,
1504
                    "Kineto is not working properly with the Dynolog environment variable",
1505
                )
1506

1507
    def test_concrete_inputs_profiling(self):
1508
        x = torch.rand(2, 6)
1509
        with profile(record_shapes=True) as p:
1510
            y = x.as_strided([4, 3], [1, 4])
1511

1512
        found = False
1513
        for e in p.events():
1514
            if e.name in ("aten::as_strided"):
1515
                found = True
1516
                self.assertTrue(len(e.input_shapes) > 0)
1517
                self.assertTrue(len(e.concrete_inputs) > 0)
1518
                self.assertEqual([2, 6], e.input_shapes[0])
1519
                self.assertEqual([4, 3], e.concrete_inputs[1])
1520
                self.assertEqual([1, 4], e.concrete_inputs[2])
1521

1522
        self.assertTrue(found, "Expected to find aten::as_strided but did not")
1523

1524
    def test_concrete_inputs_profiling_toggling(self):
1525
        try:
1526
            for before, after in [(True, False), (False, True)]:
1527
                x = torch.rand(2, 6)
1528
                torch._C._profiler._set_record_concrete_inputs_enabled_val(before)
1529
                with profile(record_shapes=True) as p:
1530
                    y = x.as_strided([4, 3], [1, 4])
1531
                    torch._C._profiler._set_record_concrete_inputs_enabled_val(after)
1532

1533
                found = False
1534
                for e in p.events():
1535
                    if e.name in ("aten::as_strided"):
1536
                        found = True
1537
                        self.assertTrue(len(e.input_shapes))
1538

1539
                self.assertTrue(found, "Expected to find aten::as_strided but did not")
1540
        finally:
1541
            torch._C._profiler._set_record_concrete_inputs_enabled_val(True)
1542

1543
    def test_record_function_fast(self):
1544
        x, y = (torch.rand((4, 4)) for _ in range(2))
1545
        with profile(record_shapes=True) as p:
1546
            for _ in range(4):
1547
                # Test first with no optional args
1548
                with torch._C._profiler._RecordFunctionFast("add_test_fast_rf1"):
1549
                    x.add(y)
1550

1551
        self.assertGreaterEqual(
1552
            len([e for e in p.events() if e.name == "add_test_fast_rf1"]), 4
1553
        )
1554
        for e in p.events():
1555
            if e.name == "add_test_fast_rf1":
1556
                self.assertTrue(e.input_shapes == [])
1557
                self.assertTrue(e.kwinputs == {})
1558
        with profile(record_shapes=True) as p:
1559
            # add optional args
1560
            cm = torch._C._profiler._RecordFunctionFast(
1561
                "add_test_fast_rf2", [x, y], {"stream": 0, "grid": "lambda x : x + 1"}
1562
            )
1563
            for _ in range(4):
1564
                with cm:
1565
                    x.add(y)
1566

1567
        self.assertGreaterEqual(
1568
            len([e for e in p.events() if e.name == "add_test_fast_rf2"]), 4
1569
        )
1570

1571
        for e in p.events():
1572
            if e.name == "add_test_fast_rf2":
1573
                self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
1574
                self.assertTrue(e.kwinputs == {"stream": 0, "grid": "lambda x : x + 1"})
1575

1576
        with profile(record_shapes=True) as p:
1577
            cm = torch._C._profiler._RecordFunctionFast(
1578
                "add_test_fast_rf3", input_values=["hi"], keyword_values={"hi": "hello"}
1579
            )
1580
            for _ in range(4):
1581
                try:
1582
                    with cm:
1583
                        x.add(y)
1584
                        raise ValueError
1585
                        x.relu()
1586
                except ValueError:
1587
                    pass
1588

1589
        self.assertGreaterEqual(
1590
            len([e for e in p.events() if e.name == "add_test_fast_rf3"]), 4
1591
        )
1592
        self.assertFalse(any((e.name and "relu" in e.name) for e in p.events()))
1593

1594
        for e in p.events():
1595
            if e.name == "add_test_fast_rf3":
1596
                self.assertTrue(e.input_shapes == [[]])
1597

1598
        with profile() as p:
1599
            for _ in range(4):
1600
                with torch._C._profiler._RecordFunctionFast(
1601
                    "add_test_fast_rf4", [x, y]
1602
                ):
1603
                    x.add(y)
1604
                    with torch._C._profiler._RecordFunctionFast("add_test_fast_rf5"):
1605
                        x.relu()
1606

1607
        self.assertGreaterEqual(
1608
            len([e for e in p.events() if e.name == "add_test_fast_rf4"]), 4
1609
        )
1610

1611
        for e in p.events():
1612
            if e.name == "add_test_fast_rf4":
1613
                self.assertTrue(e.input_shapes == [])
1614

1615
        self.assertGreaterEqual(
1616
            len([e for e in p.events() if e.name == "add_test_fast_rf5"]), 4
1617
        )
1618

1619
        with profile(record_shapes=True) as p:
1620
            # test optional args with tuple
1621
            cm = torch._C._profiler._RecordFunctionFast(
1622
                "add_test_fast_rf6",
1623
                (
1624
                    x,
1625
                    y,
1626
                ),
1627
            )
1628
            for _ in range(4):
1629
                with cm:
1630
                    x.add(y)
1631

1632
        self.assertGreaterEqual(
1633
            len([e for e in p.events() if e.name == "add_test_fast_rf6"]), 4
1634
        )
1635

1636
        for e in p.events():
1637
            if e.name == "add_test_fast_rf6":
1638
                self.assertTrue(e.input_shapes == [[4, 4], [4, 4]])
1639

1640
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1641
    def test_profiler_op_event_kwargs(self):
1642
        x, y = (torch.rand((4, 4)) for _ in range(2))
1643
        with profile(record_shapes=True) as p:
1644
            cm = torch._C._profiler._RecordFunctionFast(
1645
                "add_test_kwinputs",
1646
                [x, y],
1647
                {"stream": 0, "grid": "lambda x : x + 1", "debug": 'debug"'},
1648
            )
1649
            for _ in range(4):
1650
                with cm:
1651
                    x.add(y)
1652
        with TemporaryFileName(mode="w+") as fname:
1653
            p.export_chrome_trace(fname)
1654
            with open(fname) as f:
1655
                j = json.load(f)
1656
                op_events = [
1657
                    e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op"
1658
                ]
1659
                for e in op_events:
1660
                    if e["name"] == "add_test_kwinputs":
1661
                        args = e["args"]
1662
                        self.assertTrue("stream" in args)
1663
                        self.assertTrue("grid" in args)
1664
                        self.assertTrue(args["stream"] == "0")
1665
                        self.assertTrue(args["grid"] == "lambda x : x + 1")
1666
                        self.assertTrue(args["debug"] == "None")
1667

1668
    def test_is_profiler_enabled(self):
1669
        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1670

1671
        with profile() as p:
1672
            self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
1673

1674
        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1675

1676
        with torch.autograd.profiler.profile() as p:
1677
            self.assertTrue(torch.autograd.profiler._is_profiler_enabled)
1678

1679
        self.assertFalse(torch.autograd.profiler._is_profiler_enabled)
1680

1681
    def test_guarded_record_function_fast(self):
1682
        x, y = (torch.rand((4, 4)) for _ in range(2))
1683

1684
        with profile() as p:
1685
            cm = torch._C._profiler._RecordFunctionFast("guarded_rff")
1686
            for _ in range(4):
1687
                if torch.autograd.profiler._is_profiler_enabled:
1688
                    with cm:
1689
                        x.add(y)
1690
                else:
1691
                    x.add(y)
1692

1693
        self.assertGreaterEqual(
1694
            len([e for e in p.events() if e.name == "guarded_rff"]), 4
1695
        )
1696

1697
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1698
    def test_event_list(self):
1699
        # AFAIK event list is part of legacy profiler and/or used when kineto is not available.
1700
        # This test has basic sanity checks to test against obvious regressions.
1701
        x, y = (torch.rand((4, 4), requires_grad=True, device="cuda") for _ in range(2))
1702
        with profile(with_stack=True) as p:
1703
            z = (x @ y).relu().sum()
1704
            z.backward()
1705

1706
        event_list = torch.autograd.profiler_util.EventList(p.events())
1707
        # event_list._build_tree()
1708

1709
        with TemporaryFileName(mode="w+") as fname:
1710
            event_list.export_chrome_trace(fname)
1711
            with open(fname) as f:
1712
                json.load(f)
1713

1714
        event_list.table()
1715

1716
    def _check_all_gpu_present(self, gpu_dict, max_gpu_count):
1717
        for i in range(0, max_gpu_count):
1718
            self.assertEqual(gpu_dict["GPU " + str(i)], 1)
1719

1720
    # Do json sanity testing. Checks that all events are between profiler start and end
1721
    # also checks to see that GPU values are present in trace if cuda is used
1722
    def _validate_basic_json(self, traceEvents, cuda_available=False):
1723
        MAX_GPU_COUNT = 8
1724
        PROFILER_IDX = -4
1725
        RECORD_END = -1
1726
        RECORD_START = -2
1727
        traceEventProfiler = traceEvents[PROFILER_IDX]
1728

1729
        self.assertTrue(traceEventProfiler["name"] == "PyTorch Profiler (0)")
1730
        self.assertTrue(traceEvents[RECORD_END]["name"] == "Record Window End")
1731
        self.assertTrue(
1732
            traceEvents[RECORD_START]["name"] == "Iteration Start: PyTorch Profiler"
1733
        )
1734
        # check that the profiler starts/ends within the record interval
1735
        self.assertGreaterEqual(
1736
            traceEventProfiler["ts"],
1737
            traceEvents[RECORD_START]["ts"],
1738
            "Profiler starts before record!",
1739
        )
1740
        self.assertLessEqual(
1741
            traceEventProfiler["ts"] + traceEventProfiler["dur"],
1742
            traceEvents[RECORD_END]["ts"],
1743
            "Profiler ends after record end!",
1744
        )
1745

1746
        gpu_dict = collections.defaultdict(int)
1747
        for i, traceEvent in enumerate(traceEvents):
1748
            if (
1749
                i == len(traceEvents) + RECORD_END
1750
                or i == len(traceEvents) + RECORD_START
1751
            ):
1752
                continue
1753
            # make sure all valid trace events are within the bounds of the profiler
1754
            if "ts" in traceEvent:
1755
                self.assertGreaterEqual(
1756
                    traceEvent["ts"],
1757
                    traceEventProfiler["ts"],
1758
                    "Trace event is out of bounds",
1759
                )
1760
            # some python events seem to go a little past record end probably because
1761
            # of some clock inaccuracies so just compare events ending to RECORD_END
1762
            if "dur" in traceEvent:
1763
                self.assertLessEqual(
1764
                    traceEvent["ts"] + traceEvent["dur"],
1765
                    traceEvents[RECORD_END]["ts"],
1766
                    "Trace event ends too late!",
1767
                )
1768
            gpu_value = traceEvent.get("args", {}).get("labels", None)
1769
            if gpu_value and "GPU" in gpu_value:
1770
                gpu_dict[gpu_value] += 1
1771
                # Max PID offset is 5M, based from pytorch/kineto include header:
1772
                # https://github.com/pytorch/kineto/blob/8681ff11e1fa54da39023076c5c43eddd87b7a8a/libkineto/include/output_base.h#L35
1773
                kExceedMaxPid = 5000000
1774
                self.assertTrue(
1775
                    traceEvents[i + 1]["args"]["sort_index"]
1776
                    == kExceedMaxPid + int(gpu_value.split()[1])
1777
                )
1778

1779
        # TODO add checking gpu count if cpuOnly_ is true or not
1780

1781
    def _test_chrome_trace_basic_helper(self, with_cuda=False):
1782
        if with_cuda:
1783
            device = "cuda"
1784
        else:
1785
            device = "cpu"
1786
        x, y = (torch.rand(4, 4).to(device) for _ in range(2))
1787

1788
        with profile(with_stack=True) as p:
1789
            torch.add(x, y)
1790
        with TemporaryFileName(mode="w+") as fname:
1791
            p.export_chrome_trace(fname)
1792
            with open(fname) as f:
1793
                report = json.load(f)
1794
                self._validate_basic_json(report["traceEvents"], with_cuda)
1795

1796
    @unittest.skipIf(not kineto_available(), "Kineto is required")
1797
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1798
    def test_basic_chrome_trace(self):
1799
        self._test_chrome_trace_basic_helper()
1800
        if torch.cuda.is_available():
1801
            self._test_chrome_trace_basic_helper(with_cuda=True)
1802

1803
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1804
    def test_profiler_time_scale(self):
1805
        MARGIN_ERROR = 0.5
1806
        SEC_TO_US = 1000 * 1000
1807
        WAIT_TIME = 10
1808
        with profile() as p:
1809
            with torch.profiler.record_function("test_span"):
1810
                for i in range(WAIT_TIME):
1811
                    torch.rand(4, 4)
1812
                    time.sleep(1)
1813
        events = p.events()
1814

1815
        # make sure function events are scaled appropriately
1816
        self.assertTrue(events[0].name == "test_span")
1817
        test_span = events[0]
1818
        self.assertGreaterEqual(
1819
            test_span.cpu_time / SEC_TO_US,
1820
            WAIT_TIME - MARGIN_ERROR,
1821
            "event out of range",
1822
        )
1823
        self.assertLessEqual(
1824
            test_span.cpu_time / SEC_TO_US,
1825
            WAIT_TIME + MARGIN_ERROR,
1826
            "event out of range",
1827
        )
1828

1829
        # make sure tracing is scaled appropriately
1830
        with TemporaryFileName(mode="w+") as fname:
1831
            p.export_chrome_trace(fname)
1832
            with open(fname) as f:
1833
                report = json.load(f)
1834
            events = report["traceEvents"]
1835
            for event in events:
1836
                if event["name"] == "test_span":
1837
                    self.assertGreaterEqual(
1838
                        event["dur"] / SEC_TO_US,
1839
                        WAIT_TIME - MARGIN_ERROR,
1840
                        "profiling out of range",
1841
                    )
1842
                    self.assertLessEqual(
1843
                        event["dur"] / SEC_TO_US,
1844
                        WAIT_TIME + MARGIN_ERROR,
1845
                        "profiling out of range",
1846
                    )
1847

1848
    def _schedule_helper(self, warmup, active, repeat, acc_events=True):
1849
        with profile(
1850
            schedule=torch.profiler.schedule(
1851
                skip_first=0,
1852
                wait=0,
1853
                warmup=warmup,
1854
                active=active,
1855
                repeat=repeat,
1856
            ),
1857
            acc_events=acc_events,
1858
        ) as prof:
1859
            for i in range(100):
1860
                torch.add(1, 2)
1861
                prof.step()
1862
        # print(prof.key_averages())
1863
        for ev in prof.key_averages():
1864
            if ev.key == "aten::add":
1865
                return ev.count
1866
        return 0
1867

1868
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1869
    def test_schedule_function_count(self):
1870
        self.assertEqual(self._schedule_helper(warmup=0, active=1, repeat=1), 1)
1871
        self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=0), 100)
1872
        self.assertEqual(self._schedule_helper(warmup=0, active=5, repeat=10), 50)
1873
        self.assertEqual(self._schedule_helper(warmup=1, active=5, repeat=0), 83)
1874
        self.assertEqual(self._schedule_helper(warmup=10, active=10, repeat=4), 40)
1875
        self.assertEqual(self._schedule_helper(warmup=50, active=1, repeat=0), 1)
1876
        self.assertEqual(
1877
            self._schedule_helper(warmup=0, active=5, repeat=0, acc_events=False), 0
1878
        )
1879
        self.assertEqual(
1880
            self._schedule_helper(warmup=10, active=10, repeat=4, acc_events=False), 10
1881
        )
1882

1883
    def _step_helper_func(self, prof):
1884
        time.sleep(0.1)
1885
        torch.randn(1, 3, 224, 224)
1886
        prof.step()
1887

1888
    def _partial_overlap(self, prof_step, step_helper_func):
1889
        p_start = prof_step["ts"]
1890
        p_end = prof_step["ts"] + prof_step["dur"]
1891
        h_start = step_helper_func["ts"]
1892
        h_end = step_helper_func["ts"] + step_helper_func["dur"]
1893

1894
        if p_start < h_start and p_end < h_end and p_end > h_start:
1895
            return True
1896
        if p_start > h_start and p_start < h_end and p_end > h_end:
1897
            return True
1898
        return False
1899

1900
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1901
    def test_cpu_annotation_overlap(self):
1902
        with torch.profiler.profile(
1903
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1904
            record_shapes=True,
1905
            with_stack=True,
1906
            schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1),
1907
        ) as prof:
1908
            for i in range(5):
1909
                self._step_helper_func(prof)
1910
        with TemporaryFileName(mode="w+") as fname:
1911
            prof.export_chrome_trace(fname)
1912
            prof_steps = []
1913
            step_helper_funcs = []
1914
            with open(fname) as f:
1915
                report = json.load(f)
1916
                for event in report["traceEvents"]:
1917
                    if "ProfilerStep" in event["name"]:
1918
                        prof_steps.append(event)
1919
                    if "step_helper_func" in event["name"]:
1920
                        step_helper_funcs.append(event)
1921
            self.assertEqual(len(prof_steps), 5)
1922
            self.assertEqual(len(step_helper_funcs), 5)
1923
            for i in range(0, len(step_helper_funcs)):
1924
                for j in range(0, len(step_helper_funcs)):
1925
                    self.assertTrue(
1926
                        not self._partial_overlap(prof_steps[i], step_helper_funcs[j])
1927
                    )
1928

1929
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1930
    def test_user_annotation(self):
1931
        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
1932
        with profile(activities=supported_activities()) as p:
1933
            with torch.profiler.record_function("test_user_annotation"):
1934
                self.payload(use_cuda=use_cuda)
1935

1936
        for evt in p.key_averages():
1937
            if evt.key == "test_user_annotation":
1938
                self.assertTrue(evt.is_user_annotation)
1939
            else:
1940
                self.assertFalse(evt.is_user_annotation)
1941

1942
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
1943
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1944
    def test_dynamic_toggle(self):
1945
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p:
1946
            with torch.profiler.record_function("test_user_annotation"):
1947
                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1948
                torch.add(x, y)
1949

1950
        self.assertTrue(any("aten" in e.name for e in p.events()))
1951

1952
        self.assertTrue(any("cuda" in e.name for e in p.events()))
1953

1954
        self.assertTrue(any("kernel" in e.name for e in p.events()))
1955

1956
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p1:
1957
            p1.toggle_collection_dynamic(False, [ProfilerActivity.CUDA])
1958
            with torch.profiler.record_function("test_user_annotation"):
1959
                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1960
                torch.add(x, y)
1961

1962
        self.assertTrue(any("aten" in e.name for e in p1.events()))
1963

1964
        self.assertTrue(all("cuda" not in e.name for e in p1.events()))
1965

1966
        self.assertTrue(all("kernel" not in e.name for e in p1.events()))
1967

1968
        with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p2:
1969
            p2.toggle_collection_dynamic(
1970
                False, [ProfilerActivity.CUDA, ProfilerActivity.CPU]
1971
            )
1972
            with torch.profiler.record_function("test_user_annotation"):
1973
                x, y = (torch.rand(4, 4).to("cuda") for _ in range(2))
1974
                torch.add(x, y)
1975
        self.assertTrue(len(p2.events()) == 0)
1976

1977
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
1978
    def test_lazy_build_tree(self):
1979
        with profile() as p:
1980
            self.payload()
1981

1982
        stats = p._stats()
1983
        # Test that the tree is not built
1984
        self.assertEqual(stats.function_events_build_tree_call_duration_us, 0)
1985
        self.assertEqual(stats.number_of_events, 0)
1986

1987
        # Test that the tree is built on demand
1988
        p.events()
1989
        self.assertGreater(stats.function_events_build_tree_call_duration_us, 0)
1990
        self.assertGreater(stats.number_of_events, 0)
1991

1992

1993
class SimpleNet(nn.Module):
1994
    def __init__(self) -> None:
1995
        super().__init__()
1996
        self.fc1 = nn.Linear(10, 5)
1997
        self.fc2 = nn.Linear(5, 2)
1998

1999
    def forward(self, x):
2000
        return self.fc2(self.fc1(x))
2001

2002

2003
@dataclass(frozen=True)
2004
class MockKinetoEvent:
2005
    _name: str
2006
    _start_us: int
2007
    _duration_us: int
2008
    _linked_correlation_id: int
2009
    _device_type: int
2010

2011
    @property
2012
    def name(self) -> str:
2013
        return self._name
2014

2015
    def start_ns(self) -> int:
2016
        return self._start_us * 1000
2017

2018
    def duration_ns(self) -> int:
2019
        return self._duration_us * 1000
2020

2021
    def linked_correlation_id(self) -> int:
2022
        return self._linked_correlation_id
2023

2024
    def device_type(self) -> DeviceType:
2025
        return DeviceType.CUDA if self._device_type == 1 else DeviceType.CPU
2026

2027

2028
@dataclass(frozen=True)
2029
class MockProfilerEvent:
2030
    _name: str
2031
    id: int
2032
    start_time_ns: int
2033
    duration_time_ns: int
2034
    correlation_id: int = 0
2035
    children: List["MockProfilerEvent"] = field(default_factory=list)
2036
    parent: Optional["MockProfilerEvent"] = None
2037

2038
    @property
2039
    def end_time_ns(self):
2040
        return self.start_time_ns + self.duration_time_ns
2041

2042
    @property
2043
    def name(self) -> str:
2044
        return self._name
2045

2046
    def __post__init__(self, parent, children):
2047
        object.__setattr__(self, "parent", parent)
2048
        object.__setattr__(self, "children", children)
2049

2050

2051
class MockNode:
2052
    def __init__(self, name, children) -> None:
2053
        self.name = name
2054
        self.children = [MockNode(name, i) for name, i in children.items()]
2055

2056

2057
class TestExperimentalUtils(TestCase):
2058
    def make_tree(self) -> List[MockNode]:
2059
        tree = {
2060
            "root_0": {
2061
                "1": {"2": {}},
2062
                "3": {
2063
                    "4": {},
2064
                    "5": {},
2065
                },
2066
            },
2067
            "root_1": {
2068
                "6": {},
2069
                "7": {},
2070
                "8": {
2071
                    "9": {"10": {}},
2072
                },
2073
            },
2074
        }
2075
        return [MockNode(name, i) for name, i in tree.items()]
2076

2077
    def test_dfs(self) -> None:
2078
        self.assertEqual(
2079
            " ".join(i.name for i in _utils.traverse_dfs(self.make_tree())),
2080
            "root_0 1 2 3 4 5 root_1 6 7 8 9 10",
2081
        )
2082

2083
    def test_bfs(self) -> None:
2084
        self.assertEqual(
2085
            " ".join(i.name for i in _utils.traverse_bfs(self.make_tree())),
2086
            "root_0 root_1 1 3 6 7 8 2 4 5 9 10",
2087
        )
2088

2089
    @staticmethod
2090
    def generate_mock_profile():
2091
        cuda_events = [
2092
            MockKinetoEvent("cudaLaunchKernel", 400, 100, 1, 0),
2093
            MockKinetoEvent("cudaLaunchKernel", 500, 100, 2, 0),
2094
            MockKinetoEvent("cudaLaunchKernel", 600, 100, 3, 0),
2095
            MockKinetoEvent("cudaLaunchKernel", 700, 100, 4, 0),
2096
            MockKinetoEvent("cudaLaunchKernel", 800, 100, 5, 0),
2097
            MockKinetoEvent("cudaLaunchKernel", 1500, 100, 6, 0),
2098
            MockKinetoEvent("GPU", 900, 100, 1, 1),
2099
            MockKinetoEvent("GPU", 1000, 100, 2, 1),
2100
            MockKinetoEvent("GPU", 1100, 100, 3, 1),
2101
            MockKinetoEvent("GPU", 1200, 100, 4, 1),
2102
            MockKinetoEvent("GPU", 1300, 100, 5, 1),
2103
            MockKinetoEvent("GPU", 1700, 100, 6, 1),
2104
        ]
2105
        cpu_events = [
2106
            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 1, 0, 100000),
2107
            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 2, 100000, 100000),
2108
            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 3, 200000, 100000),
2109
            MockProfilerEvent("CPU (Before cudaLaunchKernel)", 4, 300000, 100000),
2110
            MockProfilerEvent("CPU (After cudaLaunchKernel)", 5, 400000, 100000),
2111
            MockProfilerEvent("CPU (After cudaLaunchKernel)", 6, 500000, 100000),
2112
            MockProfilerEvent("CPU (After cudaLaunchKernel)", 7, 600000, 100000),
2113
            MockProfilerEvent("CPU (After cudaLaunchKernel)", 8, 700000, 100000),
2114
            MockProfilerEvent("CPU (After GPU)", 9, 800000, 100000),
2115
            MockProfilerEvent("CPU (After GPU)", 10, 900000, 100000),
2116
            MockProfilerEvent("CPU (After GPU)", 11, 1100000, 100000),
2117
            MockProfilerEvent("CPU (After GPU)", 12, 1200000, 500000),
2118
        ]
2119

2120
        profiler = unittest.mock.Mock()
2121
        profiler.kineto_results = unittest.mock.Mock()
2122
        profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
2123
        profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
2124
            return_value=cpu_events
2125
        )
2126
        return profiler
2127

2128
    @staticmethod
2129
    def load_mock_profile():
2130
        accept = expecttest.ACCEPT
2131
        json_file_path = os.path.join(
2132
            os.path.dirname(os.path.realpath(__file__)),
2133
            "profiler_utils_mock_events.json",
2134
        )
2135
        if accept and torch.cuda.is_available():
2136

2137
            def garbage_code(x):
2138
                for i in range(5):
2139
                    x[0, i] = i
2140

2141
            x = torch.ones((4096, 4096), device="cuda")
2142
            x = x @ x
2143
            with profile(
2144
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
2145
                record_shapes=True,
2146
                with_stack=True,
2147
            ) as prof:
2148
                for _ in range(5):
2149
                    x = x @ x
2150
                garbage_code(x)
2151
                for _ in range(5):
2152
                    x = x @ x
2153

2154
            kineto_events = [
2155
                {
2156
                    "_name": e.name,
2157
                    "_start_ns": e.start_ns(),
2158
                    "_duration_ns": e.duration_ns(),
2159
                    "_linked_correlation_id": e.linked_correlation_id(),
2160
                    "_device_type": 1 if e.device_type() == DeviceType.CUDA else 0,
2161
                }
2162
                for e in prof.profiler.kineto_results.events()
2163
            ]
2164

2165
            def EventTreeDFS(event_tree):
2166
                from collections import deque
2167

2168
                stack = deque(event_tree)
2169
                while stack:
2170
                    curr_event = stack.pop()
2171
                    yield curr_event
2172
                    for child_event in curr_event.children:
2173
                        stack.append(child_event)
2174

2175
            profiler_events = [
2176
                {
2177
                    "_name": e.name,
2178
                    "id": e.id,
2179
                    "start_time_ns": e.start_time_ns,
2180
                    "duration_time_ns": e.duration_time_ns,
2181
                    "correlation_id": e.correlation_id,
2182
                    "children": [child.id for child in e.children],
2183
                    "parent": e.parent.id if e.parent else None,
2184
                }
2185
                for e in EventTreeDFS(
2186
                    prof.profiler.kineto_results.experimental_event_tree()
2187
                )
2188
            ]
2189

2190
            with open(json_file_path, "w") as f:
2191
                json.dump([kineto_events, profiler_events], f)
2192

2193
        assert os.path.exists(json_file_path)
2194
        with open(json_file_path) as f:
2195
            kineto_events, profiler_events = json.load(f)
2196

2197
        cuda_events = [MockKinetoEvent(*event.values()) for event in kineto_events]
2198
        cpu_events = []
2199
        id_map = {}
2200
        for e in profiler_events:
2201
            event = MockProfilerEvent(**e)
2202
            id_map[event.id] = event
2203
            cpu_events.append(event)
2204
        for event in cpu_events:
2205
            parent = None if event.parent is None else id_map[event.parent]
2206
            children = [id_map[child] for child in event.children]
2207
            event.__post__init__(parent, children)
2208
        cpu_events = [event for event in cpu_events if event.parent is None]
2209
        profiler = unittest.mock.Mock()
2210
        profiler.kineto_results = unittest.mock.Mock()
2211
        profiler.kineto_results.events = unittest.mock.Mock(return_value=cuda_events)
2212
        profiler.kineto_results.experimental_event_tree = unittest.mock.Mock(
2213
            return_value=cpu_events
2214
        )
2215
        return profiler
2216

2217
    def test_utils_compute_self_time(self):
2218
        with profile() as prof:
2219
            t1, t2 = torch.ones(1, requires_grad=True), torch.ones(
2220
                1, requires_grad=True
2221
            )
2222
            z = torch.add(t1, t2)
2223
            y = torch.ones(1)
2224
            loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
2225
            loss.backward()
2226
        basic_eval = _utils.BasicEvaluation(prof.profiler)
2227
        metrics = basic_eval.metrics
2228
        self.assertTrue(len(metrics) > 0)
2229
        for event_key, event_metrics in metrics.items():
2230
            self.assertEqual(
2231
                event_metrics.self_time_ns,
2232
                event_key.event.duration_time_ns
2233
                - sum(child.duration_time_ns for child in event_key.event.children),
2234
            )
2235

2236
    def test_utils_intervals_overlap(self):
2237
        event = _utils.EventKey(MockProfilerEvent("Event 1", 1, 5, 5))
2238
        intervals = [
2239
            _utils.Interval(0, 9),
2240
            _utils.Interval(1, 2),
2241
            _utils.Interval(2, 3),
2242
            _utils.Interval(3, 4),
2243
            _utils.Interval(4, 5),
2244
            _utils.Interval(8, 12),
2245
        ]
2246
        print(event.intervals_overlap(intervals))
2247
        self.assertEqual(event.intervals_overlap(intervals), 5)
2248

2249
    def test_utils_compute_queue_depth(self):
2250
        def format_queue_depth(queue_depth_list, events):
2251
            res = ""
2252
            for data, event in zip(queue_depth_list, events):
2253
                res += f"{data.queue_depth} [{event.name}]\n"
2254
            return res
2255

2256
        # We have to use Mock because time series data is too flaky to test
2257
        profiler = self.generate_mock_profile()
2258
        basic_evaluation = _utils.BasicEvaluation(profiler)
2259
        self.assertExpectedInline(
2260
            format_queue_depth(
2261
                basic_evaluation.queue_depth_list, basic_evaluation.cuda_events
2262
            ),
2263
            """\
2264
1 [cudaLaunchKernel]
2265
2 [cudaLaunchKernel]
2266
3 [cudaLaunchKernel]
2267
4 [cudaLaunchKernel]
2268
5 [cudaLaunchKernel]
2269
4 [GPU]
2270
3 [GPU]
2271
2 [GPU]
2272
1 [GPU]
2273
0 [GPU]
2274
1 [cudaLaunchKernel]
2275
0 [GPU]
2276
""",
2277
        )
2278
        self.assertExpectedInline(
2279
            format_queue_depth(
2280
                [basic_evaluation.metrics[k] for k in basic_evaluation.event_keys],
2281
                basic_evaluation.events,
2282
            ),
2283
            """\
2284
0 [CPU (Before cudaLaunchKernel)]
2285
0 [CPU (Before cudaLaunchKernel)]
2286
0 [CPU (Before cudaLaunchKernel)]
2287
0 [CPU (Before cudaLaunchKernel)]
2288
1 [CPU (After cudaLaunchKernel)]
2289
2 [CPU (After cudaLaunchKernel)]
2290
3 [CPU (After cudaLaunchKernel)]
2291
4 [CPU (After cudaLaunchKernel)]
2292
5 [CPU (After GPU)]
2293
4 [CPU (After GPU)]
2294
2 [CPU (After GPU)]
2295
1 [CPU (After GPU)]
2296
""",
2297
        )
2298

2299
    def test_utils_compute_queue_depth_when_no_cuda_events(self):
2300
        # For traces with only cpu events, we expect empty queue depth list
2301
        x = torch.ones((1024, 1024))
2302
        with profile() as prof:
2303
            for _ in range(5):
2304
                x = x @ x
2305
        basic_evaluation = _utils.BasicEvaluation(prof.profiler)
2306
        self.assertFalse(basic_evaluation.compute_queue_depth())
2307

2308
    def test_utils_compute_idle_time(self):
2309
        profiler = self.generate_mock_profile()
2310
        basic_evaluation = _utils.BasicEvaluation(profiler)
2311
        expected_output = "\n".join(
2312
            [
2313
                f"{basic_evaluation.metrics[event_key].idle_time_ns} [{event_key.event.name}]"
2314
                for event_key in basic_evaluation.event_keys
2315
            ]
2316
        )
2317
        self.assertExpectedInline(
2318
            expected_output,
2319
            """\
2320
100000 [CPU (Before cudaLaunchKernel)]
2321
100000 [CPU (Before cudaLaunchKernel)]
2322
100000 [CPU (Before cudaLaunchKernel)]
2323
100000 [CPU (Before cudaLaunchKernel)]
2324
0 [CPU (After cudaLaunchKernel)]
2325
0 [CPU (After cudaLaunchKernel)]
2326
0 [CPU (After cudaLaunchKernel)]
2327
0 [CPU (After cudaLaunchKernel)]
2328
0 [CPU (After GPU)]
2329
0 [CPU (After GPU)]
2330
0 [CPU (After GPU)]
2331
100000 [CPU (After GPU)]""",
2332
        )
2333

2334
    @unittest.skipIf(IS_JETSON, "JSON not behaving as expected on Jetson")
2335
    def test_utils_get_optimizable_events(self):
2336
        basic_evaluation = _utils.BasicEvaluation(self.load_mock_profile())
2337
        optimizable_events = basic_evaluation.get_optimizable_events(
2338
            2, print_enable=False
2339
        )
2340
        expected_output = "\n".join(
2341
            [f"{event_key.event.name}" for event_key in optimizable_events]
2342
        )
2343
        self.assertExpectedInline(
2344
            expected_output,
2345
            """\
2346
<built-in function _cuda_synchronize>
2347
aten::copy_""",
2348
        )
2349

2350
    def test_profiler_name_pattern(self):
2351
        x = torch.ones((4096, 4096))
2352
        with profile() as prof:
2353
            for _ in range(5):
2354
                x = x @ x
2355
                x = x + x
2356
        matched_events = NamePattern(prof, "aten::mm").matched_events()
2357
        output = "\n".join([f"{event.name}" for event in matched_events])
2358
        self.assertExpectedInline(
2359
            output,
2360
            """\
2361
aten::mm
2362
aten::mm
2363
aten::mm
2364
aten::mm
2365
aten::mm""",
2366
        )
2367

2368
    # TODO: Add logic for CUDA version of test
2369
    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
2370
    def test_profiler_pattern_match_helper(self):
2371
        x = torch.ones((100, 100))
2372
        with profile() as prof:
2373
            for _ in range(5):
2374
                x = x @ x
2375
                x = x + x
2376
        event_tree = prof.profiler.kineto_results.experimental_event_tree()
2377
        pattern = Pattern(prof)
2378
        self.assertEqual([], pattern.siblings_of(event_tree[0])[0])
2379
        self.assertEqual(event_tree[1:], pattern.siblings_of(event_tree[0])[1])
2380
        child_nodes = event_tree[0].children
2381
        self.assertEqual([], pattern.siblings_of(child_nodes[0])[0])
2382
        self.assertEqual(child_nodes[1:], pattern.siblings_of(child_nodes[0])[1])
2383
        self.assertEqual(
2384
            event_tree[0], pattern.root_of(event_tree[0].children[0].children[0])
2385
        )
2386
        self.assertEqual(None, pattern.next_of(event_tree[-1]))
2387
        self.assertEqual(event_tree[1], pattern.next_of(event_tree[0]))
2388
        self.assertEqual(event_tree[0], pattern.prev_of(event_tree[1]))
2389

2390
    @unittest.skipIf(
2391
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
2392
    )
2393
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2394
    def test_profiler_extra_cuda_copy_pattern(self):
2395
        cases = (
2396
            (0, lambda: torch.ones((100, 100), device="cuda")),
2397
            (1, lambda: torch.ones((100, 100)).to("cuda")),
2398
            (1, lambda: torch.zeros((100, 100)).to("cuda")),
2399
            (1, lambda: torch.empty((100, 100)).fill_(5).to("cuda")),
2400
            (1, lambda: torch.ones((100, 100)).cuda()),
2401
            (1, lambda: torch.zeros((100, 100)).cuda()),
2402
            (1, lambda: torch.empty((100, 100)).fill_(5).cuda()),
2403
            (1, lambda: torch.rand((100, 100)).cuda()),
2404
            (1, lambda: torch.randn((100, 100)).cuda()),
2405
            (1, lambda: torch.full((100, 100), 10).cuda()),
2406
            (0, lambda: torch.rand((100, 100)).to(dtype=torch.float16)),
2407
            (0, lambda: torch.rand((100, 100)).half()),
2408
            (0, lambda: torch.rand((100, 100), device="cuda").half()),
2409
        )
2410
        num_matched = []
2411
        for _, fn in cases:
2412
            with profile(with_stack=True, record_shapes=True) as prof:
2413
                fn()
2414
            pattern = ExtraCUDACopyPattern(prof)
2415
            num_matched.append(len(pattern.matched_events()))
2416
        self.assertEqual(num_matched, [i for i, _ in cases])
2417

2418
    @unittest.skipIf(
2419
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
2420
    )
2421
    def test_profiler_for_loop_indexing_pattern(self):
2422
        x = torch.ones((100, 100))
2423

2424
        def case1():
2425
            for i in range(100):
2426
                x[i] = i
2427

2428
        def case2():
2429
            y = 0
2430
            for i in range(100):
2431
                y += x[i]
2432

2433
        def case3():
2434
            y = 1
2435
            for i in range(100):
2436
                y *= x[i]
2437

2438
        def case4():
2439
            y = x
2440
            for _ in range(100):
2441
                y = y @ x
2442

2443
        def case5():
2444
            for i in range(100):
2445
                x[i, :] = torch.arange(100) + i
2446

2447
        cases = ((1, case1), (1, case2), (1, case3), (0, case4), (1, case5))
2448
        num_matched = []
2449
        for _, fn in cases:
2450
            with profile(with_stack=True) as prof:
2451
                fn()
2452
            pattern = ForLoopIndexingPattern(prof)
2453
            num_matched.append(len(pattern.matched_events()))
2454
        self.assertEqual(num_matched, [i for i, _ in cases])
2455

2456
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2457
    def test_profiler_fp32_matmul_pattern(self):
2458
        x = torch.ones((100, 100), device="cuda")
2459
        with profile(with_stack=True) as prof:
2460
            x = x @ x
2461
        pattern = FP32MatMulPattern(prof)
2462
        has_tf32 = 0 if pattern.skip else 1
2463
        num_matched = len(pattern.matched_events())
2464
        self.assertEqual(num_matched, has_tf32)
2465

2466
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2467
    def test_profiler_extra_cuda_copy_pattern_benchmark(self):
2468
        with profile(with_stack=True, record_shapes=True) as prof:
2469
            x = torch.ones((100, 100)).to("cuda")
2470
            x = torch.ones((50, 50)).to("cuda")
2471
        pattern = ExtraCUDACopyPattern(prof)
2472
        shapes_factor_map = pattern.benchmark(pattern.matched_events())
2473
        self.assertEqual(len(shapes_factor_map), 2)
2474

2475
    def test_profiler_optimizer_single_tensor_pattern(self):
2476
        x = torch.ones((100, 100))
2477
        cases = (
2478
            (1, lambda: torch.optim.Adam(model.parameters())),
2479
            (1, lambda: torch.optim.SGD(model.parameters(), lr=0.01)),
2480
            (1, lambda: torch.optim.AdamW(model.parameters())),
2481
            (0, lambda: torch.optim.Adam(model.parameters(), foreach=True)),
2482
            (0, lambda: torch.optim.SGD(model.parameters(), lr=0.01, foreach=True)),
2483
            (0, lambda: torch.optim.AdamW(model.parameters(), foreach=True)),
2484
        )
2485
        num_matched = []
2486
        for _, fn in cases:
2487
            with profile(with_stack=True) as prof:
2488
                model = nn.Sequential(
2489
                    nn.Linear(100, 100),
2490
                    nn.ReLU(),
2491
                    nn.Linear(100, 10),
2492
                )
2493
                optimizer = fn()
2494
                optimizer.zero_grad()
2495
                y_hat = model(x)
2496
                loss = torch.nn.functional.cross_entropy(
2497
                    y_hat, torch.randint(0, 10, (100,))
2498
                )
2499
                loss.backward()
2500
                optimizer.step()
2501
            pattern = OptimizerSingleTensorPattern(prof)
2502
            num_matched.append(len(pattern.matched_events()))
2503
        self.assertEqual(num_matched, [i for i, _ in cases])
2504

2505
    def test_profiler_synchronized_dataloader_pattern(self):
2506
        dataset = torch.rand((100, 100))
2507
        sync_dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
2508
        async_dataloader = torch.utils.data.DataLoader(
2509
            dataset, batch_size=10, num_workers=4
2510
        )
2511
        with profile(with_stack=True) as prof:
2512
            next(iter(sync_dataloader))
2513
            next(iter(async_dataloader))
2514
        pattern = SynchronizedDataLoaderPattern(prof)
2515
        num_matched = len(pattern.matched_events())
2516
        self.assertEqual(num_matched, 1)
2517

2518
    @skipIfTorchDynamo(
2519
        "pattern checks for aten::_zero op which might not be there with torch.compile'd graph"
2520
    )
2521
    def test_profiler_grad_not_set_to_none_pattern(self):
2522
        x = torch.ones((100, 100))
2523
        model = nn.Sequential(
2524
            nn.Linear(100, 100),
2525
            nn.ReLU(),
2526
            nn.Linear(100, 10),
2527
        )
2528
        optimizer = torch.optim.Adam(model.parameters())
2529
        cases = (
2530
            (0, lambda: optimizer.zero_grad()),
2531
            (0, lambda: model.zero_grad()),
2532
            (1, lambda: optimizer.zero_grad(set_to_none=False)),
2533
            (1, lambda: model.zero_grad(set_to_none=False)),
2534
        )
2535
        num_matched = []
2536
        for _, fn in cases:
2537
            with profile(with_stack=True) as prof:
2538
                y_hat = model(x)
2539
                loss = torch.nn.functional.cross_entropy(
2540
                    y_hat, torch.randint(0, 10, (100,))
2541
                )
2542
                loss.backward()
2543
                optimizer.step()
2544
                fn()
2545
            pattern = GradNotSetToNonePattern(prof)
2546
            num_matched.append(len(pattern.matched_events()))
2547
        self.assertEqual(num_matched, [i for i, _ in cases])
2548

2549
    def test_profiler_conv2d_bias_followed_by_batchnorm2d_pattern(self):
2550
        x = torch.randn((1, 3, 32, 32))
2551
        cases = (
2552
            (1, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1), nn.BatchNorm2d(3))),
2553
            (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1, bias=False), nn.BatchNorm2d(3))),
2554
            (0, nn.Sequential(nn.Conv2d(3, 3, 3, 1, 1))),
2555
        )
2556
        num_matched = []
2557
        for _, model in cases:
2558
            with profile(with_stack=True, record_shapes=True) as prof:
2559
                model(x)
2560
            pattern = Conv2dBiasFollowedByBatchNorm2dPattern(prof)
2561
            num_matched.append(len(pattern.matched_events()))
2562
        self.assertEqual(num_matched, [i for i, _ in cases])
2563

2564
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2565
    def test_profiler_matmul_dim_fp16_pattern(self):
2566
        cases = (
2567
            (1, torch.randn((201, 201), device="cuda", dtype=torch.float16)),
2568
            (1, torch.randn((3, 97, 97), device="cuda", dtype=torch.float16)),
2569
            (0, torch.randn((200, 200), device="cuda", dtype=torch.float16)),
2570
            (0, torch.randn((3, 200, 200), device="cuda", dtype=torch.float16)),
2571
        )
2572
        num_matched = []
2573
        for _, x in cases:
2574
            with profile(with_stack=True, record_shapes=True) as prof:
2575
                x @ x
2576
            pattern = MatMulDimInFP16Pattern(prof)
2577
            num_matched.append(len(pattern.matched_events()))
2578
        self.assertEqual(num_matched, [i for i, _ in cases])
2579

2580
    @skipIfTorchDynamo("profiler gets ignored if dynamo activated")
2581
    def test_profiler_pattern_matcher_json_report(self):
2582
        x = torch.ones((100, 100))
2583
        model = nn.Sequential(
2584
            nn.Linear(100, 100),
2585
            nn.ReLU(),
2586
            nn.Linear(100, 10),
2587
        )
2588
        optimizer = torch.optim.Adam(model.parameters())
2589
        with profile(with_stack=True, record_shapes=True) as prof:
2590
            y_hat = model(x)
2591
            loss = torch.nn.functional.cross_entropy(
2592
                y_hat, torch.randint(0, 10, (100,))
2593
            )
2594
            loss.backward()
2595
            optimizer.step()
2596
            optimizer.zero_grad()
2597

2598
        with tempfile.TemporaryDirectory() as tmpdir:
2599
            report_all_anti_patterns(prof, json_report_dir=tmpdir, print_enable=False)
2600

2601
            with open(os.path.join(tmpdir, "torchtidy_report.json")) as f:
2602
                report = json.load(f)
2603

2604
            # It is platform dependent whether the path will include "profiler/"
2605
            keys = [k for k in report.keys() if k.endswith("test_profiler.py")]
2606
            self.assertEqual(len(keys), 1, f"{keys}")
2607
            entry = report[keys[0]]
2608

2609
            self.assertTrue(len(entry) > 0)
2610
            expected_fields = sorted(["line_number", "name", "url", "message"])
2611
            for event in entry:
2612
                actual_fields = sorted(event.keys())
2613
                self.assertEqual(expected_fields, actual_fields)
2614

2615
    @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding")
2616
    def test_fuzz_symbolize(self):
2617
        # generate some random addresses in the text section and make sure the
2618
        # symbolizers do not throw exceptions/crash
2619
        def get_text_sections():
2620
            text_sections = []
2621
            seen = set()
2622
            for filename in os.listdir("/proc/self/map_files"):
2623
                library = os.readlink("/proc/self/map_files/" + filename)
2624
                if ".so" not in library or library in seen:
2625
                    continue
2626
                seen.add(library)
2627
                with open(os.path.join("/proc/self/map_files", library), "rb") as f:
2628
                    mm = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
2629

2630
                    def unpack(fmt, offset):
2631
                        return struct.unpack(
2632
                            fmt, mm[offset : offset + struct.calcsize(fmt)]
2633
                        )
2634

2635
                    if mm[:4] != b"\x7fELF":
2636
                        continue
2637
                    (section_headers_start,) = unpack("Q", 40)
2638
                    (section_header_size,) = unpack("H", 58)
2639
                    (num_section_headers,) = unpack("H", 60)
2640
                    (shstrndx,) = unpack("H", 62)
2641
                    (shstrtab_offset,) = unpack(
2642
                        "Q", section_headers_start + shstrndx * section_header_size + 24
2643
                    )
2644
                    for i in range(num_section_headers):
2645
                        (section_name_offset,) = unpack(
2646
                            "I", section_headers_start + i * section_header_size
2647
                        )
2648
                        name_start = shstrtab_offset + section_name_offset
2649
                        section_name = mm[name_start : name_start + 6]
2650
                        if section_name != b".text\0":
2651
                            continue
2652
                        (section_offset,) = unpack(
2653
                            "Q", section_headers_start + i * section_header_size + 24
2654
                        )
2655
                        (section_size,) = unpack(
2656
                            "Q", section_headers_start + i * section_header_size + 32
2657
                        )
2658
                        start = int(filename.split("-")[0], 16) + section_offset
2659
                        text_sections.append((start, section_size))
2660
                        break
2661
                    mm.close()
2662
            return text_sections
2663

2664
        r = random.Random()
2665
        r.seed(1)
2666
        text_sections = get_text_sections()
2667
        addrs = []
2668
        for i in range(200):
2669
            s = r.randrange(0, len(text_sections))
2670
            start, size = text_sections[s]
2671
            addr = r.randrange(start, start + size)
2672
            addrs.append(addr)
2673
        fast = torch._C._profiler.symbolize_addresses(addrs, "fast")
2674
        dladdr = torch._C._profiler.symbolize_addresses(addrs, "dladdr")
2675
        addr2line = torch._C._profiler.symbolize_addresses(addrs, "addr2line")
2676
        self.assertEqual(len(fast), len(addrs))
2677
        self.assertEqual(len(addr2line), len(fast))
2678

2679

2680
if __name__ == "__main__":
2681
    run_tests()
2682

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.