pytorch

Форк
0
/
test_profiler_tree.py 
1133 строки · 46.9 Кб
1
# Owner(s): ["oncall: profiler"]
2

3
import functools
4
import os
5
import re
6
import textwrap
7
import traceback
8
import unittest
9

10
import expecttest
11

12
import torch
13
from torch._C._profiler import _ExtraFields_PyCall, _ExtraFields_PyCCall
14
from torch.testing._internal.common_utils import (
15
    IS_ARM64,
16
    IS_WINDOWS,
17
    run_tests,
18
    skipIfTorchDynamo,
19
    TEST_WITH_CROSSREF,
20
    TestCase,
21
)
22
from torch.utils._pytree import tree_map
23

24

25
# These functions can vary from based on platform and build (e.g. with CUDA)
26
# and generally distract from rather than adding to the test.
27
PRUNE_ALL = 1
28
KEEP_ELLIPSES = 2
29
KEEP_NAME_AND_ELLIPSES = 3
30

31
PRUNE_FUNCTIONS = {
32
    "torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES,
33
    "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES,
34
    "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES,
35
    "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES,
36
    "<built-in method __exit__ of torch._C.DisableTorchFunctionSubclass object at 0xXXXXXXXXXXXX>": PRUNE_ALL,
37
    "cudaStreamIsCapturing": PRUNE_ALL,
38
    # These show up only on CUDA, prune them so the CUDA and CPU expected results can be the same
39
    "cudaGetDeviceCount": PRUNE_ALL,
40
    "cudaGetDeviceProperties_v2": PRUNE_ALL,
41
}
42

43
# ROCTracer is currently not producing events that profiler can extract. We
44
# should bring it up to parity with CUPTI Kineto / profiler integration, but in
45
# the mean time there is still utility in running tests but not checking that
46
# the values match expected value.
47
#  1) We will still catch runtime errors and assert failures
48
#  2) We can diff the output to see how far we are from parity
49
#
50
# TODO: We also fail to capture events for Windows on some platforms.
51
ALLOW_CUDA_FAILURE = (torch.version.hip is not None) or IS_WINDOWS
52

53

54
class TorchFunctionTensor(torch.Tensor):
55
    @classmethod
56
    def __torch_function__(cls, func, types, args=(), kwargs=None):
57
        return super().__torch_function__(func, types, args, kwargs)
58

59

60
class TorchDispatchTensor(torch.Tensor):
61
    @staticmethod
62
    def __new__(cls, elem):
63
        t = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
64
        t.elem = elem
65
        return t
66

67
    @classmethod
68
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
69
        def unwrap(x):
70
            return x.elem if isinstance(x, TorchDispatchTensor) else x
71

72
        def wrap(x):
73
            return TorchDispatchTensor(x) if isinstance(x, torch.Tensor) else x
74

75
        args = tree_map(unwrap, args)
76
        kwargs = tree_map(unwrap, kwargs or {})
77

78
        return tree_map(wrap, func(*args, **kwargs))
79

80

81
class ProfilerTree:
82
    @staticmethod
83
    def test(f):
84
        """Mark unit test that will be using ProfilerTree to test traces.
85

86
        This decorator serves two purposes. First, it provides a method name
87
        that `format` can use to tell where the test runner (which is
88
        environment specific) ends and the unit test begins. Second, it runs
89
        the test with replicates and allows `assertTreesMatch` to adjust
90
        based on which replicate is running.
91
        """
92

93
        @functools.wraps(f)
94
        def begin_unit_test_marker(self, replicates=3):
95
            try:
96
                for i in range(replicates):
97
                    self.tree_replicate = i
98
                    out = f(self)
99
                    if self.tree_replicate is None:
100
                        break
101
                return out
102
            finally:
103
                delattr(self, "tree_replicate")
104

105
        return begin_unit_test_marker
106

107
    @classmethod
108
    def format(cls, profiler, indent: int = 0):
109
        def flatten(nodes, depth=0, out=None):
110
            if out is None:
111
                out = []
112

113
            for node in nodes:
114
                cls.validate_node(node)
115
                name = cls.fmt_name(node.name)
116
                prune_level = PRUNE_FUNCTIONS.get(name.strip(), None)
117
                if prune_level is None:
118
                    out.append((depth, name))
119
                    flatten(node.children, depth + 1, out)
120
                elif prune_level == KEEP_NAME_AND_ELLIPSES:
121
                    out.append((depth, name))
122
                    if node.children:
123
                        out.append((depth + 1, "..."))
124
                elif prune_level == KEEP_ELLIPSES:
125
                    out.append((depth, "..."))
126
                else:
127
                    assert prune_level == PRUNE_ALL
128

129
            return out
130

131
        flat_nodes = flatten(profiler.kineto_results.experimental_event_tree())
132

133
        # Profiler inserts a `cudaDeviceSynchronize` at the end of profiling.
134
        # and may also insert 'Context Sync' CUDA synchronization event.
135
        if flat_nodes and flat_nodes[-2][1] == "cudaDeviceSynchronize":
136
            flat_nodes = flat_nodes[:-2]
137

138
        if flat_nodes and flat_nodes[-1][1] == "cudaDeviceSynchronize":
139
            flat_nodes = flat_nodes[:-1]
140

141
        # Profiler inserts a `hipDeviceSynchronize` at the end of profiling.
142
        if flat_nodes and flat_nodes[-1][1] == "hipDeviceSynchronize":
143
            flat_nodes = flat_nodes[:-1]
144

145
        min_depth = min(
146
            [d + 1 for d, name in flat_nodes if "begin_unit_test_marker" in name] or [0]
147
        )
148
        return textwrap.indent(
149
            "\n".join(
150
                [
151
                    f"{'  ' * (d - min_depth)}{name.rstrip()}"
152
                    for d, name in flat_nodes
153
                    if d >= min_depth
154
                ]
155
            ),
156
            " " * indent,
157
        )
158

159
    @staticmethod
160
    def fmt_name(name: str) -> str:
161
        match = re.match(r"^(.*)\.py\(([0-9]+)\): (.*)$", name)
162
        if match:
163
            filename, _, fn = match.groups()
164

165
            # This test can appear as `test/profiler/test_profiler_tree.py`
166
            # depending on where it is run from.
167
            test_file = os.path.splitext(os.path.split(__file__)[1])[0]
168
            if filename.endswith(test_file):
169
                filename = test_file
170

171
            # We test against a string literal, so all paths have to look like POSIX paths.
172
            filename = filename.replace(os.sep, "/")
173

174
            # We don't want to have to update this test every time PyTorch changes.
175
            # At some point we should test some line numbers, but for now it's
176
            # too brittle.
177
            lineno = "..."
178

179
            return f"{filename}.py({lineno}): {fn}"
180

181
        for kernel_pattern in (
182
            "void at::native::elementwise_kernel",
183
            "void at::native::reduce_kernel",
184
            "void at::native::vectorized_elementwise_kernel",
185
            "void at::native::unrolled_elementwise_kernel",
186
            r"void [a-zA-Z0-9]+_kernel",  # Nvidia kernels.
187
        ):
188
            name = re.sub(
189
                rf"{kernel_pattern}<.+>\(.+\)$",
190
                f"{kernel_pattern.replace('[a-zA-Z0-9]+', '...')}<...>(...)",
191
                name,
192
            )
193

194
        return re.sub("object at 0x[0-9a-fA-F]+>", "object at 0xXXXXXXXXXXXX>", name)
195

196
    @classmethod
197
    def validate_node(cls, node):
198
        extra_fields = node.extra_fields
199
        if isinstance(extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)):
200
            # Check that the lineage established by the profiler matches the
201
            # caller recorded by the Python tracer.
202
            parent = node.parent
203
            while parent is not None:
204
                if isinstance(parent.extra_fields, _ExtraFields_PyCall):
205
                    break
206
                parent = parent.parent
207

208
            def to_string(frame_state):
209
                return f"{frame_state.file_name}(...): {frame_state.function_name}"
210

211
            if parent:
212
                parent_name = to_string(parent.extra_fields.callsite)
213
                caller_name = to_string(extra_fields.caller)
214
                assert parent_name == caller_name, f"{parent_name} vs. {caller_name}"
215

216

217
@unittest.skipIf(IS_ARM64, "Not working on ARM")
218
class TestProfilerTree(TestCase):
219
    def assertTreesMatch(self, actual: str, expected: str, allow_failure: bool = False):
220
        # Warning: Here be dragons
221
        #   Different platforms will have subtly different behavior for Python
222
        #   tracing. Observed differences include:
223
        #     1) Windows symbolicates names differently from posix
224
        #     2) The profile callback for c_call does not fire for Tensor.__pow__
225
        #        on certain platforms. This is not caused by the function tracer,
226
        #        but by cPython itself.
227
        #
228
        # The purpose of these unit tests is to ensure that the profiler is
229
        # doing reasonable things. When these platform dependent variations occur
230
        # simply coerce them into a platform independent form. If you made a
231
        # change in the codebase which changes the trace produced, simply use
232
        # EXPECTTEST_ACCEPT=1 to update the tests to reflect the new structure.
233

234
        # expecttest will not show the diff view if `len(actual) < len(expected)`
235
        if not expecttest.ACCEPT:
236
            actual = actual.ljust(len(expected))
237
        self.maxDiff = None
238

239
        replicate = getattr(self, "tree_replicate", None)
240
        self.assertIsNotNone(
241
            replicate, "Please annotate test with `@ProfilerTree.test`"
242
        )
243

244
        # The profiler should produce deterministic results and should return
245
        # to a clean state after each run. As a result, only the first
246
        # replicate is allowed to update `expected`. If subsequent runs do not
247
        # match it is a bug in the profiler.
248
        if replicate:
249
            self.assertEqual(actual, expected)
250
        else:
251
            try:
252
                self.assertExpectedInline(actual, expected, skip=1)
253
            except AssertionError as e:
254
                if allow_failure:
255
                    self.tree_replicate = None
256
                    msg = traceback.format_exception_only(type(e), e)[0]
257
                    print(msg.split("AssertionError:")[-1])
258
                else:
259
                    raise
260

261
    # TODO: Add logic for CUDA version of test
262
    @ProfilerTree.test
263
    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
264
    def test_profiler_experimental_tree(self):
265
        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
266
        with torch.profiler.profile() as p:
267
            z = torch.add(t1, t2)
268
            y = torch.ones(1)
269
            loss = (y - z) ** 2
270
            loss.backward()
271

272
        self.assertTreesMatch(
273
            ProfilerTree.format(p.profiler, 12),
274
            """\
275
            aten::add
276
            aten::ones
277
              aten::empty
278
              aten::fill_
279
            aten::sub
280
            aten::pow
281
              aten::result_type
282
              aten::to
283
            aten::ones_like
284
              aten::empty_like
285
                aten::empty_strided
286
              aten::fill_
287
            autograd::engine::evaluate_function: PowBackward0
288
              PowBackward0
289
                aten::pow
290
                  aten::result_type
291
                  aten::to
292
                  aten::copy_
293
                aten::mul
294
                  aten::mul
295
                    aten::to
296
                      aten::_to_copy
297
                        aten::empty_strided
298
                        aten::copy_
299
                aten::mul
300
            autograd::engine::evaluate_function: SubBackward0
301
              SubBackward0
302
                aten::neg
303
            autograd::engine::evaluate_function: AddBackward0
304
              AddBackward0
305
            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
306
              torch::autograd::AccumulateGrad
307
                aten::new_empty_strided
308
                  aten::empty_strided
309
                aten::copy_
310
            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
311
              torch::autograd::AccumulateGrad
312
                aten::detach
313
                  detach""",
314
        )
315

316
    # TODO: Add logic for CUDA version of test
317
    @ProfilerTree.test
318
    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
319
    def test_profiler_experimental_tree_with_record_function(self):
320
        with torch.profiler.profile() as p:
321
            with torch.autograd.profiler.record_function("Top level Annotation"):
322
                with torch.autograd.profiler.record_function("First Annotation"):
323
                    x = torch.ones((1,), requires_grad=True)
324

325
                # Check that we correctly handle the case when a user
326
                # annotation does not call `__exit__`.
327
                _ = torch.autograd.profiler.record_function(
328
                    "Second Annotation"
329
                ).__enter__()
330

331
                y = x + 1
332
                with torch.autograd.profiler.record_function("Third Annotation"):
333
                    y.backward()
334

335
        # NB: The `aten::zeros` before the record function annotations are due to
336
        # `at::cpp_custom_type_hack`. When we switch to `torch::CustomClassHolder`
337
        # they will disappear.
338
        self.assertTreesMatch(
339
            ProfilerTree.format(p.profiler, 12),
340
            """\
341
            Top level Annotation
342
              First Annotation
343
                aten::ones
344
                  aten::empty
345
                  aten::fill_
346
              Second Annotation
347
                aten::add
348
                  aten::to
349
                    aten::_to_copy
350
                      aten::empty_strided
351
                      aten::copy_
352
                Third Annotation
353
                  aten::ones_like
354
                    aten::empty_like
355
                      aten::empty_strided
356
                    aten::fill_
357
                  autograd::engine::evaluate_function: AddBackward0
358
                    AddBackward0
359
                  autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
360
                    torch::autograd::AccumulateGrad
361
                      aten::new_empty_strided
362
                        aten::empty_strided
363
                      aten::copy_""",
364
        )
365

366
    # TODO: Add logic for CUDA version of test
367
    @ProfilerTree.test
368
    @unittest.skipIf(torch.cuda.is_available(), "Test not working for CUDA")
369
    def test_profiler_experimental_tree_with_memory(self):
370
        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
371
        with torch.profiler.profile(profile_memory=True) as p:
372
            z = torch.add(t1, t2)
373
            y = torch.ones(1)
374
            loss = (y - z) ** 2
375
            loss.backward()
376

377
        self.assertTreesMatch(
378
            ProfilerTree.format(p.profiler, 12),
379
            """\
380
            aten::add
381
              [memory]
382
            aten::ones
383
              aten::empty
384
                [memory]
385
              aten::fill_
386
            aten::sub
387
              [memory]
388
            aten::pow
389
              aten::result_type
390
              aten::to
391
              [memory]
392
            aten::ones_like
393
              aten::empty_like
394
                aten::empty_strided
395
                  [memory]
396
              aten::fill_
397
            autograd::engine::evaluate_function: PowBackward0
398
              PowBackward0
399
                aten::pow
400
                  aten::result_type
401
                  aten::to
402
                  [memory]
403
                  aten::copy_
404
                aten::mul
405
                  [memory]
406
                  aten::mul
407
                    aten::to
408
                      aten::_to_copy
409
                        aten::empty_strided
410
                          [memory]
411
                        aten::copy_
412
                    [memory]
413
                    [memory]
414
                  [memory]
415
                aten::mul
416
                  [memory]
417
                [memory]
418
                [memory]
419
              [memory]
420
            autograd::engine::evaluate_function: SubBackward0
421
              SubBackward0
422
                aten::neg
423
                  [memory]
424
              [memory]
425
            autograd::engine::evaluate_function: AddBackward0
426
              AddBackward0
427
            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
428
              torch::autograd::AccumulateGrad
429
                aten::new_empty_strided
430
                  aten::empty_strided
431
                    [memory]
432
                aten::copy_
433
            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
434
              torch::autograd::AccumulateGrad
435
                aten::detach
436
                  detach
437
            [memory]""",
438
        )
439

440
    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
441
    @unittest.skipIf(
442
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
443
    )
444
    @ProfilerTree.test
445
    def test_profiler_experimental_tree_with_memory_and_stack(self):
446
        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
447
        with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
448
            z = torch.add(t1, t2)
449
            y = torch.ones(1)
450
            loss = torch.pow(y - z, 2)
451
            loss.backward()
452

453
        self.assertTreesMatch(
454
            ProfilerTree.format(p.profiler, 12),
455
            """\
456
            test_profiler_tree.py(...): test_profiler_experimental_tree_with_memory_and_stack
457
              torch/profiler/profiler.py(...): __enter__
458
                ...
459
              <built-in method add of type object at 0xXXXXXXXXXXXX>
460
                aten::add
461
                  [memory]
462
              <built-in method ones of type object at 0xXXXXXXXXXXXX>
463
                aten::ones
464
                  aten::empty
465
                    [memory]
466
                  aten::fill_
467
              aten::sub
468
                [memory]
469
              <built-in method pow of type object at 0xXXXXXXXXXXXX>
470
                aten::pow
471
                  aten::result_type
472
                  aten::to
473
                  [memory]
474
              torch/_tensor.py(...): backward
475
                <built-in function _has_torch_function_unary>
476
                torch/autograd/__init__.py(...): backward
477
                  <built-in method _are_functorch_transforms_active of PyCapsule object at 0xXXXXXXXXXXXX>
478
                  <built-in function isinstance>
479
                  <built-in function isinstance>
480
                  <built-in function len>
481
                  torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
482
                  torch/autograd/__init__.py(...): _make_grads
483
                    typing.py(...): inner
484
                      typing.py(...): __hash__
485
                        <built-in function hash>
486
                    typing.py(...): cast
487
                    <built-in function isinstance>
488
                    <built-in function isinstance>
489
                    <built-in function isinstance>
490
                    <built-in function isinstance>
491
                    <built-in function isinstance>
492
                    <built-in function isinstance>
493
                    <built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
494
                    <built-in function isinstance>
495
                    <built-in function isinstance>
496
                    <built-in method ones_like of type object at 0xXXXXXXXXXXXX>
497
                      aten::ones_like
498
                        aten::empty_like
499
                          aten::empty_strided
500
                            [memory]
501
                        aten::fill_
502
                    <built-in method append of list object at 0xXXXXXXXXXXXX>
503
                  torch/autograd/graph.py(...): _engine_run_backward
504
                    logging/__init__.py(...): getEffectiveLevel
505
                    <built-in method run_backward of torch._C._EngineBase object at 0xXXXXXXXXXXXX>
506
                      autograd::engine::evaluate_function: PowBackward0
507
                        PowBackward0
508
                          aten::pow
509
                            aten::result_type
510
                            aten::to
511
                            [memory]
512
                            aten::copy_
513
                          aten::mul
514
                            [memory]
515
                            aten::mul
516
                              aten::to
517
                                aten::_to_copy
518
                                  aten::empty_strided
519
                                    [memory]
520
                                  aten::copy_
521
                              [memory]
522
                              [memory]
523
                            [memory]
524
                          aten::mul
525
                            [memory]
526
                          [memory]
527
                          [memory]
528
                        [memory]
529
                      autograd::engine::evaluate_function: SubBackward0
530
                        SubBackward0
531
                          aten::neg
532
                            [memory]
533
                        [memory]
534
                      autograd::engine::evaluate_function: AddBackward0
535
                        AddBackward0
536
                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
537
                        torch::autograd::AccumulateGrad
538
                          aten::new_empty_strided
539
                            aten::empty_strided
540
                              [memory]
541
                          aten::copy_
542
                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
543
                        torch::autograd::AccumulateGrad
544
                          aten::detach
545
                            detach
546
                [memory]
547
              torch/profiler/profiler.py(...): __exit__
548
                torch/profiler/profiler.py(...): stop
549
                  ...""",
550
        )
551

552
    @skipIfTorchDynamo("too slow")
553
    @unittest.skipIf(
554
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
555
    )
556
    @ProfilerTree.test
557
    def test_profiler_experimental_tree_with_stack_and_modules(self):
558
        class MyModule(torch.nn.Module):
559
            def __init__(self) -> None:
560
                super().__init__()
561
                self.layers = [
562
                    torch.nn.ReLU(),
563
                    torch.nn.Linear(1, 1),
564
                    torch.nn.ReLU(),
565
                ]
566

567
            def forward(self, x: torch.Tensor) -> torch.Tensor:
568
                for l in self.layers:
569
                    x = l(x)
570
                return x
571

572
        model = MyModule()
573
        with torch.profiler.profile(with_stack=True) as p:
574
            for _ in range(2):
575
                model(torch.ones((1,)))
576
        self.maxDiff = None
577
        self.assertTreesMatch(
578
            ProfilerTree.format(p.profiler, 12),
579
            """\
580
            test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_modules
581
              torch/profiler/profiler.py(...): __enter__
582
                ...
583
              <built-in method ones of type object at 0xXXXXXXXXXXXX>
584
                aten::ones
585
                  aten::empty
586
                  aten::fill_
587
              nn.Module: MyModule_0
588
                torch/nn/modules/module.py(...): _call_impl
589
                  <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
590
                  test_profiler_tree.py(...): forward
591
                    nn.Module: ReLU_0
592
                      torch/nn/modules/module.py(...): _call_impl
593
                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
594
                        torch/nn/modules/activation.py(...): forward
595
                          torch/nn/functional.py(...): relu
596
                            <built-in function _has_torch_function_unary>
597
                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
598
                              aten::relu
599
                                aten::clamp_min
600
                    nn.Module: Linear_0
601
                      torch/nn/modules/module.py(...): _call_impl
602
                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
603
                        torch/nn/modules/linear.py(...): forward
604
                          torch/nn/modules/module.py(...): __getattr__
605
                          torch/nn/modules/module.py(...): __getattr__
606
                          <built-in function linear>
607
                            aten::linear
608
                              aten::reshape
609
                                aten::view
610
                              aten::t
611
                                aten::transpose
612
                                  aten::as_strided
613
                              aten::addmm
614
                                aten::expand
615
                                  aten::as_strided
616
                                aten::copy_
617
                                aten::resolve_conj
618
                                aten::resolve_conj
619
                                aten::resolve_conj
620
                              aten::view
621
                    nn.Module: ReLU_1
622
                      torch/nn/modules/module.py(...): _call_impl
623
                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
624
                        torch/nn/modules/activation.py(...): forward
625
                          torch/nn/functional.py(...): relu
626
                            <built-in function _has_torch_function_unary>
627
                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
628
                              aten::relu
629
                                aten::clamp_min
630
              <built-in method ones of type object at 0xXXXXXXXXXXXX>
631
                aten::ones
632
                  aten::empty
633
                  aten::fill_
634
              nn.Module: MyModule_0
635
                torch/nn/modules/module.py(...): _call_impl
636
                  <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
637
                  test_profiler_tree.py(...): forward
638
                    nn.Module: ReLU_0
639
                      torch/nn/modules/module.py(...): _call_impl
640
                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
641
                        torch/nn/modules/activation.py(...): forward
642
                          torch/nn/functional.py(...): relu
643
                            <built-in function _has_torch_function_unary>
644
                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
645
                              aten::relu
646
                                aten::clamp_min
647
                    nn.Module: Linear_0
648
                      torch/nn/modules/module.py(...): _call_impl
649
                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
650
                        torch/nn/modules/linear.py(...): forward
651
                          torch/nn/modules/module.py(...): __getattr__
652
                          torch/nn/modules/module.py(...): __getattr__
653
                          <built-in function linear>
654
                            aten::linear
655
                              aten::reshape
656
                                aten::view
657
                              aten::t
658
                                aten::transpose
659
                                  aten::as_strided
660
                              aten::addmm
661
                                aten::expand
662
                                  aten::as_strided
663
                                aten::copy_
664
                                aten::resolve_conj
665
                                aten::resolve_conj
666
                                aten::resolve_conj
667
                              aten::view
668
                    nn.Module: ReLU_1
669
                      torch/nn/modules/module.py(...): _call_impl
670
                        <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
671
                        torch/nn/modules/activation.py(...): forward
672
                          torch/nn/functional.py(...): relu
673
                            <built-in function _has_torch_function_unary>
674
                            <built-in method relu of type object at 0xXXXXXXXXXXXX>
675
                              aten::relu
676
                                aten::clamp_min
677
              torch/profiler/profiler.py(...): __exit__
678
                torch/profiler/profiler.py(...): stop
679
                  ...""",
680
        )
681

682
    @unittest.skipIf(
683
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
684
    )
685
    @ProfilerTree.test
686
    def test_profiler_experimental_tree_with_stack_and_torch_function(self):
687
        x = TorchFunctionTensor(torch.ones((1,)))
688
        y = torch.ones((1,))
689

690
        # There's some lazy initialization in __torch_function__. If we don't
691
        # run this the first run won't match the replicates.
692
        torch.add(x, y)
693

694
        with torch.profiler.profile(with_stack=True) as p:
695
            torch.add(x, y)
696

697
        self.assertTreesMatch(
698
            ProfilerTree.format(p.profiler, 12),
699
            """\
700
            test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_torch_function
701
              torch/profiler/profiler.py(...): __enter__
702
                ...
703
              <built-in method add of type object at 0xXXXXXXXXXXXX>
704
                test_profiler_tree.py(...): __torch_function__
705
                  torch/_tensor.py(...): __torch_function__
706
                    <built-in function all>
707
                      torch/_tensor.py(...): <genexpr>
708
                        <built-in function issubclass>
709
                      torch/_tensor.py(...): <genexpr>
710
                    <built-in method add of type object at 0xXXXXXXXXXXXX>
711
                      aten::add
712
                    torch/_tensor.py(...): _convert
713
                      <built-in function isinstance>
714
                      <built-in function isinstance>
715
                      <built-in method as_subclass of Tensor object at 0xXXXXXXXXXXXX>
716
                        aten::alias
717
                      <built-in function isinstance>
718
              torch/profiler/profiler.py(...): __exit__
719
                torch/profiler/profiler.py(...): stop
720
                  ...""",
721
        )
722

723
    @unittest.skipIf(
724
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
725
    )
726
    @ProfilerTree.test
727
    def test_profiler_experimental_tree_with_stack_and_torch_dispatch(self):
728
        x = TorchDispatchTensor(torch.ones((1,)))
729
        y = torch.ones((1,))
730

731
        # warmup round
732
        with torch.profiler.profile(with_stack=True):
733
            x + y
734

735
        with torch.profiler.profile(with_stack=True) as p:
736
            x + y
737

738
        self.assertTreesMatch(
739
            ProfilerTree.format(p.profiler, 12),
740
            """\
741
            test_profiler_tree.py(...): test_profiler_experimental_tree_with_stack_and_torch_dispatch
742
              torch/profiler/profiler.py(...): __enter__
743
                ...
744
              aten::add
745
                torch/_library/simple_registry.py(...): find_torch_dispatch_rule
746
                  torch/_library/simple_registry.py(...): find
747
                  torch/_library/simple_registry.py(...): find
748
                    <built-in method get of dict object at 0xXXXXXXXXXXXX>
749
                test_profiler_tree.py(...): __torch_dispatch__
750
                  torch/utils/_pytree.py(...): tree_map
751
                    ...
752
                  torch/utils/_pytree.py(...): tree_map
753
                    ...
754
                  torch/_ops.py(...): __call__
755
                    <built-in method  of PyCapsule object at 0xXXXXXXXXXXXX>
756
                      aten::add
757
                  torch/utils/_pytree.py(...): tree_map
758
                    ...
759
              torch/profiler/profiler.py(...): __exit__
760
                torch/profiler/profiler.py(...): stop
761
                  ...""",
762
        )
763

764
    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
765
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
766
    @ProfilerTree.test
767
    def test_profiler_experimental_tree_cuda(self):
768
        with torch.profiler.profile(profile_memory=True) as p:
769
            weight = torch.ones(1, device="cuda", requires_grad=True)
770
            x = torch.ones(1, device="cuda")
771
            y = torch.add(weight, x)
772
            loss = torch.pow(y, 2)
773
            loss.backward()
774
            torch.optim.SGD([weight], lr=0.01, momentum=0.9).step()
775

776
        self.assertTreesMatch(
777
            ProfilerTree.format(p.profiler, 12),
778
            """\
779
            aten::ones
780
              aten::empty
781
                [memory]
782
              aten::fill_
783
                cudaLaunchKernel
784
                  void at::native::vectorized_elementwise_kernel<...>(...)
785
            aten::ones
786
              aten::empty
787
                [memory]
788
              aten::fill_
789
                cudaLaunchKernel
790
                  void at::native::vectorized_elementwise_kernel<...>(...)
791
            aten::add
792
              cudaLaunchKernel
793
                void at::native::vectorized_elementwise_kernel<...>(...)
794
              [memory]
795
            aten::pow
796
              cudaLaunchKernel
797
                void at::native::vectorized_elementwise_kernel<...>(...)
798
              aten::result_type
799
              aten::to
800
              [memory]
801
            aten::ones_like
802
              aten::empty_like
803
                aten::empty_strided
804
                  [memory]
805
              aten::fill_
806
                cudaLaunchKernel
807
                  void at::native::vectorized_elementwise_kernel<...>(...)
808
            autograd::engine::evaluate_function: PowBackward0
809
              PowBackward0
810
                aten::pow
811
                  aten::result_type
812
                  aten::to
813
                  [memory]
814
                  aten::copy_
815
                    cudaMemcpyAsync
816
                      Memcpy DtoD (Device -> Device)
817
                aten::mul
818
                  [memory]
819
                  aten::mul
820
                    cudaLaunchKernel
821
                      void at::native::vectorized_elementwise_kernel<...>(...)
822
                    [memory]
823
                  [memory]
824
                aten::mul
825
                  cudaLaunchKernel
826
                    void at::native::vectorized_elementwise_kernel<...>(...)
827
                  [memory]
828
                [memory]
829
                [memory]
830
            autograd::engine::evaluate_function: AddBackward0
831
              AddBackward0
832
            autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
833
              torch::autograd::AccumulateGrad
834
                aten::detach
835
                  detach
836
            [memory]
837
            aten::zeros
838
              aten::zeros
839
                aten::empty
840
                  [memory]
841
                aten::zero_
842
            Optimizer.step#SGD.step
843
              aten::empty
844
                [memory]
845
              [memory]
846
              [memory]
847
              aten::clone
848
                aten::empty_strided
849
                  [memory]
850
                aten::copy_
851
                  cudaMemcpyAsync
852
                    Memcpy DtoD (Device -> Device)
853
              aten::detach
854
                detach
855
              aten::add_
856
                cudaLaunchKernel
857
                  void at::native::vectorized_elementwise_kernel<...>(...)
858
            [memory]""",  # noqa: B950
859
            allow_failure=ALLOW_CUDA_FAILURE,
860
        )
861

862
    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
863
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
864
    @ProfilerTree.test
865
    def test_profiler_experimental_tree_cuda_with_stream(self):
866
        streams = [torch.cuda.Stream() for _ in range(3)]
867
        results = []
868
        with torch.profiler.profile(profile_memory=True) as p:
869
            x = torch.ones((4, 4), device="cuda")
870
            for stream in streams:
871
                with torch.cuda.stream(stream):
872
                    results.append(torch.tanh(x) - x)
873
        del results
874
        for s in streams:
875
            torch.cuda.current_stream().wait_stream(s)
876

877
        self.assertTreesMatch(
878
            ProfilerTree.format(p.profiler, 12),
879
            """\
880
            aten::ones
881
              aten::empty
882
                [memory]
883
              aten::fill_
884
                cudaLaunchKernel
885
                  void at::native::vectorized_elementwise_kernel<...>(...)
886
            aten::tanh
887
              cudaMalloc
888
              cudaLaunchKernel
889
                void at::native::vectorized_elementwise_kernel<...>(...)
890
              [memory]
891
            aten::sub
892
              cudaLaunchKernel
893
                void at::native::vectorized_elementwise_kernel<...>(...)
894
              [memory]
895
            [memory]
896
            aten::tanh
897
              cudaMalloc
898
              cudaLaunchKernel
899
                void at::native::vectorized_elementwise_kernel<...>(...)
900
              [memory]
901
            aten::sub
902
              cudaLaunchKernel
903
                void at::native::vectorized_elementwise_kernel<...>(...)
904
              [memory]
905
            [memory]
906
            aten::tanh
907
              cudaMalloc
908
              cudaLaunchKernel
909
                void at::native::vectorized_elementwise_kernel<...>(...)
910
              [memory]
911
            aten::sub
912
              cudaLaunchKernel
913
                void at::native::vectorized_elementwise_kernel<...>(...)
914
              [memory]
915
            [memory]""",
916
            allow_failure=ALLOW_CUDA_FAILURE,
917
        )
918

919
    @unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
920
    @unittest.skipIf(
921
        TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
922
    )
923
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
924
    @ProfilerTree.test
925
    def test_profiler_experimental_tree_cuda_detailed(self):
926
        # Do lazy imports ahead of time to avoid it showing up in the tree
927
        import torch.nested._internal.nested_tensor
928

929
        model = torch.nn.modules.Linear(1, 1, device="cuda")
930
        model.train()
931
        opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
932

933
        def step():
934
            x = torch.ones((1, 1), device="cuda")
935
            loss = model(x)
936
            loss.backward()
937
            opt.step()
938

939
        # Warmup
940
        for _ in range(3):
941
            step()
942

943
        with torch.profiler.profile(profile_memory=True, with_stack=True) as p:
944
            step()
945

946
        self.assertTreesMatch(
947
            ProfilerTree.format(p.profiler, 12),
948
            """\
949
            test_profiler_tree.py(...): test_profiler_experimental_tree_cuda_detailed
950
              torch/profiler/profiler.py(...): __enter__
951
                ...
952
              test_profiler_tree.py(...): step
953
                <built-in method ones of type object at 0xXXXXXXXXXXXX>
954
                  aten::ones
955
                    aten::empty
956
                      [memory]
957
                    aten::fill_
958
                      cudaLaunchKernel
959
                        void at::native::vectorized_elementwise_kernel<...>(...)
960
                nn.Module: Linear_0
961
                  <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX>
962
                  torch/nn/modules/linear.py(...): forward
963
                    torch/nn/modules/module.py(...): __getattr__
964
                    torch/nn/modules/module.py(...): __getattr__
965
                    <built-in function linear>
966
                      aten::linear
967
                        aten::t
968
                          aten::transpose
969
                            aten::as_strided
970
                        aten::addmm
971
                          cudaMemcpyAsync
972
                            Memcpy DtoD (Device -> Device)
973
                          cudaLaunchKernel
974
                            void ..._kernel<...>(...)
975
                          [memory]
976
                          aten::expand
977
                            aten::as_strided
978
                torch/_tensor.py(...): backward
979
                  <built-in function _has_torch_function_unary>
980
                  torch/autograd/__init__.py(...): backward
981
                    <built-in function isinstance>
982
                    <built-in function isinstance>
983
                    <built-in function len>
984
                    torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
985
                    torch/autograd/__init__.py(...): _make_grads
986
                      typing.py(...): inner
987
                        typing.py(...): __hash__
988
                          <built-in function hash>
989
                      typing.py(...): cast
990
                      <built-in function isinstance>
991
                      <built-in function isinstance>
992
                      <built-in function isinstance>
993
                      <built-in function isinstance>
994
                      <built-in function isinstance>
995
                      <built-in function isinstance>
996
                      <built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
997
                      <built-in function isinstance>
998
                      <built-in function isinstance>
999
                      <built-in method ones_like of type object at 0xXXXXXXXXXXXX>
1000
                        aten::ones_like
1001
                          aten::empty_like
1002
                            aten::empty_strided
1003
                              [memory]
1004
                          aten::fill_
1005
                            cudaLaunchKernel
1006
                              void at::native::vectorized_elementwise_kernel<...>(...)
1007
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1008
                    <built-in method run_backward of torch._C._EngineBase object at 0xXXXXXXXXXXXX>
1009
                      autograd::engine::evaluate_function: AddmmBackward0
1010
                        AddmmBackward0
1011
                          aten::t
1012
                            aten::transpose
1013
                              aten::as_strided
1014
                          aten::mm
1015
                            cudaLaunchKernel
1016
                              void ..._kernel<...>(...)
1017
                            [memory]
1018
                          aten::t
1019
                            aten::transpose
1020
                              aten::as_strided
1021
                        aten::sum
1022
                          aten::sum
1023
                            cudaLaunchKernel
1024
                              void at::native::reduce_kernel<...>(...)
1025
                            [memory]
1026
                        aten::view
1027
                          aten::view
1028
                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
1029
                        torch::autograd::AccumulateGrad
1030
                          aten::add_
1031
                            cudaLaunchKernel
1032
                              void at::native::vectorized_elementwise_kernel<...>(...)
1033
                          [memory]
1034
                      autograd::engine::evaluate_function: TBackward0
1035
                        TBackward0
1036
                          aten::t
1037
                            aten::transpose
1038
                              aten::as_strided
1039
                      autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
1040
                        torch::autograd::AccumulateGrad
1041
                          aten::add_
1042
                            cudaLaunchKernel
1043
                              void at::native::vectorized_elementwise_kernel<...>(...)
1044
                          [memory]
1045
                  [memory]
1046
                torch/optim/optimizer.py(...): wrapper
1047
                  <built-in method format of str object at 0xXXXXXXXXXXXX>
1048
                  torch/autograd/profiler.py(...): __init__
1049
                    <built-in method zeros of type object at 0xXXXXXXXXXXXX>
1050
                      aten::zeros
1051
                        aten::zeros
1052
                          aten::empty
1053
                            [memory]
1054
                          aten::zero_
1055
                  torch/autograd/profiler.py(...): __enter__
1056
                    torch/_ops.py(...): __call__
1057
                      <built-in method _record_function_enter of PyCapsule object at 0xXXXXXXXXXXXX>
1058
                        Optimizer.step#SGD.step
1059
                          aten::empty
1060
                            [memory]
1061
                          [memory]
1062
                    [memory]
1063
                  torch/optim/optimizer.py(...): _use_grad
1064
                    <built-in function is_grad_enabled>
1065
                    torch/autograd/grad_mode.py(...): __init__
1066
                      <built-in function is_grad_enabled>
1067
                      <built-in function _set_grad_enabled>
1068
                    torch/optim/sgd.py(...): step
1069
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1070
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1071
                      torch/_tensor.py(...): __hash__
1072
                        <built-in function id>
1073
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1074
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1075
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1076
                      torch/_tensor.py(...): __hash__
1077
                        <built-in function id>
1078
                      <built-in method append of list object at 0xXXXXXXXXXXXX>
1079
                      torch/optim/sgd.py(...): sgd
1080
                        torch/optim/sgd.py(...): _single_tensor_sgd
1081
                          <built-in method mul_ of Tensor object at 0xXXXXXXXXXXXX>
1082
                            [memory]
1083
                            aten::mul_
1084
                              cudaLaunchKernel
1085
                                void at::native::vectorized_elementwise_kernel<...>(...)
1086
                            [memory]
1087
                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1088
                            aten::add_
1089
                              cudaLaunchKernel
1090
                                void at::native::vectorized_elementwise_kernel<...>(...)
1091
                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1092
                            aten::add_
1093
                              cudaLaunchKernel
1094
                                void at::native::vectorized_elementwise_kernel<...>(...)
1095
                          <built-in method mul_ of Tensor object at 0xXXXXXXXXXXXX>
1096
                            [memory]
1097
                            aten::mul_
1098
                              cudaLaunchKernel
1099
                                void at::native::vectorized_elementwise_kernel<...>(...)
1100
                            [memory]
1101
                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1102
                            aten::add_
1103
                              cudaLaunchKernel
1104
                                void at::native::vectorized_elementwise_kernel<...>(...)
1105
                          <built-in method add_ of Tensor object at 0xXXXXXXXXXXXX>
1106
                            aten::add_
1107
                              cudaLaunchKernel
1108
                                void at::native::vectorized_elementwise_kernel<...>(...)
1109
                      torch/_tensor.py(...): __hash__
1110
                        <built-in function id>
1111
                      torch/_tensor.py(...): __hash__
1112
                        <built-in function id>
1113
                    torch/autograd/grad_mode.py(...): __init__
1114
                      <built-in function is_grad_enabled>
1115
                      <built-in function _set_grad_enabled>
1116
                  torch/autograd/profiler.py(...): __exit__
1117
                    torch/_ops.py(...): __call__
1118
                      <built-in method _record_function_exit of PyCapsule object at 0xXXXXXXXXXXXX>
1119
              [memory]
1120
              [memory]
1121
              torch/profiler/profiler.py(...): __exit__
1122
                torch/profiler/profiler.py(...): stop
1123
                  torch/profiler/profiler.py(...): _transit_action
1124
                    <built-in method get of dict object at 0xXXXXXXXXXXXX>
1125
                      enum.py(...): __hash__
1126
                        <built-in function hash>
1127
                    ...""",  # noqa: B950
1128
            allow_failure=ALLOW_CUDA_FAILURE,
1129
        )
1130

1131

1132
if __name__ == "__main__":
1133
    run_tests()
1134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.