pytorch

Форк
0
/
test_memory_profiler.py 
1613 строк · 77.3 Кб
1
# Owner(s): ["oncall: profiler"]
2
import functools
3
import gc
4
import itertools as it
5
import textwrap
6
from typing import Callable, Dict, Iterator, List, Optional, Tuple
7

8
import torch
9
from torch._C._profiler import _EventType, _TensorMetadata
10
from torch.profiler import _memory_profiler, _utils
11
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
12
from torch.utils import _pytree as pytree
13

14

15
profile = functools.partial(
16
    torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
17
)
18

19

20
@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
21
class TestMemoryProfiler(TestCase):
22
    def test_config_check(self) -> None:
23
        with torch.profiler.profile() as prof:
24
            pass
25

26
        pattern = r"record_shapes=True, profile_memory=True, with_stack=True"
27
        with self.assertRaisesRegex(ValueError, pattern):
28
            prof._memory_profile()
29

30
        with torch.profiler.profile(record_shapes=True, with_stack=True) as prof:
31
            pass
32

33
        pattern = r"^profile_memory=True required for memory profiling\.$"
34
        with self.assertRaisesRegex(ValueError, pattern):
35
            prof._memory_profile()
36

37
        with profile() as prof:
38
            pass
39

40
        self.assertIsInstance(prof._memory_profile(), _memory_profiler.MemoryProfile)
41

42

43
class ScaleLayer(torch.nn.Module):
44
    def __init__(self) -> None:
45
        super().__init__()
46
        self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
47

48
    def forward(self, x: torch.Tensor) -> torch.Tensor:
49
        return x * self.scale
50

51

52
class LazyLinear(torch.nn.Module):
53
    def __init__(self, in_features: int, out_features: int):
54
        super().__init__()
55
        self.in_features = in_features
56
        self.out_features = out_features
57

58
    def forward(self, x) -> torch.Tensor:
59
        if getattr(self, "weight", None) is None:
60
            self.weight = torch.nn.Parameter(
61
                torch.empty((self.out_features, self.in_features))
62
            )
63
            self.bias = torch.nn.Parameter(torch.empty(self.out_features))
64

65
        return torch.nn.functional.linear(x, self.weight, self.bias)
66

67

68
class RecordInputOutputDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
69
    def __init__(self) -> None:
70
        self.results = []
71

72
    def mark_region(self, name: str):
73
        self.results.append((name, (), ()))
74

75
    @staticmethod
76
    def flat_ids(args):
77
        flat_args = pytree.tree_leaves(args)
78
        return tuple(
79
            (t._cdata, t.storage().data_ptr())
80
            for t in flat_args
81
            if isinstance(t, torch.Tensor) and t.storage()
82
        )
83

84
    def __torch_dispatch__(self, func, types, args=..., kwargs=None):
85
        args = args or []
86
        kwargs = kwargs or {}
87
        flat_inputs = self.flat_ids(args) + self.flat_ids(kwargs)
88
        out = func(*args, **kwargs)
89
        flat_outputs = self.flat_ids(out)
90
        if (
91
            flat_inputs or flat_outputs
92
        ) and "_record_function_enter" not in func.name():
93
            self.results.append((func.name(), flat_inputs, flat_outputs))
94
        return out
95

96

97
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
98
class TestIdentifyGradients(TestCase):
99
    def gradient_detected(
100
        self,
101
        prof: torch.profiler.profile,
102
        ctx: _EventType,
103
        grad_tensor: torch.Tensor,
104
        parameter: Optional[torch.Tensor] = None,
105
    ) -> None:
106
        # This is not an exhaustive check, but for the purpose of unit testing
107
        # it is sufficient.
108
        def key_matches_tensor(key, tensor) -> bool:
109
            # Vacuous case.
110
            if tensor is None:
111
                return True
112

113
            if key is None:
114
                return False
115

116
            return tensor.storage().data_ptr() == key.storage.ptr
117

118
        tree = prof.profiler.kineto_results.experimental_event_tree()
119
        for node in _utils.traverse_dfs(tree):
120
            for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
121
                if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
122
                    if parameter is None:
123
                        return True  # Don't need to check parameter; we're done.
124

125
                    elif p_key is not None:
126
                        # For a complex workflow a gradient could correspond to
127
                        # different parameters at different points in a trace.
128
                        # However this will not happen in the relatively simple
129
                        # cases tested here, so if `extract_gradients` identifies
130
                        # the parameter corresponding to a particular gradient it
131
                        # must be the one we expect.
132
                        self.assertTrue(key_matches_tensor(p_key, parameter))
133
                        return True
134

135
        return False
136

137
    def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
138
        self.assertTrue(
139
            self.gradient_detected(*args, **kwargs),
140
            f"Failed to identify gradient `{name}` from profile.",
141
        )
142

143
    def assertOnlyGradients(
144
        self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
145
    ) -> None:
146
        allowed_set = {t.storage().data_ptr() for t in tensors}
147

148
        tree = prof.profiler.kineto_results.experimental_event_tree()
149
        for node in _utils.traverse_dfs(tree):
150
            for _, p_grad_key in _memory_profiler.extract_gradients(node):
151
                self.assertTrue(
152
                    p_grad_key.storage.ptr in allowed_set,
153
                    f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
154
                )
155

156
    def test_extract_gradients_low_level(self) -> None:
157
        x = torch.ones((1,))
158
        w0 = torch.ones((1,), requires_grad=True)
159
        w1 = torch.ones((1,), requires_grad=True)
160

161
        def check(cold_start: bool):
162
            self.assertEqual(w0.grad is None, cold_start)
163
            self.assertEqual(w1.grad is None, cold_start)
164
            with profile() as prof:
165
                z = x.expand(4) * w0
166
                (z * w1).sum().backward()
167

168
            # Gradient detection through op inspection does not provide a
169
            # reference to the parameter corresponding to the gradient.
170
            self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
171
            self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
172
            self.assertOnlyGradients(prof, (w0.grad, w1.grad))
173

174
        check(cold_start=True)
175
        check(cold_start=False)
176

177
    def test_extract_gradients_from_module(self) -> None:
178
        model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
179
        named_parameters = dict(model.named_parameters())
180
        self.assertEqual(len(named_parameters), 3)
181

182
        def assert_only_gradients(prof: torch.profiler.profile):
183
            gradients = tuple(i.grad for i in named_parameters.values())
184
            self.assertFalse(any(i is None for i in gradients))
185
            self.assertOnlyGradients(prof, gradients)
186

187
        def check(cold_start: bool):
188
            x = torch.ones((2, 2))
189
            with profile() as prof:
190
                model(x).sum().backward()
191

192
            for name, p in named_parameters.items():
193
                # The first time we run a module none of the `.grad` fields
194
                # have been initialized. This is fine; in that case we can
195
                # detect everything we need in the profiled section.
196
                self.assertNotEqual(
197
                    self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
198
                    cold_start,
199
                    name,
200
                )
201

202
                # Op based detection should still identify the gradients.
203
                self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
204
            assert_only_gradients(prof)
205

206
            # We can detect gradients even when `.backward()` is not called.
207
            with profile() as prof:
208
                model(torch.ones((2, 2)))
209

210
            for name, p in named_parameters.items():
211
                self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
212
                self.assertFalse(
213
                    self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
214
                )
215
            assert_only_gradients(prof)
216

217
        check(cold_start=True)
218
        check(cold_start=False)
219

220
    def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
221
        x = torch.ones((1,))
222
        w0 = torch.ones((1,), requires_grad=True)
223
        w1 = torch.ones((1,), requires_grad=True)
224
        optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
225

226
        def check(cold_start: bool):
227
            self.assertEqual(w0.grad is None, cold_start)
228
            self.assertEqual(w1.grad is None, cold_start)
229
            with profile() as prof:
230
                optimizer.zero_grad(set_to_none=set_to_none)
231
                z = x.expand(4) * w0
232
                (z * w1).sum().backward()
233
                optimizer.step()
234

235
            # Optimizer instrumentation runs late in the step, so we can detect
236
            # gradients for both cold and warm start.
237
            self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
238
            self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
239

240
            self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
241
            self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
242
            self.assertOnlyGradients(prof, (w0.grad, w1.grad))
243

244
            with profile() as prof:
245
                for _ in range(2):
246
                    optimizer.zero_grad(set_to_none=set_to_none)
247
                    z = x.expand(4) * w0
248
                    (z * w1).sum().backward()
249
                    optimizer.step()
250

251
            # Inspected state is cached, so if we replace gradients (as is the
252
            # case for `set_to_none=True`) our python instrumentation will not
253
            # see them.
254
            # TODO(robieta): Should `.step()` be excluded from caching?
255
            self.assertNotEqual(
256
                self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
257
                set_to_none,
258
            )
259

260
            self.assertNotEqual(
261
                self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
262
                set_to_none,
263
            )
264

265
            if set_to_none:
266
                with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
267
                    self.assertOnlyGradients(prof, (w0.grad, w1.grad))
268

269
        check(cold_start=True)
270
        check(cold_start=False)
271

272
    def test_extract_gradients_from_optimizer(self) -> None:
273
        self._test_extract_gradients_from_optimizer(set_to_none=False)
274

275
    def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
276
        self._test_extract_gradients_from_optimizer(set_to_none=True)
277

278
    def test_extract_gradients_from_module_and_optimizer(self) -> None:
279
        # Module and optimizer are thoroughly tested individually and should be
280
        # additive. Thus we can manage with a lightweight check that they don't
281
        # interact adversely.
282
        model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
283
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
284
        with profile() as prof:
285
            model(torch.ones((2, 2))).sum().backward()
286
            optimizer.step()
287

288
        self.assertGradientDetected(
289
            "weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
290
        )
291

292

293
@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
294
class TestDataFlow(TestCase):
295
    def setUp(self) -> None:
296
        super().setUp()
297
        self.maxDiff = None
298

299
    @staticmethod
300
    def formatSchemas(
301
        prof: torch.profiler.profile, indent: int = 12
302
    ) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]:
303
        tree = prof.profiler.kineto_results.experimental_event_tree()
304
        out: List[Tuple[str, Tuple[bool, ...]]] = []
305
        for node in _utils.traverse_dfs(tree):
306
            if node.tag == _EventType.TorchOp:
307
                e = node.extra_fields
308
                schemas = _memory_profiler.SchemaMatcher.match_schemas(e)
309
                name = node.name
310
                if len(schemas) == 1:
311
                    name = f"{name}.{schemas[0].overload_name}"
312
                elif len(schemas) > 1:
313
                    name = f"{name}.{{{', '.join(s.overload_name for s in schemas)}}}"
314

315
                out.append((name, _memory_profiler.SchemaMatcher.inputs_are_mutable(e)))
316
        return tuple(out)
317

318
    @staticmethod
319
    def _run_and_format_data_flow(
320
        inputs: Dict[str, torch.Tensor],
321
        f: Callable[..., Optional[Dict[str, torch.Tensor]]],
322
        indent: int = 12,
323
    ) -> str:
324
        with profile() as prof:
325
            outputs = f(**inputs) or {}
326
            gc.collect()
327

328
        memory_profile = prof._memory_profile()
329
        graph = memory_profile._data_flow_graph
330
        storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
331

332
        lines: List[str] = []
333
        for name, t in it.chain(inputs.items(), outputs.items()):
334
            lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
335
            if t.grad is not None:
336
                grad_id = storage_to_id[t.grad.storage().data_ptr()]
337
                lines.append(f"{name + '.grad:':<9} T{grad_id}")
338

339
        if lines:
340
            lines.append("")
341

342
        for node in graph.flow_nodes:
343
            destroyed = {k for k, v in node._edges.items() if v.is_deletion}
344

345
            inputs: List[str] = []
346
            for key, (_, v) in node.inputs.items():
347
                inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
348

349
            outputs = [f"T{key.id}(v{v})" for key, v in node.outputs.items()]
350
            if inputs or outputs:
351
                event_name = node._event.name.replace("torch::autograd::", "")
352
                lines.append(
353
                    f"{event_name:<25} {', '.join(inputs):<15}  ->  {', '.join(outputs)}"
354
                )
355

356
        return textwrap.indent("\n".join([l.rstrip() for l in lines]), " " * indent)
357

358
    def test_match_schemas(self) -> None:
359
        with profile() as prof:
360
            x = torch.ones((1,)).mul(2).add_(2)
361
            _ = torch.sin(x, out=torch.empty_like(x))
362

363
        self.assertEqual(
364
            self.formatSchemas(prof),
365
            (
366
                ("aten::ones.", (False,) * 5),
367
                ("aten::empty.memory_format", (False,) * 6),
368
                #
369
                # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
370
                ("aten::fill_.Scalar", (True, False)),
371
                ("aten::mul.Tensor", (False, False)),
372
                ("aten::to.dtype", (False,) * 5),
373
                ("aten::_to_copy.", (False,) * 7),
374
                ("aten::empty_strided.", (False,) * 6),
375
                #
376
                # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
377
                ("aten::copy_.", (True, False, False)),
378
                #
379
                # add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
380
                ("aten::add_.Tensor", (True, False, False)),
381
                ("aten::to.dtype", (False,) * 5),
382
                ("aten::_to_copy.", (False,) * 7),
383
                ("aten::empty_strided.", (False,) * 6),
384
                #
385
                # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
386
                ("aten::copy_.", (True, False, False)),
387
                ("aten::empty_like.", (False,) * 6),
388
                ("aten::empty_strided.", (False,) * 6),
389
                #
390
                # sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
391
                ("aten::sin.out", (False, True)),
392
            ),
393
        )
394

395
    def test_match_schemas_backward(self) -> None:
396
        x = torch.ones((1,))
397
        w = torch.ones((1,), requires_grad=True)
398
        with profile() as prof:
399
            torch.mul(x, w).backward()
400

401
        self.assertEqual(
402
            self.formatSchemas(prof),
403
            (
404
                ("aten::mul.Tensor", (False, False)),
405
                ("aten::ones_like.", (False,) * 6),
406
                ("aten::empty_like.", (False,) * 6),
407
                ("aten::empty_strided.", (False,) * 6),
408
                #
409
                # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
410
                ("aten::fill_.Scalar", (True, False)),
411
                ("autograd::engine::evaluate_function: MulBackward0", ()),
412
                ("MulBackward0", (None,)),
413
                ("aten::mul.Tensor", (False, False)),
414
                (
415
                    "autograd::engine::evaluate_function: torch::autograd::AccumulateGrad",
416
                    (),
417
                ),
418
                ("torch::autograd::AccumulateGrad", (None,)),
419
                ("aten::detach.", (False,)),
420
                ("detach", (None,)),
421
            ),
422
        )
423

424
    def test_match_schemas_tensorlist(self) -> None:
425
        x = torch.ones((1,))
426
        y = torch.ones((1,))
427
        with profile() as prof:
428
            torch.cat([x, y], axis=0)
429

430
        self.assertEqual(
431
            self.formatSchemas(prof),
432
            (("aten::cat.", (False, False)),),
433
        )
434

435
    def test_data_flow_graph_with_annotations(self) -> None:
436
        def f(x, y):
437
            # torch._C._jit_get_schemas_for_operator will reject any name that
438
            # is missing a namespace. (denoted by the presence of "::") We want
439
            # to check that we skip both annotations which have no schema
440
            # (return empty tuple from SchemaMatcher.lookup_schemas) and
441
            # annotations which cannot have schema (return None from
442
            # SchemaMatcher.lookup_schemas).
443
            with torch.profiler.record_function("Namespaced::Annotation"):
444
                with torch.profiler.record_function("My Annotation"):
445
                    x.zero_()
446
                    y.zero_()
447
                    return {"x0": torch.ones_like(x), "y0": torch.zeros_like(y)}
448

449
        inputs = {"x": torch.ones((1,)), "y": torch.ones((1,))}
450
        self.assertExpectedInline(
451
            self._run_and_format_data_flow(inputs, f),
452
            """\
453
            x:       T0
454
            y:       T1
455
            x0:      T2
456
            y0:      T3
457

458
            aten::zero_               T0(v0)           ->  T0(v1)
459
            aten::zero_               T1(v0)           ->  T1(v1)
460
            aten::ones_like           T0(v1)           ->  T2(v0)
461
            aten::zeros_like          T1(v1)           ->  T3(v0)""",
462
        )
463

464
    def test_data_flow_graph_non_op_allocations(self) -> None:
465
        def f(x):
466
            x.mul(2)
467

468
        # The python arg parser will convert the python scalar `2` to a Tensor
469
        # to pass to `aten::mul`. As a result there is no op that "owns" the
470
        # allocation. The Tensor deletions also do not happen in an op; they
471
        # are collected as a result of the Python objects going out of scope.
472
        self.assertExpectedInline(
473
            self._run_and_format_data_flow({"x": torch.ones((1,))}, f),
474
            """\
475
            x:       T1
476

477
            [memory]                                   ->  T0(v0)
478
            aten::mul                 T0(v0), T1(v0)   ->
479
            [memory]                  T0(v0*)          ->""",
480
        )
481

482
    def test_data_flow_graph_simple(self) -> None:
483
        inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)}
484

485
        def f0(x, y):
486
            z = x.mul(y)
487
            return {"z": z.view_as(z)}
488

489
        def f1(x, y):
490
            with torch.no_grad():
491
                return f0(x, y)
492

493
        self.assertExpectedInline(
494
            self._run_and_format_data_flow(inputs, f0),
495
            """\
496
            x:       T0
497
            y:       T1
498
            z:       T2
499

500
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
501
            aten::view_as             T2(v0)           ->""",
502
        )
503

504
        # Out of place is identical regardless of Autograd.
505
        self.assertExpectedInline(
506
            self._run_and_format_data_flow(inputs, f0),
507
            """\
508
            x:       T0
509
            y:       T1
510
            z:       T2
511

512
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
513
            aten::view_as             T2(v0)           ->""",
514
        )
515

516
    def test_data_flow_graph_simple_inplace(self) -> None:
517
        inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)}
518

519
        def f0(x, y):
520
            x.mul_(y)
521

522
        def f1(x, y):
523
            with torch.no_grad():
524
                return f0(x, y)
525

526
        # When Autograd is enabled a second Tensor `T2` is created to store
527
        # the values of T0(v0) which are needed for backwards.
528
        self.assertExpectedInline(
529
            self._run_and_format_data_flow(inputs, f0),
530
            """\
531
            x:       T0
532
            y:       T1
533

534
            aten::mul_                T0(v0), T1(v0)   ->  T0(v1), T2(v0)""",
535
        )
536

537
        self.assertExpectedInline(
538
            self._run_and_format_data_flow(inputs, f1),
539
            """\
540
            x:       T0
541
            y:       T1
542

543
            aten::mul_                T0(v0), T1(v0)   ->  T0(v1)""",
544
        )
545

546
    def test_data_flow_graph_simple_backward(self) -> None:
547
        inputs = {
548
            "x": torch.ones((1,)),
549
            "w": torch.ones((1,), requires_grad=True),
550
        }
551
        self.assertExpectedInline(
552
            self._run_and_format_data_flow(
553
                inputs, lambda x, w: (x * w).sin().backward()
554
            ),
555
            """\
556
            x:       T0
557
            w:       T1
558
            w.grad:   T7
559

560
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
561
            aten::sin                 T2(v0)           ->  T3(v0)
562
            aten::ones_like           T3(v0)           ->  T4(v0)
563
            SinBackward0              T2(v0), T4(v0)   ->  T6(v0)
564
            [memory]                  T2(v0*)          ->
565
            MulBackward0              T0(v0), T6(v0)   ->  T7(v0)
566
            [memory]                  T6(v0*)          ->
567
            AccumulateGrad            T7(v0)           ->
568
            [memory]                  T4(v0*)          ->
569
            [memory]                  T3(v0*)          ->""",
570
        )
571

572
    def test_data_flow_graph_complicated(self) -> None:
573
        def f():
574
            x = torch.ones((25,))
575
            y = x.mul(2).add_(2)
576
            z = torch.sin(y, out=torch.empty_like(y))
577
            return {"x": x, "y": y, "z": z}
578

579
        # T1 is the `2` in `.mul(2)`. The Python arg parser automatically
580
        # converts Scalar arguments to Tensors. The same is true for `T4`
581
        # and `.add_(2)`.
582
        self.assertExpectedInline(
583
            self._run_and_format_data_flow({}, f),
584
            """\
585
            x:       T0
586
            y:       T3
587
            z:       T6
588

589
            aten::ones                                 ->  T0(v0)
590
            [memory]                                   ->  T1(v0)
591
            aten::mul                 T0(v0), T1(v0)   ->  T3(v0)
592
            [memory]                  T1(v0*)          ->
593
            [memory]                                   ->  T4(v0)
594
            aten::add_                T3(v0), T4(v0)   ->  T3(v1)
595
            [memory]                  T4(v0*)          ->
596
            aten::empty_like          T3(v1)           ->  T6(v0)
597
            aten::sin                 T3(v1), T6(v0)   ->  T6(v1)""",
598
        )
599

600
        with profile() as prof:
601
            f()
602

603
        # `aten::mul` creates a temporary Tensor (T2), which is why the output
604
        # is has ID three rather than two.
605
        mul_node = prof._memory_profile()._data_flow_graph.flow_nodes[2]
606
        self.assertEqual(mul_node._event.name, "aten::mul")
607
        self.assertEqual(len(mul_node.intermediates), 1)
608
        self.assertEqual(mul_node.intermediates[0].id, 2)
609

610
    def test_data_flow_graph_stacked(self) -> None:
611
        inputs = {
612
            "x": torch.ones((25,)),
613
            "w0": torch.ones((1,), requires_grad=True),
614
            "w1": torch.ones((1,), requires_grad=True),
615
        }
616

617
        def f(x, w0, w1):
618
            return x.mul(w0).relu().mul(w1).relu().sum()
619

620
        def f_fwd(**kwargs):
621
            with torch.no_grad():
622
                return {"loss": f(**kwargs)}
623

624
        def f_fwd_bwd(**kwargs):
625
            loss = f(**kwargs)
626
            loss.backward()
627
            return {"loss": loss}
628

629
        self.assertExpectedInline(
630
            self._run_and_format_data_flow(inputs, f_fwd),
631
            """\
632
            x:       T0
633
            w0:      T1
634
            w1:      T4
635
            loss:    T7
636

637
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
638
            aten::relu                T2(v0)           ->  T3(v0)
639
            [memory]                  T2(v0*)          ->
640
            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
641
            [memory]                  T3(v0*)          ->
642
            aten::relu                T5(v0)           ->  T6(v0)
643
            [memory]                  T5(v0*)          ->
644
            aten::sum                 T6(v0)           ->  T7(v0)
645
            [memory]                  T6(v0*)          ->""",
646
        )
647

648
        self.assertExpectedInline(
649
            self._run_and_format_data_flow(inputs, f_fwd_bwd),
650
            """\
651
            x:       T0
652
            w0:      T1
653
            w0.grad:  T15
654
            w1:      T4
655
            w1.grad:  T12
656
            loss:    T7
657

658
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
659
            aten::relu                T2(v0)           ->  T3(v0)
660
            [memory]                  T2(v0*)          ->
661
            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
662
            aten::relu                T5(v0)           ->  T6(v0)
663
            [memory]                  T5(v0*)          ->
664
            aten::sum                 T6(v0)           ->  T7(v0)
665
            aten::ones_like           T7(v0)           ->  T8(v0)
666
            SumBackward0              T8(v0)           ->
667
            ReluBackward0             T6(v0), T8(v0)   ->  T9(v0)
668
            [memory]                  T6(v0*)          ->
669
            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T10(v0), T11(v0)
670
            aten::sum                 T10(v0)          ->  T12(v0)
671
            [memory]                  T10(v0*)         ->
672
            [memory]                  T9(v0*)          ->
673
            AccumulateGrad            T12(v0)          ->
674
            ReluBackward0             T3(v0), T11(v0)  ->  T13(v0)
675
            [memory]                  T11(v0*)         ->
676
            [memory]                  T3(v0*)          ->
677
            MulBackward0              T0(v0), T13(v0)  ->  T14(v0)
678
            aten::sum                 T14(v0)          ->  T15(v0)
679
            [memory]                  T14(v0*)         ->
680
            [memory]                  T13(v0*)         ->
681
            AccumulateGrad            T15(v0)          ->
682
            [memory]                  T8(v0*)          ->""",
683
        )
684

685
        # Second time grads are already initialized.
686
        self.assertExpectedInline(
687
            self._run_and_format_data_flow(inputs, f_fwd_bwd),
688
            """\
689
            x:       T0
690
            w0:      T1
691
            w0.grad:  T17
692
            w1:      T4
693
            w1.grad:  T13
694
            loss:    T7
695

696
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
697
            aten::relu                T2(v0)           ->  T3(v0)
698
            [memory]                  T2(v0*)          ->
699
            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
700
            aten::relu                T5(v0)           ->  T6(v0)
701
            [memory]                  T5(v0*)          ->
702
            aten::sum                 T6(v0)           ->  T7(v0)
703
            aten::ones_like           T7(v0)           ->  T8(v0)
704
            SumBackward0              T8(v0)           ->
705
            ReluBackward0             T6(v0), T8(v0)   ->  T9(v0)
706
            [memory]                  T6(v0*)          ->
707
            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T10(v0), T11(v0)
708
            aten::sum                 T10(v0)          ->  T12(v0)
709
            [memory]                  T10(v0*)         ->
710
            [memory]                  T9(v0*)          ->
711
            AccumulateGrad            T12(v0*), T13(v0)  ->  T13(v1)
712
            ReluBackward0             T3(v0), T11(v0)  ->  T14(v0)
713
            [memory]                  T11(v0*)         ->
714
            [memory]                  T3(v0*)          ->
715
            MulBackward0              T0(v0), T14(v0)  ->  T15(v0)
716
            aten::sum                 T15(v0)          ->  T16(v0)
717
            [memory]                  T15(v0*)         ->
718
            [memory]                  T14(v0*)         ->
719
            AccumulateGrad            T16(v0*), T17(v0)  ->  T17(v1)
720
            [memory]                  T8(v0*)          ->""",
721
        )
722

723
        return
724

725
        x = torch.ones((25,))
726
        w0 = torch.ones((1,), requires_grad=True)
727
        w1 = torch.ones((1,), requires_grad=True)
728

729
        with profile() as prof_no_grad:
730
            with torch.no_grad():
731
                x.mul(w0).relu().mul(w1).relu().sum()
732

733
        # TODO: one with `.logsumexp(dim=0)`
734

735
        self.assertExpectedInline(
736
            self._format_graph(prof_no_grad),
737
            """\
738
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
739
            aten::relu                T2(v0)           ->  T3(v0)
740
            [memory]                  T2(v0*)          ->
741
            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
742
            [memory]                  T3(v0*)          ->
743
            aten::relu                T5(v0)           ->  T6(v0)
744
            [memory]                  T5(v0*)          ->
745
            aten::sum                 T6(v0)           ->  T7(v0)
746
            [memory]                  T6(v0*)          ->
747
            [memory]                  T7(v0*)          ->""",
748
        )
749

750
        with profile() as prof_grad:
751
            loss = x.mul(w0).relu().mul(w1).relu().sum()
752
            loss.backward()
753

754
        self.assertExpectedInline(
755
            self._format_graph(prof_grad),
756
            """\
757
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
758
            aten::relu                T2(v0)           ->  T3(v0)
759
            [memory]                  T2(v0*)          ->
760
            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
761
            aten::relu                T5(v0)           ->  T6(v0)
762
            [memory]                  T5(v0*)          ->
763
            aten::sum                 T6(v0)           ->  T7(v0)
764
            aten::ones_like           T7(v0)           ->  T8(v0)
765
            SumBackward0              T8(v0)           ->  T8(v1)
766
            ReluBackward0             T6(v0), T8(v1)   ->  T8(v2), T9(v0)
767
            [memory]                  T6(v0*)          ->
768
            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T9(v1), T10(v0), T11(v0)
769
            aten::sum                 T10(v0)          ->  T12(v0)
770
            [memory]                  T10(v0*)         ->
771
            [memory]                  T9(v1*)          ->
772
            AccumulateGrad            T12(v0)          ->  T12(v1)
773
            ReluBackward0             T3(v0), T11(v0)  ->  T11(v1), T13(v0)
774
            [memory]                  T11(v1*)         ->
775
            [memory]                  T3(v0*)          ->
776
            MulBackward0              T0(v0), T13(v0)  ->  T13(v1), T14(v0)
777
            aten::sum                 T14(v0)          ->  T15(v0)
778
            [memory]                  T14(v0*)         ->
779
            [memory]                  T13(v1*)         ->
780
            AccumulateGrad            T15(v0)          ->  T15(v1)
781
            [memory]                  T8(v2*)          ->""",
782
        )
783

784
        # Second time grads are already initialized.
785
        with profile() as prof_grad:
786
            loss = x.mul(w0).relu().mul(w1).relu().sum()
787
            loss.backward()
788

789
        self.assertExpectedInline(
790
            self._format_graph(prof_grad),
791
            """\
792
            aten::mul                 T0(v0), T1(v0)   ->  T2(v0)
793
            aten::relu                T2(v0)           ->  T3(v0)
794
            [memory]                  T2(v0*)          ->
795
            aten::mul                 T3(v0), T4(v0)   ->  T5(v0)
796
            aten::relu                T5(v0)           ->  T6(v0)
797
            [memory]                  T5(v0*)          ->
798
            aten::sum                 T6(v0)           ->  T7(v0)
799
            aten::ones_like           T7(v0)           ->  T8(v0)
800
            SumBackward0              T8(v0)           ->  T8(v1)
801
            ReluBackward0             T6(v0), T8(v1)   ->  T8(v2), T9(v0)
802
            [memory]                  T6(v0*)          ->
803
            MulBackward0              T3(v0), T4(v0), T9(v0)  ->  T9(v1), T10(v0), T11(v0)
804
            aten::sum                 T10(v0)          ->  T12(v0)
805
            [memory]                  T10(v0*)         ->
806
            [memory]                  T9(v1*)          ->
807
            AccumulateGrad            T12(v0*), T13(v0)  ->  T13(v1)
808
            ReluBackward0             T3(v0), T11(v0)  ->  T11(v1), T14(v0)
809
            [memory]                  T11(v1*)         ->
810
            [memory]                  T3(v0*)          ->
811
            MulBackward0              T0(v0), T14(v0)  ->  T14(v1), T15(v0)
812
            aten::sum                 T15(v0)          ->  T16(v0)
813
            [memory]                  T15(v0*)         ->
814
            [memory]                  T14(v1*)         ->
815
            AccumulateGrad            T16(v0*), T17(v0)  ->  T17(v1)
816
            [memory]                  T8(v2*)          ->""",
817
        )
818

819

820
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
821
class TestMemoryProfilerE2E(TestCase):
822
    @staticmethod
823
    def _lookup_tensor_categories(
824
        t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
825
    ) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
826
        storage = t.storage()
827
        if storage is None:
828
            raise ValueError("Cannot look up uninitialized Tensor.")
829

830
        snapshot = memory_profile._category_snapshot()
831
        ids = {
832
            key.storage.allocation_id
833
            for key, _ in snapshot
834
            if key.storage.ptr == storage.data_ptr() and key.device == storage.device
835
        }
836

837
        return {
838
            (key, version): category
839
            for (key, version), category in memory_profile._category_snapshot().items()
840
            #
841
            # If a Tensor is live we want the most recent ID
842
            if key.storage.allocation_id == max(ids | {-1})
843
        }
844

845
    def _run_and_check_parameters_and_gradients(
846
        self, inner_fn, model, grads_none: bool = False
847
    ):
848
        with profile() as prof:
849
            inner_fn()
850

851
        memory_profile = prof._memory_profile()
852

853
        def assert_category(
854
            t: torch.Tensor,
855
            category: _memory_profiler.Category,
856
            should_be_none: bool = False,
857
        ):
858
            if should_be_none:
859
                assert t is None, "tensor should be None but is not."
860
                return
861
            self.assertIsNotNone(t)
862
            categories = self._lookup_tensor_categories(t, memory_profile)
863
            self.assertGreater(len(categories), 0)
864
            self.assertTrue(all(c == category for c in categories.values()), categories)
865

866
        for p in model.parameters():
867
            assert_category(p, _memory_profiler.Category.PARAMETER)
868
            assert_category(p.grad, _memory_profiler.Category.GRADIENT, grads_none)
869

870
        # Rely on internal asserts
871
        _ = memory_profile.timeline
872

873
    def _run_and_format_categories(self, fn, indent=12):
874
        """Generate summary of assigned categories for expecttest."""
875

876
        # Use `__torch_dispatch__` to collect ground truth.
877
        with RecordInputOutputDispatchMode() as record_ops, profile() as prof:
878
            fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
879

880
        memory_profile = prof._memory_profile()
881
        ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {}
882
        snapshot = memory_profile._category_snapshot()
883

884
        # Build map from observed live Tensors to the memory profiler's
885
        # TensorKey representation.
886
        for op in memory_profile._op_tree.dfs():
887
            if op.typed[0] == _EventType.TorchOp:
888
                inputs = pytree.tree_leaves(op.typed[1].inputs)
889
                for t in (i for i in inputs if isinstance(i, _TensorMetadata)):
890
                    key = _memory_profiler.TensorKey.from_tensor(t)
891
                    if key:
892
                        ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key
893

894
        def format_categories(ptr_pair: int):
895
            target_key = ptr_pair_to_key.get(ptr_pair, None)
896
            if target_key is None:
897
                return "???"
898

899
            matches = tuple(
900
                (version, category.name if category else "???")
901
                for (key, version), category in snapshot.items()
902
                if key == target_key
903
            )
904
            assert matches, "Failed to lookup Tensor"
905

906
            # Deduplicate version bumps which don't change the category.
907
            categories = [matches[0][1]]
908
            for _, category in matches:
909
                if category != categories[-1]:
910
                    categories.append(category)
911

912
            return f"{target_key.storage.allocation_id} ({','.join(categories)})"
913

914
        out: List[str] = []
915
        for name, inputs, outputs in record_ops.results:
916
            if inputs or outputs:
917
                # PyTorch ops
918
                inputs_str = ", ".join(format_categories(i) for i in inputs)
919
                outputs_str = ", ".join(format_categories(i) for i in outputs)
920
                out.append(f"{name:<40} {inputs_str:<45} -> {outputs_str}")
921

922
            else:
923
                # Marked regions.
924
                out.append(f"\n{name}")
925

926
        return textwrap.indent("\n".join(out), " " * indent)
927

928
    def test_parameters_and_gradients(self):
929
        model = torch.nn.Sequential(
930
            torch.nn.Linear(2, 2), ScaleLayer(), torch.nn.Linear(2, 1), ScaleLayer()
931
        )
932
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
933

934
        def fwd_only():
935
            _ = model(torch.ones((2, 2)))
936

937
        def fwd_bwd_step():
938
            optimizer.zero_grad()
939
            y = model(torch.ones((2, 2)))
940
            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
941
            optimizer.step()
942

943
        # If we profile the first step then gradients will not have been
944
        # created when we call `model.forward`, so if we don't call `.backward`
945
        # then gradients are never created.
946
        self._run_and_check_parameters_and_gradients(
947
            inner_fn=fwd_only, model=model, grads_none=True
948
        )
949

950
        # On the first step we must rely on `AccumulateGrad`, since gradients
951
        # did not exist when `model.forward` was called.
952
        self.assertTrue(all(p.grad is None for p in model.parameters()))
953
        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
954

955
        # After one step the python tracer will also flag gradients.
956
        self.assertTrue(not any(p.grad is None for p in model.parameters()))
957
        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
958

959
        # The parameter gradients are not used but we still detect them with
960
        # the python tracer.
961
        self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model)
962

963
    def test_parameters_and_gradients_set_to_none(self):
964
        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
965
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
966

967
        def fwd_bwd_step():
968
            for _ in range(3):
969
                # zero grads at the start so gradients are still live to be
970
                # checked.
971
                optimizer.zero_grad(set_to_none=True)
972

973
                y = model(torch.ones((2, 2)))
974
                torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
975
                optimizer.step()
976

977
        fwd_bwd_step()
978
        self.assertTrue(not any(p.grad is None for p in model.parameters()))
979
        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
980

981
        optimizer.zero_grad(set_to_none=True)
982
        self.assertTrue(all(p.grad is None for p in model.parameters()))
983
        self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
984

985
    def test_inputs_fwd(self):
986
        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
987
        inputs = [torch.ones((2, 2)) for _ in range(2)]
988

989
        with profile() as prof:
990
            # Inputs which were allocated before profiling began
991
            for x in inputs:
992
                _ = model(x)
993

994
            # Inputs which were allocated after profiling began
995
            for _ in range(2):
996
                x = torch.ones((2, 2))
997
                inputs.append(x)
998
                _ = model(x)
999

1000
        memory_profile = prof._memory_profile()
1001
        for x in inputs:
1002
            categories = self._lookup_tensor_categories(x, memory_profile)
1003
            self.assertGreater(len(categories), 0)
1004
            self.assertTrue(
1005
                all(i == _memory_profiler.Category.INPUT for i in categories.values()),
1006
                categories,
1007
            )
1008

1009
        snapshot = memory_profile._category_snapshot()
1010
        self.assertTrue(_memory_profiler.Category.INPUT in snapshot.values())
1011

1012
    def test_inputs_fwd_lazy(self):
1013
        model = torch.nn.Sequential(LazyLinear(2, 2), LazyLinear(2, 1))
1014
        inputs = [torch.ones((2, 2)) for _ in range(2)]
1015

1016
        with profile() as prof:
1017
            # Inputs which were allocated before profiling began
1018
            for x in inputs:
1019
                _ = model(x)
1020

1021
            # Inputs which were allocated after profiling began
1022
            for _ in range(2):
1023
                x = torch.ones((2, 2))
1024
                inputs.append(x)
1025
                _ = model(x)
1026

1027
        # For now we can't make any meaningful statements without a backward
1028
        # pass. Here we simply ensure that passes don't generate false positive
1029
        # category classifications.
1030
        memory_profile = prof._memory_profile()
1031
        for x in inputs:
1032
            categories = self._lookup_tensor_categories(x, memory_profile)
1033
            self.assertGreater(len(categories), 0)
1034
            self.assertTrue(all(i is None for i in categories.values()), categories)
1035

1036
        snapshot = memory_profile._category_snapshot()
1037
        self.assertFalse(_memory_profiler.Category.INPUT in snapshot.values())
1038

1039
    def test_inputs_fwd_bwd(self):
1040
        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
1041
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1042
        inputs_targets = [(torch.ones((2, 2)), torch.rand((2, 1))) for _ in range(2)]
1043

1044
        def fwd_bwd_step(x, targets):
1045
            y = model(x)
1046
            torch.nn.functional.mse_loss(y, targets).backward()
1047
            optimizer.step()
1048
            optimizer.zero_grad()
1049

1050
        with profile() as prof:
1051
            # Inputs which were allocated before profiling began
1052
            for x, targets in inputs_targets:
1053
                fwd_bwd_step(x, targets)
1054

1055
            # Inputs which were allocated after profiling began
1056
            for _ in range(2):
1057
                x = torch.ones((2, 2))
1058
                targets = torch.rand((2, 1))
1059
                inputs_targets.append((x, targets))
1060
                fwd_bwd_step(x, targets)
1061

1062
        memory_profile = prof._memory_profile()
1063

1064
        def check(t):
1065
            categories = self._lookup_tensor_categories(t, memory_profile)
1066
            self.assertGreater(len(categories), 0)
1067
            self.assertTrue(
1068
                all(i == _memory_profiler.Category.INPUT for i in categories.values())
1069
            )
1070

1071
        for x, targets in inputs_targets:
1072
            check(x)
1073
            check(targets)
1074

1075
    def test_lazily_initialized(self) -> None:
1076
        model = torch.nn.Sequential(
1077
            torch.nn.Linear(2, 2),
1078
            torch.nn.ReLU(),
1079
            LazyLinear(2, 2),
1080
            torch.nn.ReLU(),
1081
            torch.nn.Linear(2, 1),
1082
        )
1083

1084
        self.assertEqual(len(list(model.parameters())), 4)
1085

1086
        def inner_fn():
1087
            y = model(torch.ones((2, 2)))
1088
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1089
            optimizer.zero_grad()
1090
            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
1091
            optimizer.step()
1092

1093
        self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
1094
        self.assertEqual(len(list(model.parameters())), 6)
1095

1096
    def test_manual_optimizer_step(self) -> None:
1097
        model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
1098

1099
        def inner_fn():
1100
            y = model(torch.ones((2, 2)))
1101
            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
1102

1103
            with torch.no_grad():
1104
                for p in model.parameters():
1105
                    grad = p.grad
1106
                    self.assertIsNotNone(grad)
1107
                    p.add_(grad, alpha=-0.1)
1108

1109
        self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
1110

1111
    def test_categories_e2e_simple_fwd(self) -> None:
1112
        w0 = torch.ones((1,), requires_grad=True)
1113
        w1 = torch.ones((1,), requires_grad=True)
1114

1115
        def step_fn(_):
1116
            x = torch.ones((2, 2))
1117
            y = torch.cat([x * w0, x * w1], dim=1)
1118

1119
        # NOTE: We expect that all unknown categories. This is simply a sanity
1120
        #       check to ensure that we do not over-label.
1121
        self.assertExpectedInline(
1122
            self._run_and_format_categories(step_fn),
1123
            """\
1124
            aten::ones                                                                             -> 1 (???)
1125
            aten::mul.Tensor                         1 (???), 2 (???)                              -> 3 (???)
1126
            aten::mul.Tensor                         1 (???), 4 (???)                              -> 5 (???)
1127
            aten::cat                                3 (???), 5 (???)                              -> ???""",
1128
        )
1129

1130
    def test_categories_e2e_simple_fwd_bwd(self) -> None:
1131
        w0 = torch.ones((1,), requires_grad=True)
1132
        w1 = torch.ones((1,), requires_grad=True)
1133

1134
        def step_fn(mark_region):
1135
            x = torch.ones((2, 2))
1136
            targets = torch.ones((2, 4))
1137

1138
            mark_region("Forward & loss")
1139
            y = torch.cat([x * w0, x * w1], dim=1)
1140
            loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets)
1141

1142
            mark_region("Backward")
1143
            loss.backward()
1144

1145
        self.assertExpectedInline(
1146
            self._run_and_format_categories(step_fn),
1147
            """\
1148
            aten::ones                                                                             -> 1 (INPUT)
1149
            aten::ones                                                                             -> 2 (INPUT)
1150

1151
            -- Forward & loss ---------------------------------------------------------------------------------------
1152
            aten::mul.Tensor                         1 (INPUT), 3 (INPUT)                          -> 4 (INPUT)
1153
            aten::mul.Tensor                         1 (INPUT), 5 (INPUT)                          -> 6 (INPUT)
1154
            aten::cat                                4 (INPUT), 6 (INPUT)                          -> 7 (INPUT)
1155
            aten::binary_cross_entropy_with_logits   7 (INPUT), 2 (INPUT)                          -> 11 (INPUT)
1156

1157
            -- Backward ---------------------------------------------------------------------------------------------
1158
            aten::ones_like                          11 (INPUT)                                    -> 14 (INPUT)
1159
            aten::sigmoid                            7 (INPUT)                                     -> 15 (TEMPORARY)
1160
            aten::sub.Tensor                         15 (TEMPORARY), 2 (INPUT)                     -> 16 (TEMPORARY)
1161
            aten::mul.Tensor                         16 (TEMPORARY), 14 (INPUT)                    -> 17 (AUTOGRAD_DETAIL)
1162
            aten::div_.Scalar                        17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1163
            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1164
            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1165
            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 20 (AUTOGRAD_DETAIL)
1166
            aten::sum.dim_IntList                    20 (AUTOGRAD_DETAIL)                          -> 21 (GRADIENT)
1167
            aten::view                               21 (GRADIENT)                                 -> 21 (GRADIENT)
1168
            aten::detach                             21 (GRADIENT)                                 -> 21 (GRADIENT)
1169
            aten::detach                             21 (GRADIENT)                                 -> ???
1170
            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 22 (AUTOGRAD_DETAIL)
1171
            aten::sum.dim_IntList                    22 (AUTOGRAD_DETAIL)                          -> 23 (GRADIENT)
1172
            aten::view                               23 (GRADIENT)                                 -> 23 (GRADIENT)
1173
            aten::detach                             23 (GRADIENT)                                 -> 23 (GRADIENT)
1174
            aten::detach                             23 (GRADIENT)                                 -> ???""",
1175
        )
1176

1177
    def test_categories_e2e_simple_fwd_bwd_step(self) -> None:
1178
        w0 = torch.ones((1,), requires_grad=True)
1179
        w1 = torch.ones((1,), requires_grad=True)
1180
        optimizer = torch.optim.SGD([w0, w1], lr=0.1)
1181

1182
        def step_fn(mark_region):
1183
            x = torch.ones((2, 2))
1184
            targets = torch.ones((2, 4))
1185

1186
            mark_region("Forward & loss")
1187
            y = torch.cat([x * w0, x * w1], dim=1)
1188
            loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets)
1189

1190
            mark_region("Backward")
1191
            loss.backward()
1192

1193
            mark_region("Optimizer")
1194
            optimizer.step()
1195
            optimizer.zero_grad()
1196

1197
        self.assertExpectedInline(
1198
            self._run_and_format_categories(step_fn),
1199
            """\
1200
            aten::ones                                                                             -> 1 (INPUT)
1201
            aten::ones                                                                             -> 2 (INPUT)
1202

1203
            -- Forward & loss ---------------------------------------------------------------------------------------
1204
            aten::mul.Tensor                         1 (INPUT), 3 (PARAMETER)                      -> 4 (ACTIVATION)
1205
            aten::mul.Tensor                         1 (INPUT), 5 (PARAMETER)                      -> 6 (ACTIVATION)
1206
            aten::cat                                4 (ACTIVATION), 6 (ACTIVATION)                -> 7 (ACTIVATION)
1207
            aten::binary_cross_entropy_with_logits   7 (ACTIVATION), 2 (INPUT)                     -> 11 (ACTIVATION)
1208

1209
            -- Backward ---------------------------------------------------------------------------------------------
1210
            aten::ones_like                          11 (ACTIVATION)                               -> 14 (ACTIVATION)
1211
            aten::sigmoid                            7 (ACTIVATION)                                -> 15 (TEMPORARY)
1212
            aten::sub.Tensor                         15 (TEMPORARY), 2 (INPUT)                     -> 16 (TEMPORARY)
1213
            aten::mul.Tensor                         16 (TEMPORARY), 14 (ACTIVATION)               -> 17 (AUTOGRAD_DETAIL)
1214
            aten::div_.Scalar                        17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1215
            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1216
            aten::slice.Tensor                       17 (AUTOGRAD_DETAIL)                          -> 17 (AUTOGRAD_DETAIL)
1217
            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 20 (AUTOGRAD_DETAIL)
1218
            aten::sum.dim_IntList                    20 (AUTOGRAD_DETAIL)                          -> 21 (GRADIENT)
1219
            aten::view                               21 (GRADIENT)                                 -> 21 (GRADIENT)
1220
            aten::detach                             21 (GRADIENT)                                 -> 21 (GRADIENT)
1221
            aten::detach                             21 (GRADIENT)                                 -> 21 (GRADIENT)
1222
            aten::mul.Tensor                         17 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 22 (AUTOGRAD_DETAIL)
1223
            aten::sum.dim_IntList                    22 (AUTOGRAD_DETAIL)                          -> 23 (GRADIENT)
1224
            aten::view                               23 (GRADIENT)                                 -> 23 (GRADIENT)
1225
            aten::detach                             23 (GRADIENT)                                 -> 23 (GRADIENT)
1226
            aten::detach                             23 (GRADIENT)                                 -> 23 (GRADIENT)
1227

1228
            -- Optimizer --------------------------------------------------------------------------------------------
1229
            aten::add_.Tensor                        3 (PARAMETER), 23 (GRADIENT)                  -> 3 (PARAMETER)
1230
            aten::add_.Tensor                        5 (PARAMETER), 21 (GRADIENT)                  -> 5 (PARAMETER)""",
1231
        )
1232

1233
    def test_categories_e2e_simple_module_fwd(self) -> None:
1234
        model = torch.nn.Linear(2, 4, bias=True)
1235
        self.assertExpectedInline(
1236
            self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))),
1237
            """\
1238
            aten::ones                                                                             -> 1 (INPUT)
1239
            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1240
            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)""",
1241
        )
1242

1243
    def test_categories_e2e_simple_module_fwd_bwd(self) -> None:
1244
        model = torch.nn.Linear(2, 1, bias=True)
1245

1246
        def step_fn(mark_region):
1247
            mark_region("Forward & loss")
1248
            loss = model(torch.ones((2, 2))).sum()
1249

1250
            mark_region("Backward")
1251
            loss.backward()
1252

1253
        self.assertExpectedInline(
1254
            self._run_and_format_categories(step_fn),
1255
            """\
1256

1257
            -- Forward & loss ---------------------------------------------------------------------------------------
1258
            aten::ones                                                                             -> 1 (INPUT)
1259
            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1260
            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
1261
            aten::sum                                4 (ACTIVATION)                                -> 5 (ACTIVATION)
1262

1263
            -- Backward ---------------------------------------------------------------------------------------------
1264
            aten::ones_like                          5 (ACTIVATION)                                -> 6 (ACTIVATION)
1265
            aten::expand                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1266
            aten::t                                  6 (ACTIVATION)                                -> 6 (ACTIVATION)
1267
            aten::mm                                 6 (ACTIVATION), 1 (INPUT)                     -> 7 (GRADIENT)
1268
            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1269
            aten::sum.dim_IntList                    6 (ACTIVATION)                                -> 9 (GRADIENT)
1270
            aten::view                               9 (GRADIENT)                                  -> 9 (GRADIENT)
1271
            aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
1272
            aten::detach                             9 (GRADIENT)                                  -> ???
1273
            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1274
            aten::detach                             7 (GRADIENT)                                  -> 7 (GRADIENT)
1275
            aten::detach                             7 (GRADIENT)                                  -> ???""",
1276
        )
1277

1278
    def test_categories_e2e_simple_module_fwd_bwd_step(self) -> None:
1279
        model = torch.nn.Linear(2, 1, bias=True)
1280
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1281

1282
        def step_fn(mark_region):
1283
            mark_region("Forward & loss")
1284
            loss = model(torch.ones((2, 2))).sum()
1285

1286
            mark_region("Backward")
1287
            loss.backward()
1288

1289
            mark_region("Optimizer")
1290
            optimizer.step()
1291
            optimizer.zero_grad()
1292

1293
        self.assertExpectedInline(
1294
            self._run_and_format_categories(step_fn),
1295
            """\
1296

1297
            -- Forward & loss ---------------------------------------------------------------------------------------
1298
            aten::ones                                                                             -> 1 (INPUT)
1299
            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1300
            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
1301
            aten::sum                                4 (ACTIVATION)                                -> 5 (ACTIVATION)
1302

1303
            -- Backward ---------------------------------------------------------------------------------------------
1304
            aten::ones_like                          5 (ACTIVATION)                                -> 6 (ACTIVATION)
1305
            aten::expand                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1306
            aten::t                                  6 (ACTIVATION)                                -> 6 (ACTIVATION)
1307
            aten::mm                                 6 (ACTIVATION), 1 (INPUT)                     -> 7 (GRADIENT)
1308
            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1309
            aten::sum.dim_IntList                    6 (ACTIVATION)                                -> 9 (GRADIENT)
1310
            aten::view                               9 (GRADIENT)                                  -> 9 (GRADIENT)
1311
            aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
1312
            aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
1313
            aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
1314
            aten::detach                             7 (GRADIENT)                                  -> 7 (GRADIENT)
1315
            aten::detach                             7 (GRADIENT)                                  -> 7 (GRADIENT)
1316

1317
            -- Optimizer --------------------------------------------------------------------------------------------
1318
            aten::clone                              7 (GRADIENT)                                  -> 10 (OPTIMIZER_STATE)
1319
            aten::detach                             10 (OPTIMIZER_STATE)                          -> 10 (OPTIMIZER_STATE)
1320
            aten::detach                             10 (OPTIMIZER_STATE)                          -> 10 (OPTIMIZER_STATE)
1321
            aten::add_.Tensor                        2 (PARAMETER), 10 (OPTIMIZER_STATE)           -> 2 (PARAMETER)
1322
            aten::clone                              9 (GRADIENT)                                  -> 11 (OPTIMIZER_STATE)
1323
            aten::detach                             11 (OPTIMIZER_STATE)                          -> 11 (OPTIMIZER_STATE)
1324
            aten::detach                             11 (OPTIMIZER_STATE)                          -> 11 (OPTIMIZER_STATE)
1325
            aten::add_.Tensor                        3 (PARAMETER), 11 (OPTIMIZER_STATE)           -> 3 (PARAMETER)""",
1326
        )
1327

1328
    def test_categories_e2e_sequential_fwd(self) -> None:
1329
        model = torch.nn.Sequential(
1330
            torch.nn.Linear(2, 4, bias=True),
1331
            torch.nn.ReLU(),
1332
            torch.nn.Linear(4, 4, bias=False),
1333
            torch.nn.Softmax(dim=1),
1334
        )
1335
        self.assertExpectedInline(
1336
            self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))),
1337
            """\
1338
            aten::ones                                                                             -> 1 (INPUT)
1339
            aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
1340
            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
1341
            aten::relu                               4 (ACTIVATION)                                -> 5 (ACTIVATION)
1342
            aten::detach                             5 (ACTIVATION)                                -> ???
1343
            aten::t                                  6 (PARAMETER)                                 -> 6 (PARAMETER)
1344
            aten::mm                                 5 (ACTIVATION), 6 (PARAMETER)                 -> 7 (ACTIVATION)
1345
            aten::_softmax                           7 (ACTIVATION)                                -> 8 (ACTIVATION)
1346
            aten::detach                             8 (ACTIVATION)                                -> ???""",
1347
        )
1348

1349
    def test_categories_e2e_sequential_fwd_bwd(self) -> None:
1350
        model = torch.nn.Sequential(
1351
            torch.nn.Linear(2, 4, bias=True),
1352
            torch.nn.ReLU(),
1353
            torch.nn.Linear(4, 4, bias=False),
1354
            torch.nn.Softmax(dim=1),
1355
        )
1356

1357
        def step_fn(mark_region):
1358
            x = torch.ones((2, 2))
1359
            targets = torch.ones((2, 4))
1360

1361
            mark_region("Forward")
1362
            y = model(x)
1363

1364
            mark_region("Loss")
1365
            loss = torch.sum((y - targets) ** 2).mean()
1366

1367
            mark_region("Backward")
1368
            loss.backward()
1369

1370
        self.assertExpectedInline(
1371
            self._run_and_format_categories(step_fn),
1372
            """\
1373
            aten::ones                                                                             -> 1 (INPUT)
1374
            aten::ones                                                                             -> 2 (INPUT)
1375

1376
            -- Forward ----------------------------------------------------------------------------------------------
1377
            aten::t                                  3 (PARAMETER)                                 -> 3 (PARAMETER)
1378
            aten::addmm                              4 (PARAMETER), 1 (INPUT), 3 (PARAMETER)       -> 5 (ACTIVATION)
1379
            aten::relu                               5 (ACTIVATION)                                -> 6 (ACTIVATION)
1380
            aten::detach                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1381
            aten::t                                  7 (PARAMETER)                                 -> 7 (PARAMETER)
1382
            aten::mm                                 6 (ACTIVATION), 7 (PARAMETER)                 -> 8 (ACTIVATION)
1383
            aten::_softmax                           8 (ACTIVATION)                                -> 9 (ACTIVATION)
1384
            aten::detach                             9 (ACTIVATION)                                -> 9 (ACTIVATION)
1385

1386
            -- Loss -------------------------------------------------------------------------------------------------
1387
            aten::sub.Tensor                         9 (ACTIVATION), 2 (INPUT)                     -> 10 (ACTIVATION)
1388
            aten::pow.Tensor_Scalar                  10 (ACTIVATION)                               -> 11 (ACTIVATION)
1389
            aten::sum                                11 (ACTIVATION)                               -> 12 (ACTIVATION)
1390
            aten::mean                               12 (ACTIVATION)                               -> 13 (ACTIVATION)
1391

1392
            -- Backward ---------------------------------------------------------------------------------------------
1393
            aten::ones_like                          13 (ACTIVATION)                               -> 16 (ACTIVATION)
1394
            aten::expand                             16 (ACTIVATION)                               -> 16 (ACTIVATION)
1395
            aten::div.Scalar                         16 (ACTIVATION)                               -> 19 (AUTOGRAD_DETAIL)
1396
            aten::expand                             19 (AUTOGRAD_DETAIL)                          -> 19 (AUTOGRAD_DETAIL)
1397
            aten::pow.Tensor_Scalar                  10 (ACTIVATION)                               -> 20 (TEMPORARY)
1398
            aten::mul.Scalar                         20 (TEMPORARY)                                -> 23 (TEMPORARY)
1399
            aten::mul.Tensor                         19 (AUTOGRAD_DETAIL), 23 (TEMPORARY)          -> 24 (AUTOGRAD_DETAIL)
1400
            aten::detach                             9 (ACTIVATION)                                -> 9 (ACTIVATION)
1401
            aten::_softmax_backward_data             24 (AUTOGRAD_DETAIL), 9 (ACTIVATION)          -> 25 (AUTOGRAD_DETAIL)
1402
            aten::t                                  25 (AUTOGRAD_DETAIL)                          -> 25 (AUTOGRAD_DETAIL)
1403
            aten::mm                                 25 (AUTOGRAD_DETAIL), 6 (ACTIVATION)          -> 26 (GRADIENT)
1404
            aten::t                                  26 (GRADIENT)                                 -> 26 (GRADIENT)
1405
            aten::t                                  7 (PARAMETER)                                 -> 7 (PARAMETER)
1406
            aten::mm                                 25 (AUTOGRAD_DETAIL), 7 (PARAMETER)           -> 27 (AUTOGRAD_DETAIL)
1407
            aten::t                                  26 (GRADIENT)                                 -> 26 (GRADIENT)
1408
            aten::detach                             26 (GRADIENT)                                 -> 26 (GRADIENT)
1409
            aten::detach                             26 (GRADIENT)                                 -> ???
1410
            aten::detach                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
1411
            aten::threshold_backward                 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION)          -> 28 (AUTOGRAD_DETAIL)
1412
            aten::t                                  28 (AUTOGRAD_DETAIL)                          -> 28 (AUTOGRAD_DETAIL)
1413
            aten::mm                                 28 (AUTOGRAD_DETAIL), 1 (INPUT)               -> 29 (GRADIENT)
1414
            aten::t                                  29 (GRADIENT)                                 -> 29 (GRADIENT)
1415
            aten::sum.dim_IntList                    28 (AUTOGRAD_DETAIL)                          -> 30 (GRADIENT)
1416
            aten::view                               30 (GRADIENT)                                 -> 30 (GRADIENT)
1417
            aten::detach                             30 (GRADIENT)                                 -> 30 (GRADIENT)
1418
            aten::detach                             30 (GRADIENT)                                 -> ???
1419
            aten::t                                  29 (GRADIENT)                                 -> 29 (GRADIENT)
1420
            aten::detach                             29 (GRADIENT)                                 -> 29 (GRADIENT)
1421
            aten::detach                             29 (GRADIENT)                                 -> ???""",
1422
        )
1423

1424
    def test_memory_timeline(self) -> None:
1425
        model = torch.nn.Sequential(
1426
            torch.nn.Linear(64, 512, bias=True),
1427
            torch.nn.ReLU(),
1428
            torch.nn.Linear(512, 512, bias=False),
1429
            torch.nn.Softmax(dim=1),
1430
        )
1431
        optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
1432

1433
        with profile() as prof:
1434
            x = torch.ones((1024, 64))
1435
            targets = torch.ones((1024, 512))
1436
            y = model(x)
1437
            loss = torch.nn.functional.mse_loss(y, targets)
1438
            loss.backward()
1439
            optimizer.step()
1440
            optimizer.zero_grad()
1441

1442
        memory_profile = prof._memory_profile()
1443
        timeline = memory_profile.timeline
1444
        times = tuple(t for t, _, _, _ in timeline)
1445
        self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times)
1446
        self.assertTrue(
1447
            all(
1448
                (t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0)
1449
                for t, action, _, _ in timeline
1450
            )
1451
        )
1452

1453
        def category_name(category):
1454
            return category.name if category else "???"
1455

1456
        def format_action(action, key, version):
1457
            category = memory_profile._categories.get(key, version)
1458
            if action == _memory_profiler.Action.INCREMENT_VERSION:
1459
                new_category = memory_profile._categories.get(key, version + 1)
1460
                if category != new_category:
1461
                    return f"{category_name(category)} -> {category_name(new_category)}"
1462
            return category_name(category)
1463

1464
        def format_size(size: int):
1465
            if size < 1024:
1466
                return f"{size / 1024:3.1f} kB"
1467
            return f"{size // 1024} kB"
1468

1469
        # We generate sequential IDs for Tensors; however platforms vary
1470
        # slightly in the exact computation executed. If this results in
1471
        # tensor creation the IDs will be shifted and the unit test will fail.
1472
        # (Even though the behavior we're testing is unchanged.) To correct for
1473
        # this we assign sequential numbers to the tensors which are actually
1474
        # tested, effectively suppressing the extraneous implementation details.
1475
        id_map = {}
1476

1477
        def id_for_testing(key):
1478
            return id_map.setdefault(key.storage.allocation_id, len(id_map))
1479

1480
        lines = [
1481
            f"{action.name.lower():<25}  {format_action(action, key, version):<25}  "
1482
            f"{id_for_testing(key):>3}(v{version}) {format_size(size):>15}"
1483
            for _, action, (key, version), size in prof._memory_profile().timeline
1484
            # We generally don't care about tiny allocations during memory
1485
            # profiling and they add a lot of noise to the unit test.
1486
            if size > 1024
1487
        ]
1488

1489
        self.assertExpectedInline(
1490
            textwrap.indent("\n".join(lines), " " * 12),
1491
            """\
1492
            preexisting                PARAMETER                    0(v0)          128 kB
1493
            preexisting                PARAMETER                    1(v0)            2 kB
1494
            preexisting                PARAMETER                    2(v0)         1024 kB
1495
            create                     INPUT                        3(v0)          256 kB
1496
            create                     INPUT                        4(v0)         2048 kB
1497
            create                     ACTIVATION                   5(v0)         2048 kB
1498
            create                     ACTIVATION                   6(v0)         2048 kB
1499
            destroy                    ACTIVATION                   5(v0)         2048 kB
1500
            create                     ACTIVATION                   7(v0)         2048 kB
1501
            create                     ACTIVATION                   8(v0)         2048 kB
1502
            destroy                    ACTIVATION                   7(v0)         2048 kB
1503
            create                     ACTIVATION                   9(v0)         2048 kB
1504
            create                     TEMPORARY                   10(v0)         2048 kB
1505
            destroy                    TEMPORARY                   10(v0)         2048 kB
1506
            create                     AUTOGRAD_DETAIL             11(v0)         2048 kB
1507
            create                     AUTOGRAD_DETAIL             12(v0)         2048 kB
1508
            destroy                    AUTOGRAD_DETAIL             11(v0)         2048 kB
1509
            create                     GRADIENT                    13(v0)         1024 kB
1510
            create                     AUTOGRAD_DETAIL             14(v0)         2048 kB
1511
            destroy                    AUTOGRAD_DETAIL             12(v0)         2048 kB
1512
            create                     AUTOGRAD_DETAIL             15(v0)         2048 kB
1513
            destroy                    AUTOGRAD_DETAIL             14(v0)         2048 kB
1514
            destroy                    ACTIVATION                   6(v0)         2048 kB
1515
            create                     GRADIENT                    16(v0)          128 kB
1516
            create                     GRADIENT                    17(v0)            2 kB
1517
            destroy                    AUTOGRAD_DETAIL             15(v0)         2048 kB
1518
            create                     OPTIMIZER_STATE             18(v0)          128 kB
1519
            create                     OPTIMIZER_STATE             19(v0)          128 kB
1520
            create                     OPTIMIZER_STATE             20(v0)            2 kB
1521
            create                     OPTIMIZER_STATE             21(v0)            2 kB
1522
            create                     OPTIMIZER_STATE             22(v0)         1024 kB
1523
            create                     OPTIMIZER_STATE             23(v0)         1024 kB
1524
            increment_version          OPTIMIZER_STATE             18(v0)          128 kB
1525
            increment_version          OPTIMIZER_STATE             19(v0)          128 kB
1526
            increment_version          OPTIMIZER_STATE             19(v1)          128 kB
1527
            create                     ???                         24(v0)          128 kB
1528
            create                     ???                         25(v0)          128 kB
1529
            destroy                    ???                         24(v0)          128 kB
1530
            increment_version          ???                         25(v0)          128 kB
1531
            increment_version          PARAMETER                    0(v0)          128 kB
1532
            increment_version          OPTIMIZER_STATE             20(v0)            2 kB
1533
            increment_version          OPTIMIZER_STATE             21(v0)            2 kB
1534
            increment_version          OPTIMIZER_STATE             21(v1)            2 kB
1535
            create                     ???                         26(v0)            2 kB
1536
            create                     ???                         27(v0)            2 kB
1537
            destroy                    ???                         26(v0)            2 kB
1538
            increment_version          ???                         27(v0)            2 kB
1539
            destroy                    ???                         25(v1)          128 kB
1540
            increment_version          PARAMETER                    1(v0)            2 kB
1541
            increment_version          OPTIMIZER_STATE             22(v0)         1024 kB
1542
            increment_version          OPTIMIZER_STATE             23(v0)         1024 kB
1543
            increment_version          OPTIMIZER_STATE             23(v1)         1024 kB
1544
            create                     ???                         28(v0)         1024 kB
1545
            create                     ???                         29(v0)         1024 kB
1546
            destroy                    ???                         28(v0)         1024 kB
1547
            increment_version          ???                         29(v0)         1024 kB
1548
            destroy                    ???                         27(v1)            2 kB
1549
            increment_version          PARAMETER                    2(v0)         1024 kB
1550
            destroy                    ???                         29(v1)         1024 kB
1551
            destroy                    GRADIENT                    16(v0)          128 kB
1552
            destroy                    GRADIENT                    17(v0)            2 kB
1553
            destroy                    GRADIENT                    13(v0)         1024 kB""",
1554
        )
1555

1556
    def test_memory_timeline_no_id(self) -> None:
1557
        # On CPU the default behavior is to simply forward to malloc. That
1558
        # means that when we free `x` the allocator doesn't actually know how
1559
        # many bytes are in the allocation, and thus there's no point to
1560
        # calling `c10::reportMemoryUsageToProfiler`. So in order to test that
1561
        # memory profiler processes this case correctly we need to use CUDA
1562
        # where we do always keep a record.
1563
        x = torch.ones((1024,), device="cuda" if torch.cuda.is_available() else "cpu")
1564

1565
        with profile() as prof:
1566
            # We never see `x` used so we don't know the storage is for a
1567
            # Tensor, but we do still see the free event.
1568
            del x
1569

1570
            # For empty we see the allocation and free, but not any use.
1571
            # So this also cannot be identified as a Tensor.
1572
            y = torch.empty((64,))
1573
            del y
1574

1575
            z = torch.empty((256,))
1576
            z.view_as(z)  # Show `z` to the profiler
1577
            del z
1578

1579
        memory_profile = prof._memory_profile()
1580

1581
        expected = [
1582
            # x
1583
            (_memory_profiler.Action.PREEXISTING, 4096),
1584
            (_memory_profiler.Action.DESTROY, 4096),
1585
            #
1586
            # y
1587
            (_memory_profiler.Action.CREATE, 256),
1588
            (_memory_profiler.Action.DESTROY, 256),
1589
            #
1590
            # z
1591
            (_memory_profiler.Action.CREATE, 1024),
1592
            (_memory_profiler.Action.DESTROY, 1024),
1593
        ]
1594

1595
        actual = [(action, size) for _, action, _, size in memory_profile.timeline]
1596

1597
        # See above.
1598
        if not torch.cuda.is_available():
1599
            expected = expected[2:]
1600
            for event in expected:
1601
                self.assertTrue(
1602
                    event in actual, f"event: {event} was not found in actual."
1603
                )
1604
        else:
1605
            self.assertEqual(
1606
                actual,
1607
                expected,
1608
                f"expected does not match actual: {actual}",
1609
            )
1610

1611

1612
if __name__ == "__main__":
1613
    run_tests()
1614

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.