1
# Owner(s): ["oncall: profiler"]
6
from typing import Callable, Dict, Iterator, List, Optional, Tuple
9
from torch._C._profiler import _EventType, _TensorMetadata
10
from torch.profiler import _memory_profiler, _utils
11
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
12
from torch.utils import _pytree as pytree
15
profile = functools.partial(
16
torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True
20
@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
21
class TestMemoryProfiler(TestCase):
22
def test_config_check(self) -> None:
23
with torch.profiler.profile() as prof:
26
pattern = r"record_shapes=True, profile_memory=True, with_stack=True"
27
with self.assertRaisesRegex(ValueError, pattern):
28
prof._memory_profile()
30
with torch.profiler.profile(record_shapes=True, with_stack=True) as prof:
33
pattern = r"^profile_memory=True required for memory profiling\.$"
34
with self.assertRaisesRegex(ValueError, pattern):
35
prof._memory_profile()
37
with profile() as prof:
40
self.assertIsInstance(prof._memory_profile(), _memory_profiler.MemoryProfile)
43
class ScaleLayer(torch.nn.Module):
44
def __init__(self) -> None:
46
self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True)
48
def forward(self, x: torch.Tensor) -> torch.Tensor:
52
class LazyLinear(torch.nn.Module):
53
def __init__(self, in_features: int, out_features: int):
55
self.in_features = in_features
56
self.out_features = out_features
58
def forward(self, x) -> torch.Tensor:
59
if getattr(self, "weight", None) is None:
60
self.weight = torch.nn.Parameter(
61
torch.empty((self.out_features, self.in_features))
63
self.bias = torch.nn.Parameter(torch.empty(self.out_features))
65
return torch.nn.functional.linear(x, self.weight, self.bias)
68
class RecordInputOutputDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
69
def __init__(self) -> None:
72
def mark_region(self, name: str):
73
self.results.append((name, (), ()))
77
flat_args = pytree.tree_leaves(args)
79
(t._cdata, t.storage().data_ptr())
81
if isinstance(t, torch.Tensor) and t.storage()
84
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
87
flat_inputs = self.flat_ids(args) + self.flat_ids(kwargs)
88
out = func(*args, **kwargs)
89
flat_outputs = self.flat_ids(out)
91
flat_inputs or flat_outputs
92
) and "_record_function_enter" not in func.name():
93
self.results.append((func.name(), flat_inputs, flat_outputs))
97
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
98
class TestIdentifyGradients(TestCase):
99
def gradient_detected(
101
prof: torch.profiler.profile,
103
grad_tensor: torch.Tensor,
104
parameter: Optional[torch.Tensor] = None,
106
# This is not an exhaustive check, but for the purpose of unit testing
108
def key_matches_tensor(key, tensor) -> bool:
116
return tensor.storage().data_ptr() == key.storage.ptr
118
tree = prof.profiler.kineto_results.experimental_event_tree()
119
for node in _utils.traverse_dfs(tree):
120
for p_key, p_grad_key in _memory_profiler.extract_gradients(node):
121
if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor):
122
if parameter is None:
123
return True # Don't need to check parameter; we're done.
125
elif p_key is not None:
126
# For a complex workflow a gradient could correspond to
127
# different parameters at different points in a trace.
128
# However this will not happen in the relatively simple
129
# cases tested here, so if `extract_gradients` identifies
130
# the parameter corresponding to a particular gradient it
131
# must be the one we expect.
132
self.assertTrue(key_matches_tensor(p_key, parameter))
137
def assertGradientDetected(self, name: str, *args, **kwargs) -> None:
139
self.gradient_detected(*args, **kwargs),
140
f"Failed to identify gradient `{name}` from profile.",
143
def assertOnlyGradients(
144
self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor]
146
allowed_set = {t.storage().data_ptr() for t in tensors}
148
tree = prof.profiler.kineto_results.experimental_event_tree()
149
for node in _utils.traverse_dfs(tree):
150
for _, p_grad_key in _memory_profiler.extract_gradients(node):
152
p_grad_key.storage.ptr in allowed_set,
153
f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}",
156
def test_extract_gradients_low_level(self) -> None:
158
w0 = torch.ones((1,), requires_grad=True)
159
w1 = torch.ones((1,), requires_grad=True)
161
def check(cold_start: bool):
162
self.assertEqual(w0.grad is None, cold_start)
163
self.assertEqual(w1.grad is None, cold_start)
164
with profile() as prof:
166
(z * w1).sum().backward()
168
# Gradient detection through op inspection does not provide a
169
# reference to the parameter corresponding to the gradient.
170
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
171
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
172
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
174
check(cold_start=True)
175
check(cold_start=False)
177
def test_extract_gradients_from_module(self) -> None:
178
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
179
named_parameters = dict(model.named_parameters())
180
self.assertEqual(len(named_parameters), 3)
182
def assert_only_gradients(prof: torch.profiler.profile):
183
gradients = tuple(i.grad for i in named_parameters.values())
184
self.assertFalse(any(i is None for i in gradients))
185
self.assertOnlyGradients(prof, gradients)
187
def check(cold_start: bool):
188
x = torch.ones((2, 2))
189
with profile() as prof:
190
model(x).sum().backward()
192
for name, p in named_parameters.items():
193
# The first time we run a module none of the `.grad` fields
194
# have been initialized. This is fine; in that case we can
195
# detect everything we need in the profiled section.
197
self.gradient_detected(prof, _EventType.PyCall, p.grad, p),
202
# Op based detection should still identify the gradients.
203
self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad)
204
assert_only_gradients(prof)
206
# We can detect gradients even when `.backward()` is not called.
207
with profile() as prof:
208
model(torch.ones((2, 2)))
210
for name, p in named_parameters.items():
211
self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p)
213
self.gradient_detected(prof, _EventType.TorchOp, p.grad), name
215
assert_only_gradients(prof)
217
check(cold_start=True)
218
check(cold_start=False)
220
def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None:
222
w0 = torch.ones((1,), requires_grad=True)
223
w1 = torch.ones((1,), requires_grad=True)
224
optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9)
226
def check(cold_start: bool):
227
self.assertEqual(w0.grad is None, cold_start)
228
self.assertEqual(w1.grad is None, cold_start)
229
with profile() as prof:
230
optimizer.zero_grad(set_to_none=set_to_none)
232
(z * w1).sum().backward()
235
# Optimizer instrumentation runs late in the step, so we can detect
236
# gradients for both cold and warm start.
237
self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0)
238
self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1)
240
self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad)
241
self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad)
242
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
244
with profile() as prof:
246
optimizer.zero_grad(set_to_none=set_to_none)
248
(z * w1).sum().backward()
251
# Inspected state is cached, so if we replace gradients (as is the
252
# case for `set_to_none=True`) our python instrumentation will not
254
# TODO(robieta): Should `.step()` be excluded from caching?
256
self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0),
261
self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1),
266
with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"):
267
self.assertOnlyGradients(prof, (w0.grad, w1.grad))
269
check(cold_start=True)
270
check(cold_start=False)
272
def test_extract_gradients_from_optimizer(self) -> None:
273
self._test_extract_gradients_from_optimizer(set_to_none=False)
275
def test_extract_gradients_from_optimizer_set_to_none(self) -> None:
276
self._test_extract_gradients_from_optimizer(set_to_none=True)
278
def test_extract_gradients_from_module_and_optimizer(self) -> None:
279
# Module and optimizer are thoroughly tested individually and should be
280
# additive. Thus we can manage with a lightweight check that they don't
281
# interact adversely.
282
model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer())
283
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
284
with profile() as prof:
285
model(torch.ones((2, 2))).sum().backward()
288
self.assertGradientDetected(
289
"weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight
293
@skipIfTorchDynamo("TorchDynamo removes profiler altogether.")
294
class TestDataFlow(TestCase):
295
def setUp(self) -> None:
301
prof: torch.profiler.profile, indent: int = 12
302
) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]:
303
tree = prof.profiler.kineto_results.experimental_event_tree()
304
out: List[Tuple[str, Tuple[bool, ...]]] = []
305
for node in _utils.traverse_dfs(tree):
306
if node.tag == _EventType.TorchOp:
307
e = node.extra_fields
308
schemas = _memory_profiler.SchemaMatcher.match_schemas(e)
310
if len(schemas) == 1:
311
name = f"{name}.{schemas[0].overload_name}"
312
elif len(schemas) > 1:
313
name = f"{name}.{{{', '.join(s.overload_name for s in schemas)}}}"
315
out.append((name, _memory_profiler.SchemaMatcher.inputs_are_mutable(e)))
319
def _run_and_format_data_flow(
320
inputs: Dict[str, torch.Tensor],
321
f: Callable[..., Optional[Dict[str, torch.Tensor]]],
324
with profile() as prof:
325
outputs = f(**inputs) or {}
328
memory_profile = prof._memory_profile()
329
graph = memory_profile._data_flow_graph
330
storage_to_id = {key.storage.ptr: key.id for key in graph._active_version}
332
lines: List[str] = []
333
for name, t in it.chain(inputs.items(), outputs.items()):
334
lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}")
335
if t.grad is not None:
336
grad_id = storage_to_id[t.grad.storage().data_ptr()]
337
lines.append(f"{name + '.grad:':<9} T{grad_id}")
342
for node in graph.flow_nodes:
343
destroyed = {k for k, v in node._edges.items() if v.is_deletion}
345
inputs: List[str] = []
346
for key, (_, v) in node.inputs.items():
347
inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})")
349
outputs = [f"T{key.id}(v{v})" for key, v in node.outputs.items()]
350
if inputs or outputs:
351
event_name = node._event.name.replace("torch::autograd::", "")
353
f"{event_name:<25} {', '.join(inputs):<15} -> {', '.join(outputs)}"
356
return textwrap.indent("\n".join([l.rstrip() for l in lines]), " " * indent)
358
def test_match_schemas(self) -> None:
359
with profile() as prof:
360
x = torch.ones((1,)).mul(2).add_(2)
361
_ = torch.sin(x, out=torch.empty_like(x))
364
self.formatSchemas(prof),
366
("aten::ones.", (False,) * 5),
367
("aten::empty.memory_format", (False,) * 6),
369
# fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
370
("aten::fill_.Scalar", (True, False)),
371
("aten::mul.Tensor", (False, False)),
372
("aten::to.dtype", (False,) * 5),
373
("aten::_to_copy.", (False,) * 7),
374
("aten::empty_strided.", (False,) * 6),
376
# copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
377
("aten::copy_.", (True, False, False)),
379
# add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
380
("aten::add_.Tensor", (True, False, False)),
381
("aten::to.dtype", (False,) * 5),
382
("aten::_to_copy.", (False,) * 7),
383
("aten::empty_strided.", (False,) * 6),
385
# copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
386
("aten::copy_.", (True, False, False)),
387
("aten::empty_like.", (False,) * 6),
388
("aten::empty_strided.", (False,) * 6),
390
# sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
391
("aten::sin.out", (False, True)),
395
def test_match_schemas_backward(self) -> None:
397
w = torch.ones((1,), requires_grad=True)
398
with profile() as prof:
399
torch.mul(x, w).backward()
402
self.formatSchemas(prof),
404
("aten::mul.Tensor", (False, False)),
405
("aten::ones_like.", (False,) * 6),
406
("aten::empty_like.", (False,) * 6),
407
("aten::empty_strided.", (False,) * 6),
409
# fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
410
("aten::fill_.Scalar", (True, False)),
411
("autograd::engine::evaluate_function: MulBackward0", ()),
412
("MulBackward0", (None,)),
413
("aten::mul.Tensor", (False, False)),
415
"autograd::engine::evaluate_function: torch::autograd::AccumulateGrad",
418
("torch::autograd::AccumulateGrad", (None,)),
419
("aten::detach.", (False,)),
424
def test_match_schemas_tensorlist(self) -> None:
427
with profile() as prof:
428
torch.cat([x, y], axis=0)
431
self.formatSchemas(prof),
432
(("aten::cat.", (False, False)),),
435
def test_data_flow_graph_with_annotations(self) -> None:
437
# torch._C._jit_get_schemas_for_operator will reject any name that
438
# is missing a namespace. (denoted by the presence of "::") We want
439
# to check that we skip both annotations which have no schema
440
# (return empty tuple from SchemaMatcher.lookup_schemas) and
441
# annotations which cannot have schema (return None from
442
# SchemaMatcher.lookup_schemas).
443
with torch.profiler.record_function("Namespaced::Annotation"):
444
with torch.profiler.record_function("My Annotation"):
447
return {"x0": torch.ones_like(x), "y0": torch.zeros_like(y)}
449
inputs = {"x": torch.ones((1,)), "y": torch.ones((1,))}
450
self.assertExpectedInline(
451
self._run_and_format_data_flow(inputs, f),
458
aten::zero_ T0(v0) -> T0(v1)
459
aten::zero_ T1(v0) -> T1(v1)
460
aten::ones_like T0(v1) -> T2(v0)
461
aten::zeros_like T1(v1) -> T3(v0)""",
464
def test_data_flow_graph_non_op_allocations(self) -> None:
468
# The python arg parser will convert the python scalar `2` to a Tensor
469
# to pass to `aten::mul`. As a result there is no op that "owns" the
470
# allocation. The Tensor deletions also do not happen in an op; they
471
# are collected as a result of the Python objects going out of scope.
472
self.assertExpectedInline(
473
self._run_and_format_data_flow({"x": torch.ones((1,))}, f),
478
aten::mul T0(v0), T1(v0) ->
479
[memory] T0(v0*) ->""",
482
def test_data_flow_graph_simple(self) -> None:
483
inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)}
487
return {"z": z.view_as(z)}
490
with torch.no_grad():
493
self.assertExpectedInline(
494
self._run_and_format_data_flow(inputs, f0),
500
aten::mul T0(v0), T1(v0) -> T2(v0)
501
aten::view_as T2(v0) ->""",
504
# Out of place is identical regardless of Autograd.
505
self.assertExpectedInline(
506
self._run_and_format_data_flow(inputs, f0),
512
aten::mul T0(v0), T1(v0) -> T2(v0)
513
aten::view_as T2(v0) ->""",
516
def test_data_flow_graph_simple_inplace(self) -> None:
517
inputs = {"x": torch.ones((25,)), "y": torch.ones((25,), requires_grad=True)}
523
with torch.no_grad():
526
# When Autograd is enabled a second Tensor `T2` is created to store
527
# the values of T0(v0) which are needed for backwards.
528
self.assertExpectedInline(
529
self._run_and_format_data_flow(inputs, f0),
534
aten::mul_ T0(v0), T1(v0) -> T0(v1), T2(v0)""",
537
self.assertExpectedInline(
538
self._run_and_format_data_flow(inputs, f1),
543
aten::mul_ T0(v0), T1(v0) -> T0(v1)""",
546
def test_data_flow_graph_simple_backward(self) -> None:
548
"x": torch.ones((1,)),
549
"w": torch.ones((1,), requires_grad=True),
551
self.assertExpectedInline(
552
self._run_and_format_data_flow(
553
inputs, lambda x, w: (x * w).sin().backward()
560
aten::mul T0(v0), T1(v0) -> T2(v0)
561
aten::sin T2(v0) -> T3(v0)
562
aten::ones_like T3(v0) -> T4(v0)
563
SinBackward0 T2(v0), T4(v0) -> T6(v0)
565
MulBackward0 T0(v0), T6(v0) -> T7(v0)
567
AccumulateGrad T7(v0) ->
569
[memory] T3(v0*) ->""",
572
def test_data_flow_graph_complicated(self) -> None:
574
x = torch.ones((25,))
576
z = torch.sin(y, out=torch.empty_like(y))
577
return {"x": x, "y": y, "z": z}
579
# T1 is the `2` in `.mul(2)`. The Python arg parser automatically
580
# converts Scalar arguments to Tensors. The same is true for `T4`
582
self.assertExpectedInline(
583
self._run_and_format_data_flow({}, f),
591
aten::mul T0(v0), T1(v0) -> T3(v0)
594
aten::add_ T3(v0), T4(v0) -> T3(v1)
596
aten::empty_like T3(v1) -> T6(v0)
597
aten::sin T3(v1), T6(v0) -> T6(v1)""",
600
with profile() as prof:
603
# `aten::mul` creates a temporary Tensor (T2), which is why the output
604
# is has ID three rather than two.
605
mul_node = prof._memory_profile()._data_flow_graph.flow_nodes[2]
606
self.assertEqual(mul_node._event.name, "aten::mul")
607
self.assertEqual(len(mul_node.intermediates), 1)
608
self.assertEqual(mul_node.intermediates[0].id, 2)
610
def test_data_flow_graph_stacked(self) -> None:
612
"x": torch.ones((25,)),
613
"w0": torch.ones((1,), requires_grad=True),
614
"w1": torch.ones((1,), requires_grad=True),
618
return x.mul(w0).relu().mul(w1).relu().sum()
621
with torch.no_grad():
622
return {"loss": f(**kwargs)}
624
def f_fwd_bwd(**kwargs):
627
return {"loss": loss}
629
self.assertExpectedInline(
630
self._run_and_format_data_flow(inputs, f_fwd),
637
aten::mul T0(v0), T1(v0) -> T2(v0)
638
aten::relu T2(v0) -> T3(v0)
640
aten::mul T3(v0), T4(v0) -> T5(v0)
642
aten::relu T5(v0) -> T6(v0)
644
aten::sum T6(v0) -> T7(v0)
645
[memory] T6(v0*) ->""",
648
self.assertExpectedInline(
649
self._run_and_format_data_flow(inputs, f_fwd_bwd),
658
aten::mul T0(v0), T1(v0) -> T2(v0)
659
aten::relu T2(v0) -> T3(v0)
661
aten::mul T3(v0), T4(v0) -> T5(v0)
662
aten::relu T5(v0) -> T6(v0)
664
aten::sum T6(v0) -> T7(v0)
665
aten::ones_like T7(v0) -> T8(v0)
666
SumBackward0 T8(v0) ->
667
ReluBackward0 T6(v0), T8(v0) -> T9(v0)
669
MulBackward0 T3(v0), T4(v0), T9(v0) -> T10(v0), T11(v0)
670
aten::sum T10(v0) -> T12(v0)
673
AccumulateGrad T12(v0) ->
674
ReluBackward0 T3(v0), T11(v0) -> T13(v0)
677
MulBackward0 T0(v0), T13(v0) -> T14(v0)
678
aten::sum T14(v0) -> T15(v0)
681
AccumulateGrad T15(v0) ->
682
[memory] T8(v0*) ->""",
685
# Second time grads are already initialized.
686
self.assertExpectedInline(
687
self._run_and_format_data_flow(inputs, f_fwd_bwd),
696
aten::mul T0(v0), T1(v0) -> T2(v0)
697
aten::relu T2(v0) -> T3(v0)
699
aten::mul T3(v0), T4(v0) -> T5(v0)
700
aten::relu T5(v0) -> T6(v0)
702
aten::sum T6(v0) -> T7(v0)
703
aten::ones_like T7(v0) -> T8(v0)
704
SumBackward0 T8(v0) ->
705
ReluBackward0 T6(v0), T8(v0) -> T9(v0)
707
MulBackward0 T3(v0), T4(v0), T9(v0) -> T10(v0), T11(v0)
708
aten::sum T10(v0) -> T12(v0)
711
AccumulateGrad T12(v0*), T13(v0) -> T13(v1)
712
ReluBackward0 T3(v0), T11(v0) -> T14(v0)
715
MulBackward0 T0(v0), T14(v0) -> T15(v0)
716
aten::sum T15(v0) -> T16(v0)
719
AccumulateGrad T16(v0*), T17(v0) -> T17(v1)
720
[memory] T8(v0*) ->""",
725
x = torch.ones((25,))
726
w0 = torch.ones((1,), requires_grad=True)
727
w1 = torch.ones((1,), requires_grad=True)
729
with profile() as prof_no_grad:
730
with torch.no_grad():
731
x.mul(w0).relu().mul(w1).relu().sum()
733
# TODO: one with `.logsumexp(dim=0)`
735
self.assertExpectedInline(
736
self._format_graph(prof_no_grad),
738
aten::mul T0(v0), T1(v0) -> T2(v0)
739
aten::relu T2(v0) -> T3(v0)
741
aten::mul T3(v0), T4(v0) -> T5(v0)
743
aten::relu T5(v0) -> T6(v0)
745
aten::sum T6(v0) -> T7(v0)
747
[memory] T7(v0*) ->""",
750
with profile() as prof_grad:
751
loss = x.mul(w0).relu().mul(w1).relu().sum()
754
self.assertExpectedInline(
755
self._format_graph(prof_grad),
757
aten::mul T0(v0), T1(v0) -> T2(v0)
758
aten::relu T2(v0) -> T3(v0)
760
aten::mul T3(v0), T4(v0) -> T5(v0)
761
aten::relu T5(v0) -> T6(v0)
763
aten::sum T6(v0) -> T7(v0)
764
aten::ones_like T7(v0) -> T8(v0)
765
SumBackward0 T8(v0) -> T8(v1)
766
ReluBackward0 T6(v0), T8(v1) -> T8(v2), T9(v0)
768
MulBackward0 T3(v0), T4(v0), T9(v0) -> T9(v1), T10(v0), T11(v0)
769
aten::sum T10(v0) -> T12(v0)
772
AccumulateGrad T12(v0) -> T12(v1)
773
ReluBackward0 T3(v0), T11(v0) -> T11(v1), T13(v0)
776
MulBackward0 T0(v0), T13(v0) -> T13(v1), T14(v0)
777
aten::sum T14(v0) -> T15(v0)
780
AccumulateGrad T15(v0) -> T15(v1)
781
[memory] T8(v2*) ->""",
784
# Second time grads are already initialized.
785
with profile() as prof_grad:
786
loss = x.mul(w0).relu().mul(w1).relu().sum()
789
self.assertExpectedInline(
790
self._format_graph(prof_grad),
792
aten::mul T0(v0), T1(v0) -> T2(v0)
793
aten::relu T2(v0) -> T3(v0)
795
aten::mul T3(v0), T4(v0) -> T5(v0)
796
aten::relu T5(v0) -> T6(v0)
798
aten::sum T6(v0) -> T7(v0)
799
aten::ones_like T7(v0) -> T8(v0)
800
SumBackward0 T8(v0) -> T8(v1)
801
ReluBackward0 T6(v0), T8(v1) -> T8(v2), T9(v0)
803
MulBackward0 T3(v0), T4(v0), T9(v0) -> T9(v1), T10(v0), T11(v0)
804
aten::sum T10(v0) -> T12(v0)
807
AccumulateGrad T12(v0*), T13(v0) -> T13(v1)
808
ReluBackward0 T3(v0), T11(v0) -> T11(v1), T14(v0)
811
MulBackward0 T0(v0), T14(v0) -> T14(v1), T15(v0)
812
aten::sum T15(v0) -> T16(v0)
815
AccumulateGrad T16(v0*), T17(v0) -> T17(v1)
816
[memory] T8(v2*) ->""",
820
@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.")
821
class TestMemoryProfilerE2E(TestCase):
823
def _lookup_tensor_categories(
824
t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile
825
) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]:
826
storage = t.storage()
828
raise ValueError("Cannot look up uninitialized Tensor.")
830
snapshot = memory_profile._category_snapshot()
832
key.storage.allocation_id
833
for key, _ in snapshot
834
if key.storage.ptr == storage.data_ptr() and key.device == storage.device
838
(key, version): category
839
for (key, version), category in memory_profile._category_snapshot().items()
841
# If a Tensor is live we want the most recent ID
842
if key.storage.allocation_id == max(ids | {-1})
845
def _run_and_check_parameters_and_gradients(
846
self, inner_fn, model, grads_none: bool = False
848
with profile() as prof:
851
memory_profile = prof._memory_profile()
855
category: _memory_profiler.Category,
856
should_be_none: bool = False,
859
assert t is None, "tensor should be None but is not."
861
self.assertIsNotNone(t)
862
categories = self._lookup_tensor_categories(t, memory_profile)
863
self.assertGreater(len(categories), 0)
864
self.assertTrue(all(c == category for c in categories.values()), categories)
866
for p in model.parameters():
867
assert_category(p, _memory_profiler.Category.PARAMETER)
868
assert_category(p.grad, _memory_profiler.Category.GRADIENT, grads_none)
870
# Rely on internal asserts
871
_ = memory_profile.timeline
873
def _run_and_format_categories(self, fn, indent=12):
874
"""Generate summary of assigned categories for expecttest."""
876
# Use `__torch_dispatch__` to collect ground truth.
877
with RecordInputOutputDispatchMode() as record_ops, profile() as prof:
878
fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-")))
880
memory_profile = prof._memory_profile()
881
ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {}
882
snapshot = memory_profile._category_snapshot()
884
# Build map from observed live Tensors to the memory profiler's
885
# TensorKey representation.
886
for op in memory_profile._op_tree.dfs():
887
if op.typed[0] == _EventType.TorchOp:
888
inputs = pytree.tree_leaves(op.typed[1].inputs)
889
for t in (i for i in inputs if isinstance(i, _TensorMetadata)):
890
key = _memory_profiler.TensorKey.from_tensor(t)
892
ptr_pair_to_key[(t.impl_ptr, t.storage_data_ptr)] = key
894
def format_categories(ptr_pair: int):
895
target_key = ptr_pair_to_key.get(ptr_pair, None)
896
if target_key is None:
900
(version, category.name if category else "???")
901
for (key, version), category in snapshot.items()
904
assert matches, "Failed to lookup Tensor"
906
# Deduplicate version bumps which don't change the category.
907
categories = [matches[0][1]]
908
for _, category in matches:
909
if category != categories[-1]:
910
categories.append(category)
912
return f"{target_key.storage.allocation_id} ({','.join(categories)})"
915
for name, inputs, outputs in record_ops.results:
916
if inputs or outputs:
918
inputs_str = ", ".join(format_categories(i) for i in inputs)
919
outputs_str = ", ".join(format_categories(i) for i in outputs)
920
out.append(f"{name:<40} {inputs_str:<45} -> {outputs_str}")
924
out.append(f"\n{name}")
926
return textwrap.indent("\n".join(out), " " * indent)
928
def test_parameters_and_gradients(self):
929
model = torch.nn.Sequential(
930
torch.nn.Linear(2, 2), ScaleLayer(), torch.nn.Linear(2, 1), ScaleLayer()
932
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
935
_ = model(torch.ones((2, 2)))
938
optimizer.zero_grad()
939
y = model(torch.ones((2, 2)))
940
torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
943
# If we profile the first step then gradients will not have been
944
# created when we call `model.forward`, so if we don't call `.backward`
945
# then gradients are never created.
946
self._run_and_check_parameters_and_gradients(
947
inner_fn=fwd_only, model=model, grads_none=True
950
# On the first step we must rely on `AccumulateGrad`, since gradients
951
# did not exist when `model.forward` was called.
952
self.assertTrue(all(p.grad is None for p in model.parameters()))
953
self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
955
# After one step the python tracer will also flag gradients.
956
self.assertTrue(not any(p.grad is None for p in model.parameters()))
957
self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
959
# The parameter gradients are not used but we still detect them with
961
self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model)
963
def test_parameters_and_gradients_set_to_none(self):
964
model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
965
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
969
# zero grads at the start so gradients are still live to be
971
optimizer.zero_grad(set_to_none=True)
973
y = model(torch.ones((2, 2)))
974
torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
978
self.assertTrue(not any(p.grad is None for p in model.parameters()))
979
self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
981
optimizer.zero_grad(set_to_none=True)
982
self.assertTrue(all(p.grad is None for p in model.parameters()))
983
self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model)
985
def test_inputs_fwd(self):
986
model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
987
inputs = [torch.ones((2, 2)) for _ in range(2)]
989
with profile() as prof:
990
# Inputs which were allocated before profiling began
994
# Inputs which were allocated after profiling began
996
x = torch.ones((2, 2))
1000
memory_profile = prof._memory_profile()
1002
categories = self._lookup_tensor_categories(x, memory_profile)
1003
self.assertGreater(len(categories), 0)
1005
all(i == _memory_profiler.Category.INPUT for i in categories.values()),
1009
snapshot = memory_profile._category_snapshot()
1010
self.assertTrue(_memory_profiler.Category.INPUT in snapshot.values())
1012
def test_inputs_fwd_lazy(self):
1013
model = torch.nn.Sequential(LazyLinear(2, 2), LazyLinear(2, 1))
1014
inputs = [torch.ones((2, 2)) for _ in range(2)]
1016
with profile() as prof:
1017
# Inputs which were allocated before profiling began
1021
# Inputs which were allocated after profiling began
1023
x = torch.ones((2, 2))
1027
# For now we can't make any meaningful statements without a backward
1028
# pass. Here we simply ensure that passes don't generate false positive
1029
# category classifications.
1030
memory_profile = prof._memory_profile()
1032
categories = self._lookup_tensor_categories(x, memory_profile)
1033
self.assertGreater(len(categories), 0)
1034
self.assertTrue(all(i is None for i in categories.values()), categories)
1036
snapshot = memory_profile._category_snapshot()
1037
self.assertFalse(_memory_profiler.Category.INPUT in snapshot.values())
1039
def test_inputs_fwd_bwd(self):
1040
model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
1041
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1042
inputs_targets = [(torch.ones((2, 2)), torch.rand((2, 1))) for _ in range(2)]
1044
def fwd_bwd_step(x, targets):
1046
torch.nn.functional.mse_loss(y, targets).backward()
1048
optimizer.zero_grad()
1050
with profile() as prof:
1051
# Inputs which were allocated before profiling began
1052
for x, targets in inputs_targets:
1053
fwd_bwd_step(x, targets)
1055
# Inputs which were allocated after profiling began
1057
x = torch.ones((2, 2))
1058
targets = torch.rand((2, 1))
1059
inputs_targets.append((x, targets))
1060
fwd_bwd_step(x, targets)
1062
memory_profile = prof._memory_profile()
1065
categories = self._lookup_tensor_categories(t, memory_profile)
1066
self.assertGreater(len(categories), 0)
1068
all(i == _memory_profiler.Category.INPUT for i in categories.values())
1071
for x, targets in inputs_targets:
1075
def test_lazily_initialized(self) -> None:
1076
model = torch.nn.Sequential(
1077
torch.nn.Linear(2, 2),
1081
torch.nn.Linear(2, 1),
1084
self.assertEqual(len(list(model.parameters())), 4)
1087
y = model(torch.ones((2, 2)))
1088
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1089
optimizer.zero_grad()
1090
torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
1093
self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
1094
self.assertEqual(len(list(model.parameters())), 6)
1096
def test_manual_optimizer_step(self) -> None:
1097
model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1))
1100
y = model(torch.ones((2, 2)))
1101
torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
1103
with torch.no_grad():
1104
for p in model.parameters():
1106
self.assertIsNotNone(grad)
1107
p.add_(grad, alpha=-0.1)
1109
self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
1111
def test_categories_e2e_simple_fwd(self) -> None:
1112
w0 = torch.ones((1,), requires_grad=True)
1113
w1 = torch.ones((1,), requires_grad=True)
1116
x = torch.ones((2, 2))
1117
y = torch.cat([x * w0, x * w1], dim=1)
1119
# NOTE: We expect that all unknown categories. This is simply a sanity
1120
# check to ensure that we do not over-label.
1121
self.assertExpectedInline(
1122
self._run_and_format_categories(step_fn),
1124
aten::ones -> 1 (???)
1125
aten::mul.Tensor 1 (???), 2 (???) -> 3 (???)
1126
aten::mul.Tensor 1 (???), 4 (???) -> 5 (???)
1127
aten::cat 3 (???), 5 (???) -> ???""",
1130
def test_categories_e2e_simple_fwd_bwd(self) -> None:
1131
w0 = torch.ones((1,), requires_grad=True)
1132
w1 = torch.ones((1,), requires_grad=True)
1134
def step_fn(mark_region):
1135
x = torch.ones((2, 2))
1136
targets = torch.ones((2, 4))
1138
mark_region("Forward & loss")
1139
y = torch.cat([x * w0, x * w1], dim=1)
1140
loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets)
1142
mark_region("Backward")
1145
self.assertExpectedInline(
1146
self._run_and_format_categories(step_fn),
1148
aten::ones -> 1 (INPUT)
1149
aten::ones -> 2 (INPUT)
1151
-- Forward & loss ---------------------------------------------------------------------------------------
1152
aten::mul.Tensor 1 (INPUT), 3 (INPUT) -> 4 (INPUT)
1153
aten::mul.Tensor 1 (INPUT), 5 (INPUT) -> 6 (INPUT)
1154
aten::cat 4 (INPUT), 6 (INPUT) -> 7 (INPUT)
1155
aten::binary_cross_entropy_with_logits 7 (INPUT), 2 (INPUT) -> 11 (INPUT)
1157
-- Backward ---------------------------------------------------------------------------------------------
1158
aten::ones_like 11 (INPUT) -> 14 (INPUT)
1159
aten::sigmoid 7 (INPUT) -> 15 (TEMPORARY)
1160
aten::sub.Tensor 15 (TEMPORARY), 2 (INPUT) -> 16 (TEMPORARY)
1161
aten::mul.Tensor 16 (TEMPORARY), 14 (INPUT) -> 17 (AUTOGRAD_DETAIL)
1162
aten::div_.Scalar 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
1163
aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
1164
aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
1165
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
1166
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
1167
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
1168
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
1169
aten::detach 21 (GRADIENT) -> ???
1170
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
1171
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
1172
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
1173
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
1174
aten::detach 23 (GRADIENT) -> ???""",
1177
def test_categories_e2e_simple_fwd_bwd_step(self) -> None:
1178
w0 = torch.ones((1,), requires_grad=True)
1179
w1 = torch.ones((1,), requires_grad=True)
1180
optimizer = torch.optim.SGD([w0, w1], lr=0.1)
1182
def step_fn(mark_region):
1183
x = torch.ones((2, 2))
1184
targets = torch.ones((2, 4))
1186
mark_region("Forward & loss")
1187
y = torch.cat([x * w0, x * w1], dim=1)
1188
loss = torch.nn.functional.binary_cross_entropy_with_logits(y, targets)
1190
mark_region("Backward")
1193
mark_region("Optimizer")
1195
optimizer.zero_grad()
1197
self.assertExpectedInline(
1198
self._run_and_format_categories(step_fn),
1200
aten::ones -> 1 (INPUT)
1201
aten::ones -> 2 (INPUT)
1203
-- Forward & loss ---------------------------------------------------------------------------------------
1204
aten::mul.Tensor 1 (INPUT), 3 (PARAMETER) -> 4 (ACTIVATION)
1205
aten::mul.Tensor 1 (INPUT), 5 (PARAMETER) -> 6 (ACTIVATION)
1206
aten::cat 4 (ACTIVATION), 6 (ACTIVATION) -> 7 (ACTIVATION)
1207
aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 11 (ACTIVATION)
1209
-- Backward ---------------------------------------------------------------------------------------------
1210
aten::ones_like 11 (ACTIVATION) -> 14 (ACTIVATION)
1211
aten::sigmoid 7 (ACTIVATION) -> 15 (TEMPORARY)
1212
aten::sub.Tensor 15 (TEMPORARY), 2 (INPUT) -> 16 (TEMPORARY)
1213
aten::mul.Tensor 16 (TEMPORARY), 14 (ACTIVATION) -> 17 (AUTOGRAD_DETAIL)
1214
aten::div_.Scalar 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
1215
aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
1216
aten::slice.Tensor 17 (AUTOGRAD_DETAIL) -> 17 (AUTOGRAD_DETAIL)
1217
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 20 (AUTOGRAD_DETAIL)
1218
aten::sum.dim_IntList 20 (AUTOGRAD_DETAIL) -> 21 (GRADIENT)
1219
aten::view 21 (GRADIENT) -> 21 (GRADIENT)
1220
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
1221
aten::detach 21 (GRADIENT) -> 21 (GRADIENT)
1222
aten::mul.Tensor 17 (AUTOGRAD_DETAIL), 1 (INPUT) -> 22 (AUTOGRAD_DETAIL)
1223
aten::sum.dim_IntList 22 (AUTOGRAD_DETAIL) -> 23 (GRADIENT)
1224
aten::view 23 (GRADIENT) -> 23 (GRADIENT)
1225
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
1226
aten::detach 23 (GRADIENT) -> 23 (GRADIENT)
1228
-- Optimizer --------------------------------------------------------------------------------------------
1229
aten::add_.Tensor 3 (PARAMETER), 23 (GRADIENT) -> 3 (PARAMETER)
1230
aten::add_.Tensor 5 (PARAMETER), 21 (GRADIENT) -> 5 (PARAMETER)""",
1233
def test_categories_e2e_simple_module_fwd(self) -> None:
1234
model = torch.nn.Linear(2, 4, bias=True)
1235
self.assertExpectedInline(
1236
self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))),
1238
aten::ones -> 1 (INPUT)
1239
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
1240
aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)""",
1243
def test_categories_e2e_simple_module_fwd_bwd(self) -> None:
1244
model = torch.nn.Linear(2, 1, bias=True)
1246
def step_fn(mark_region):
1247
mark_region("Forward & loss")
1248
loss = model(torch.ones((2, 2))).sum()
1250
mark_region("Backward")
1253
self.assertExpectedInline(
1254
self._run_and_format_categories(step_fn),
1257
-- Forward & loss ---------------------------------------------------------------------------------------
1258
aten::ones -> 1 (INPUT)
1259
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
1260
aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)
1261
aten::sum 4 (ACTIVATION) -> 5 (ACTIVATION)
1263
-- Backward ---------------------------------------------------------------------------------------------
1264
aten::ones_like 5 (ACTIVATION) -> 6 (ACTIVATION)
1265
aten::expand 6 (ACTIVATION) -> 6 (ACTIVATION)
1266
aten::t 6 (ACTIVATION) -> 6 (ACTIVATION)
1267
aten::mm 6 (ACTIVATION), 1 (INPUT) -> 7 (GRADIENT)
1268
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
1269
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
1270
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
1271
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
1272
aten::detach 9 (GRADIENT) -> ???
1273
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
1274
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
1275
aten::detach 7 (GRADIENT) -> ???""",
1278
def test_categories_e2e_simple_module_fwd_bwd_step(self) -> None:
1279
model = torch.nn.Linear(2, 1, bias=True)
1280
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1282
def step_fn(mark_region):
1283
mark_region("Forward & loss")
1284
loss = model(torch.ones((2, 2))).sum()
1286
mark_region("Backward")
1289
mark_region("Optimizer")
1291
optimizer.zero_grad()
1293
self.assertExpectedInline(
1294
self._run_and_format_categories(step_fn),
1297
-- Forward & loss ---------------------------------------------------------------------------------------
1298
aten::ones -> 1 (INPUT)
1299
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
1300
aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)
1301
aten::sum 4 (ACTIVATION) -> 5 (ACTIVATION)
1303
-- Backward ---------------------------------------------------------------------------------------------
1304
aten::ones_like 5 (ACTIVATION) -> 6 (ACTIVATION)
1305
aten::expand 6 (ACTIVATION) -> 6 (ACTIVATION)
1306
aten::t 6 (ACTIVATION) -> 6 (ACTIVATION)
1307
aten::mm 6 (ACTIVATION), 1 (INPUT) -> 7 (GRADIENT)
1308
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
1309
aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
1310
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
1311
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
1312
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
1313
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
1314
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
1315
aten::detach 7 (GRADIENT) -> 7 (GRADIENT)
1317
-- Optimizer --------------------------------------------------------------------------------------------
1318
aten::clone 7 (GRADIENT) -> 10 (OPTIMIZER_STATE)
1319
aten::detach 10 (OPTIMIZER_STATE) -> 10 (OPTIMIZER_STATE)
1320
aten::detach 10 (OPTIMIZER_STATE) -> 10 (OPTIMIZER_STATE)
1321
aten::add_.Tensor 2 (PARAMETER), 10 (OPTIMIZER_STATE) -> 2 (PARAMETER)
1322
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
1323
aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE)
1324
aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE)
1325
aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
1328
def test_categories_e2e_sequential_fwd(self) -> None:
1329
model = torch.nn.Sequential(
1330
torch.nn.Linear(2, 4, bias=True),
1332
torch.nn.Linear(4, 4, bias=False),
1333
torch.nn.Softmax(dim=1),
1335
self.assertExpectedInline(
1336
self._run_and_format_categories(lambda _: model(torch.ones((2, 2)))),
1338
aten::ones -> 1 (INPUT)
1339
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
1340
aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)
1341
aten::relu 4 (ACTIVATION) -> 5 (ACTIVATION)
1342
aten::detach 5 (ACTIVATION) -> ???
1343
aten::t 6 (PARAMETER) -> 6 (PARAMETER)
1344
aten::mm 5 (ACTIVATION), 6 (PARAMETER) -> 7 (ACTIVATION)
1345
aten::_softmax 7 (ACTIVATION) -> 8 (ACTIVATION)
1346
aten::detach 8 (ACTIVATION) -> ???""",
1349
def test_categories_e2e_sequential_fwd_bwd(self) -> None:
1350
model = torch.nn.Sequential(
1351
torch.nn.Linear(2, 4, bias=True),
1353
torch.nn.Linear(4, 4, bias=False),
1354
torch.nn.Softmax(dim=1),
1357
def step_fn(mark_region):
1358
x = torch.ones((2, 2))
1359
targets = torch.ones((2, 4))
1361
mark_region("Forward")
1365
loss = torch.sum((y - targets) ** 2).mean()
1367
mark_region("Backward")
1370
self.assertExpectedInline(
1371
self._run_and_format_categories(step_fn),
1373
aten::ones -> 1 (INPUT)
1374
aten::ones -> 2 (INPUT)
1376
-- Forward ----------------------------------------------------------------------------------------------
1377
aten::t 3 (PARAMETER) -> 3 (PARAMETER)
1378
aten::addmm 4 (PARAMETER), 1 (INPUT), 3 (PARAMETER) -> 5 (ACTIVATION)
1379
aten::relu 5 (ACTIVATION) -> 6 (ACTIVATION)
1380
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
1381
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
1382
aten::mm 6 (ACTIVATION), 7 (PARAMETER) -> 8 (ACTIVATION)
1383
aten::_softmax 8 (ACTIVATION) -> 9 (ACTIVATION)
1384
aten::detach 9 (ACTIVATION) -> 9 (ACTIVATION)
1386
-- Loss -------------------------------------------------------------------------------------------------
1387
aten::sub.Tensor 9 (ACTIVATION), 2 (INPUT) -> 10 (ACTIVATION)
1388
aten::pow.Tensor_Scalar 10 (ACTIVATION) -> 11 (ACTIVATION)
1389
aten::sum 11 (ACTIVATION) -> 12 (ACTIVATION)
1390
aten::mean 12 (ACTIVATION) -> 13 (ACTIVATION)
1392
-- Backward ---------------------------------------------------------------------------------------------
1393
aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION)
1394
aten::expand 16 (ACTIVATION) -> 16 (ACTIVATION)
1395
aten::div.Scalar 16 (ACTIVATION) -> 19 (AUTOGRAD_DETAIL)
1396
aten::expand 19 (AUTOGRAD_DETAIL) -> 19 (AUTOGRAD_DETAIL)
1397
aten::pow.Tensor_Scalar 10 (ACTIVATION) -> 20 (TEMPORARY)
1398
aten::mul.Scalar 20 (TEMPORARY) -> 23 (TEMPORARY)
1399
aten::mul.Tensor 19 (AUTOGRAD_DETAIL), 23 (TEMPORARY) -> 24 (AUTOGRAD_DETAIL)
1400
aten::detach 9 (ACTIVATION) -> 9 (ACTIVATION)
1401
aten::_softmax_backward_data 24 (AUTOGRAD_DETAIL), 9 (ACTIVATION) -> 25 (AUTOGRAD_DETAIL)
1402
aten::t 25 (AUTOGRAD_DETAIL) -> 25 (AUTOGRAD_DETAIL)
1403
aten::mm 25 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 26 (GRADIENT)
1404
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
1405
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
1406
aten::mm 25 (AUTOGRAD_DETAIL), 7 (PARAMETER) -> 27 (AUTOGRAD_DETAIL)
1407
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
1408
aten::detach 26 (GRADIENT) -> 26 (GRADIENT)
1409
aten::detach 26 (GRADIENT) -> ???
1410
aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
1411
aten::threshold_backward 27 (AUTOGRAD_DETAIL), 6 (ACTIVATION) -> 28 (AUTOGRAD_DETAIL)
1412
aten::t 28 (AUTOGRAD_DETAIL) -> 28 (AUTOGRAD_DETAIL)
1413
aten::mm 28 (AUTOGRAD_DETAIL), 1 (INPUT) -> 29 (GRADIENT)
1414
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
1415
aten::sum.dim_IntList 28 (AUTOGRAD_DETAIL) -> 30 (GRADIENT)
1416
aten::view 30 (GRADIENT) -> 30 (GRADIENT)
1417
aten::detach 30 (GRADIENT) -> 30 (GRADIENT)
1418
aten::detach 30 (GRADIENT) -> ???
1419
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
1420
aten::detach 29 (GRADIENT) -> 29 (GRADIENT)
1421
aten::detach 29 (GRADIENT) -> ???""",
1424
def test_memory_timeline(self) -> None:
1425
model = torch.nn.Sequential(
1426
torch.nn.Linear(64, 512, bias=True),
1428
torch.nn.Linear(512, 512, bias=False),
1429
torch.nn.Softmax(dim=1),
1431
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
1433
with profile() as prof:
1434
x = torch.ones((1024, 64))
1435
targets = torch.ones((1024, 512))
1437
loss = torch.nn.functional.mse_loss(y, targets)
1440
optimizer.zero_grad()
1442
memory_profile = prof._memory_profile()
1443
timeline = memory_profile.timeline
1444
times = tuple(t for t, _, _, _ in timeline)
1445
self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times)
1448
(t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0)
1449
for t, action, _, _ in timeline
1453
def category_name(category):
1454
return category.name if category else "???"
1456
def format_action(action, key, version):
1457
category = memory_profile._categories.get(key, version)
1458
if action == _memory_profiler.Action.INCREMENT_VERSION:
1459
new_category = memory_profile._categories.get(key, version + 1)
1460
if category != new_category:
1461
return f"{category_name(category)} -> {category_name(new_category)}"
1462
return category_name(category)
1464
def format_size(size: int):
1466
return f"{size / 1024:3.1f} kB"
1467
return f"{size // 1024} kB"
1469
# We generate sequential IDs for Tensors; however platforms vary
1470
# slightly in the exact computation executed. If this results in
1471
# tensor creation the IDs will be shifted and the unit test will fail.
1472
# (Even though the behavior we're testing is unchanged.) To correct for
1473
# this we assign sequential numbers to the tensors which are actually
1474
# tested, effectively suppressing the extraneous implementation details.
1477
def id_for_testing(key):
1478
return id_map.setdefault(key.storage.allocation_id, len(id_map))
1481
f"{action.name.lower():<25} {format_action(action, key, version):<25} "
1482
f"{id_for_testing(key):>3}(v{version}) {format_size(size):>15}"
1483
for _, action, (key, version), size in prof._memory_profile().timeline
1484
# We generally don't care about tiny allocations during memory
1485
# profiling and they add a lot of noise to the unit test.
1489
self.assertExpectedInline(
1490
textwrap.indent("\n".join(lines), " " * 12),
1492
preexisting PARAMETER 0(v0) 128 kB
1493
preexisting PARAMETER 1(v0) 2 kB
1494
preexisting PARAMETER 2(v0) 1024 kB
1495
create INPUT 3(v0) 256 kB
1496
create INPUT 4(v0) 2048 kB
1497
create ACTIVATION 5(v0) 2048 kB
1498
create ACTIVATION 6(v0) 2048 kB
1499
destroy ACTIVATION 5(v0) 2048 kB
1500
create ACTIVATION 7(v0) 2048 kB
1501
create ACTIVATION 8(v0) 2048 kB
1502
destroy ACTIVATION 7(v0) 2048 kB
1503
create ACTIVATION 9(v0) 2048 kB
1504
create TEMPORARY 10(v0) 2048 kB
1505
destroy TEMPORARY 10(v0) 2048 kB
1506
create AUTOGRAD_DETAIL 11(v0) 2048 kB
1507
create AUTOGRAD_DETAIL 12(v0) 2048 kB
1508
destroy AUTOGRAD_DETAIL 11(v0) 2048 kB
1509
create GRADIENT 13(v0) 1024 kB
1510
create AUTOGRAD_DETAIL 14(v0) 2048 kB
1511
destroy AUTOGRAD_DETAIL 12(v0) 2048 kB
1512
create AUTOGRAD_DETAIL 15(v0) 2048 kB
1513
destroy AUTOGRAD_DETAIL 14(v0) 2048 kB
1514
destroy ACTIVATION 6(v0) 2048 kB
1515
create GRADIENT 16(v0) 128 kB
1516
create GRADIENT 17(v0) 2 kB
1517
destroy AUTOGRAD_DETAIL 15(v0) 2048 kB
1518
create OPTIMIZER_STATE 18(v0) 128 kB
1519
create OPTIMIZER_STATE 19(v0) 128 kB
1520
create OPTIMIZER_STATE 20(v0) 2 kB
1521
create OPTIMIZER_STATE 21(v0) 2 kB
1522
create OPTIMIZER_STATE 22(v0) 1024 kB
1523
create OPTIMIZER_STATE 23(v0) 1024 kB
1524
increment_version OPTIMIZER_STATE 18(v0) 128 kB
1525
increment_version OPTIMIZER_STATE 19(v0) 128 kB
1526
increment_version OPTIMIZER_STATE 19(v1) 128 kB
1527
create ??? 24(v0) 128 kB
1528
create ??? 25(v0) 128 kB
1529
destroy ??? 24(v0) 128 kB
1530
increment_version ??? 25(v0) 128 kB
1531
increment_version PARAMETER 0(v0) 128 kB
1532
increment_version OPTIMIZER_STATE 20(v0) 2 kB
1533
increment_version OPTIMIZER_STATE 21(v0) 2 kB
1534
increment_version OPTIMIZER_STATE 21(v1) 2 kB
1535
create ??? 26(v0) 2 kB
1536
create ??? 27(v0) 2 kB
1537
destroy ??? 26(v0) 2 kB
1538
increment_version ??? 27(v0) 2 kB
1539
destroy ??? 25(v1) 128 kB
1540
increment_version PARAMETER 1(v0) 2 kB
1541
increment_version OPTIMIZER_STATE 22(v0) 1024 kB
1542
increment_version OPTIMIZER_STATE 23(v0) 1024 kB
1543
increment_version OPTIMIZER_STATE 23(v1) 1024 kB
1544
create ??? 28(v0) 1024 kB
1545
create ??? 29(v0) 1024 kB
1546
destroy ??? 28(v0) 1024 kB
1547
increment_version ??? 29(v0) 1024 kB
1548
destroy ??? 27(v1) 2 kB
1549
increment_version PARAMETER 2(v0) 1024 kB
1550
destroy ??? 29(v1) 1024 kB
1551
destroy GRADIENT 16(v0) 128 kB
1552
destroy GRADIENT 17(v0) 2 kB
1553
destroy GRADIENT 13(v0) 1024 kB""",
1556
def test_memory_timeline_no_id(self) -> None:
1557
# On CPU the default behavior is to simply forward to malloc. That
1558
# means that when we free `x` the allocator doesn't actually know how
1559
# many bytes are in the allocation, and thus there's no point to
1560
# calling `c10::reportMemoryUsageToProfiler`. So in order to test that
1561
# memory profiler processes this case correctly we need to use CUDA
1562
# where we do always keep a record.
1563
x = torch.ones((1024,), device="cuda" if torch.cuda.is_available() else "cpu")
1565
with profile() as prof:
1566
# We never see `x` used so we don't know the storage is for a
1567
# Tensor, but we do still see the free event.
1570
# For empty we see the allocation and free, but not any use.
1571
# So this also cannot be identified as a Tensor.
1572
y = torch.empty((64,))
1575
z = torch.empty((256,))
1576
z.view_as(z) # Show `z` to the profiler
1579
memory_profile = prof._memory_profile()
1583
(_memory_profiler.Action.PREEXISTING, 4096),
1584
(_memory_profiler.Action.DESTROY, 4096),
1587
(_memory_profiler.Action.CREATE, 256),
1588
(_memory_profiler.Action.DESTROY, 256),
1591
(_memory_profiler.Action.CREATE, 1024),
1592
(_memory_profiler.Action.DESTROY, 1024),
1595
actual = [(action, size) for _, action, _, size in memory_profile.timeline]
1598
if not torch.cuda.is_available():
1599
expected = expected[2:]
1600
for event in expected:
1602
event in actual, f"event: {event} was not found in actual."
1608
f"expected does not match actual: {actual}",
1612
if __name__ == "__main__":