19
from typing_extensions import Literal
22
from torch._C import FunctionSchema
23
from torch._C._autograd import _ProfilerResult
24
from torch._C._profiler import (
26
_ExtraFields_Allocation,
32
from torch._utils import _element_size
33
from torch.profiler import _utils
35
KeyAndID = Tuple["Key", int]
36
TensorAndID = Tuple["TensorKey", int]
38
log = logging.getLogger(__name__)
41
class Category(enum.Enum):
43
TEMPORARY = enum.auto()
44
ACTIVATION = enum.auto()
45
GRADIENT = enum.auto()
46
AUTOGRAD_DETAIL = enum.auto()
47
PARAMETER = enum.auto()
48
OPTIMIZER_STATE = enum.auto()
51
_CATEGORY_TO_COLORS = {
52
Category.PARAMETER: "darkgreen",
53
Category.OPTIMIZER_STATE: "goldenrod",
54
Category.INPUT: "black",
55
Category.TEMPORARY: "mediumpurple",
56
Category.ACTIVATION: "red",
57
Category.GRADIENT: "mediumblue",
58
Category.AUTOGRAD_DETAIL: "royalblue",
62
_CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)}
65
class Action(enum.Enum):
66
PREEXISTING = enum.auto()
68
INCREMENT_VERSION = enum.auto()
72
_ACTION_TO_INDEX = {i: i.value for i in Action}
75
@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True)
82
"""Bundle storage pointer and id.
84
All profiling logic should use `allocation_id`, however it is useful to
85
print storage pointers for debugging and unit tests sometimes look up
86
values using the storage data pointer of a live Tensor."""
91
def __repr__(self) -> str:
92
return f"{hex(self.ptr):>18} ({self.allocation_id})"
94
def __eq__(self, other: object) -> bool:
95
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
97
def __hash__(self) -> int:
98
return hash(self.allocation_id)
101
@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
103
"""Hashable identifier for a storage which has been asigned an ID.
105
A detailed description of Tensor IDs and why they are needed is given in
106
`torch/csrc/profiler/collection.h` when `TensorID` is declared. To
107
summarize, multiple Storage buffers can map to the same logical Tensor.
108
This dataclass is used to refer to a concrete in-memory StorageImpl of
115
def __repr__(self) -> str:
116
return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
118
def __lt__(self, other: "TensorKey") -> bool:
119
return self._as_sortable < other._as_sortable
123
tensor_id: Optional[int],
124
storage_ptr: Optional[int],
125
allocation_id: Optional[int],
126
device: torch.device,
127
) -> Optional["TensorKey"]:
129
tensor_id is not None
130
and storage_ptr is not None
131
and allocation_id is not None
133
return TensorKey(device, tensor_id, _Storage(storage_ptr, allocation_id))
137
def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]:
138
return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device)
141
def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
143
return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
147
def _as_sortable(self) -> Tuple[int, int, str, int]:
148
return self.id, self.storage.allocation_id, self.device.type, self.device.index
151
def _extract_parameters_and_gradients(
152
node: _ProfilerEvent,
153
) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]:
154
children = node.children
156
# AccumulateGrad is used in the Autograd engine to handle gradient updates.
157
# There are two possible cases:
158
# 1) This is a newly created gradient Tensor. In that case there is nothing
159
# to accumulate, so autograd simply detaches the Tensor.
161
# 2) There is a preexisting gradient Tensor and we need to add the newly
162
# computed update. This is done with an in-place add (aten::add_) op.
163
# (The underscore suffix denotes "in-place".)
165
node.typed[0] == _EventType.TorchOp
166
and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
167
# TODO(robieta): Move away from load bearing names
168
and node.name == "torch::autograd::AccumulateGrad"
170
and children[0].typed[0] == _EventType.TorchOp
171
and children[0].name in ("aten::detach", "aten::add_")
172
and children[0].typed[1].inputs
173
and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
175
yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0])
177
# We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
178
# NOTE: The values captured by the python tracer are cached; they can be
179
# used to build up labels but do not imply that a Tensor was live at
181
elif node.typed[0] == _EventType.PyCall:
182
typed_fields = node.typed[1]
183
assert typed_fields.module is None or typed_fields.optimizer is None
184
if typed_fields.module is not None:
185
for _, p, p_grad in typed_fields.module.parameters:
186
yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
188
if typed_fields.optimizer is not None:
189
for p, p_grad, _ in typed_fields.optimizer.parameters:
190
yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
193
def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
194
for p, p_grad in _extract_parameters_and_gradients(node):
199
def extract_gradients(
200
node: _ProfilerEvent,
201
) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
202
for p, p_grad in _extract_parameters_and_gradients(node):
203
if p_grad is not None:
207
def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]:
210
if event.typed[0] == _EventType.TorchOp:
211
scopes.append(event.typed[1].scope)
217
"""Lookup operator schema based on profiled name.
219
When profiling we record the operator's name but not the schema. However
220
some analysis requires that information. Fortunately we can look up
221
registered schema from the recorded name. We do not, however, record the
222
overload and so we must compare the profiled arguments with all overloads
223
to determine viable matches.
225
Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed
226
this code will be obsolete.
230
def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]:
231
"""Determine which inputs may have mutated based on function schema.
233
Note that we don't need to resolve down to a single schema to perform
234
this analysis. An input is mutable if it is mutable in any overload. In
235
practice, however, it is overwhelmingly common to match a single
236
overload. If we cannot find any valid schema then we must be
237
conservative and assume all inputs are mutable.
239
mutable: Optional[List[bool]] = None
240
for schema in cls.match_schemas(t):
241
mutable = mutable or [False for _ in schema.arguments]
242
for i, arg in enumerate(schema.arguments):
243
mutable[i] |= getattr(arg.alias_info, "is_write", False)
245
return tuple(mutable or (None for _ in t.inputs))
248
def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]:
251
TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
254
else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
256
# Scalar and uncaptured inputs.
261
def matches(schema) -> bool:
262
return len(schema.arguments) == len(signature) and all(
263
cls._types_match(observed, schema_arg.type)
264
for observed, schema_arg in zip(signature, schema.arguments)
267
return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
270
def _types_match(cls, observed, schema_type) -> bool:
271
if isinstance(schema_type, torch._C.OptionalType):
272
schema_type = schema_type.getElementType()
273
return observed is None or cls._types_match(observed, schema_type)
275
if isinstance(schema_type, torch._C.AnyType):
278
if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()):
279
return isinstance(observed, list) and all(
280
isinstance(i, TensorKey) for i in observed
283
type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = (
284
(torch._C.TensorType, TensorKey),
285
(torch._C.NoneType, type(None)),
286
(torch._C.BoolType, bool),
287
(torch._C.IntType, int),
288
(torch._C.FloatType, float),
289
(torch._C.ComplexType, complex),
290
(torch._C.NumberType, (bool, int, float, complex)),
293
for jit_type, py_types in type_map:
294
if isinstance(schema_type, jit_type):
295
return isinstance(observed, py_types)
297
# Profiler only records a subset of possible argument types. If we
298
# reach this point then the schema must call for a type that profiler
299
# does not record. Thus, the schema can only be a match if `observed`
301
return observed is None
304
def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]:
306
# _jit_get_schemas_for_operator is quite expensive. (~100us / call)
307
# Consider adding `functools.lru_cache` if that becomes an issue.
310
# Schema lookup will throw if `name` is malformed. (For example,
311
# schemas must be namespaced and schema lookup will fail if name
312
# does not include "::".) We simply catch the exception and return
313
# `None` to denote that `name` cannot be an operator name.
315
# Note that record_function annotations also go through this path,
316
# so it is expected that some names will not correspond to PyTorch
320
return tuple(torch._C._jit_get_schemas_for_operator(name))
326
def __init__(self, result: _ProfilerResult) -> None:
327
self._root_nodes = result.experimental_event_tree()
328
self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns))
330
def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]:
331
yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs)
334
def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]:
335
return self._sorted_nodes
339
def __init__(self, op_tree: OpTree) -> None:
340
self._values: Dict[TensorKey, int] = {}
342
for node in op_tree.sorted_nodes:
343
if node.typed[0] == _EventType.TorchOp:
344
for t in self._flat_tensor_inputs(node.typed[1]):
345
self._update_values(t)
347
elif node.typed[0] == _EventType.PyCall:
348
typed_fields = node.typed[1]
349
assert typed_fields.module is None or typed_fields.optimizer is None
350
if typed_fields.module is not None:
351
for _, p, p_grad in typed_fields.module.parameters:
352
self._update_values(p)
353
self._update_values(p_grad)
355
if typed_fields.optimizer is not None:
356
for p, p_grad, state in typed_fields.optimizer.parameters:
357
self._update_values(p)
358
self._update_values(p_grad)
360
self._update_values(t)
362
allocations: Dict[TensorKey, int] = {}
363
for node in op_tree.sorted_nodes:
364
if node.typed[0] == _EventType.Allocation:
365
alloc_fields = node.typed[1]
366
key = TensorKey.from_allocation(alloc_fields)
368
new_size = abs(alloc_fields.alloc_size)
369
prior_size = allocations.setdefault(key, new_size)
371
# It is possible to resize Storage in PyTorch, however we
372
# key on data pointer so most resizes will be treated as a
373
# change in storage. The one corner case that cannot be
374
# handled is `realloc` which successfully resizes the
375
# storage. At time of writing this is not done anywhere in
376
# the core PyTorch codebase.
377
if prior_size != new_size:
378
delta = f"{prior_size} vs. {new_size}"
379
log.warning("Mismatch between allocation and free: %s", delta)
381
self._values.update(allocations)
383
def _update_values(self, t: Optional[_TensorMetadata]) -> None:
384
key = TensorKey.from_tensor(t)
385
if key is not None and t is not None and t.layout == torch.strided:
386
# Scalars are represented as zero dim Tensors
387
n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
389
num_bytes = n * _element_size(t.dtype)
390
assert num_bytes >= 0, f"{num_bytes}"
391
self._values[key] = max(self._values.get(key, 0), num_bytes)
394
def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]:
396
if isinstance(i, _TensorMetadata):
398
elif isinstance(i, list):
401
def __getitem__(self, key: TensorKey):
402
return self._values[key]
405
@dataclasses.dataclass()
407
input_version: Optional[int] = None
408
mutated: Optional[bool] = False
411
def is_allocation(self) -> bool:
412
return self.input_version is None
415
def is_deletion(self) -> bool:
416
return self.mutated is None
420
def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None:
423
self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges()
425
for key, edge in self._edges.items():
426
if edge.mutated and not edge.is_allocation:
427
self._graph.bump(key)
429
# Make sure the version bumping behavior matches what we expect.
430
versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()}
431
assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}"
433
def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]:
434
subtree = tuple(_utils.traverse_dfs([self._event]))
436
# Start by populating edges from op inputs and outputs.
437
mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {}
438
for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
439
for op_input, mutable in zip(
440
op.inputs, SchemaMatcher.inputs_are_mutable(op)
443
if isinstance(op_input, _TensorMetadata):
444
key = TensorKey.from_tensor(op_input)
445
mutable_by_key.setdefault(key, set()).add(mutable)
448
elif isinstance(op_input, list):
449
for op_input_i in op_input:
450
key = TensorKey.from_tensor(op_input_i)
451
mutable_by_key.setdefault(key, set()).add(mutable)
453
edges: DefaultDict[Optional[TensorKey], DataFlowEdge]
454
edges = collections.defaultdict(DataFlowEdge)
455
for key, mutable_set in mutable_by_key.items():
457
edges[key].input_version = self._graph.lookup(key) if key else -1
459
# We consider an op to be mutated if we encounter a schema where it
460
# is a mutable argument OR if it is ambiguous. (We never explicitly
461
# see it in any schema.)
462
mutated = (True in mutable_set) or (tuple(mutable_set) == (None,))
463
edges[key].mutated = mutated
465
# Then handle deletions. Note that deleting a Tensor implicitly adds
466
# it as an input edge.
468
if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0:
469
key = TensorKey.from_allocation(i.typed[1])
471
assert key is None or edge.mutated is not None, f"Double delete: {key}"
473
edge.input_version = self._graph.lookup(key) if key else -1
475
# And finally handle allocations. This step must be last, because the
476
# previous two steps optimistically add input edges.
478
if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0:
479
edges[TensorKey.from_allocation(i.typed[1])].input_version = None
481
# We don't need to sort the inputs, but it makes debugging and unit tests nicer.
482
return dict(sorted((k, v) for k, v in edges.items() if k is not None))
485
def inputs(self) -> Dict[TensorKey, Tuple[bool, int]]:
487
# MyPy can't see through `is_allocation` to know that
488
# `v.input_version` is not None.
489
k: (bool(v.mutated), cast(int, v.input_version))
490
for k, v in self._edges.items()
491
if not v.is_allocation
495
def outputs(self) -> Dict[TensorKey, int]:
497
k: 0 if v.input_version is None else v.input_version + 1
498
for k, v in self._edges.items()
499
if (v.is_allocation and not v.is_deletion) or v.mutated
503
def intermediates(self) -> Tuple[TensorKey, ...]:
505
k for k, v in self._edges.items() if v.is_allocation and v.is_deletion
509
def start_time(self) -> int:
510
return self._event.start_time_ns
514
def __init__(self, op_tree: OpTree) -> None:
515
self._op_tree = op_tree
516
self._leaf_events = self._extract_leaf_events(op_tree)
517
self._active_version: Dict[TensorKey, Optional[int]] = {}
518
self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events]
519
self._flow_nodes.sort(key=lambda x: x.start_time)
523
def flow_nodes(self) -> Tuple[DataFlowNode, ...]:
524
return tuple(self._flow_nodes)
527
# Check that each (Tensor, version) pair has a unique creation node
528
outputs: Set[Tuple[TensorKey, int]] = set()
529
for node in self.flow_nodes:
530
node_outputs = set(node.outputs.items())
531
duplicates = outputs & node_outputs
532
assert not duplicates, f"{node._event.name} {node._edges} {duplicates}"
533
outputs |= node_outputs
535
# And check that `self._nodes` forms a valid topologically sorted DAG.
536
tensor_versions: Dict[TensorKey, int] = {}
537
for node in self.flow_nodes:
538
for key, (_, version) in node.inputs.items():
539
expected = tensor_versions.get(key, 0)
540
assert expected == version, (expected, version)
542
for key, version in node.outputs.items():
543
prior_version = tensor_versions.get(key, version)
544
assert version >= prior_version, (version, prior_version)
545
tensor_versions[key] = version
548
def leaf_events(self) -> Tuple[_ProfilerEvent, ...]:
549
return self._leaf_events
552
def _extract_leaf_events(op_tree: OpTree) -> Tuple[_ProfilerEvent, ...]:
553
"""Partially traverse the op tree and extract top level ops.
555
Consider the following code:
557
with record_function("My annotation"):
562
The op tree (assuming no Autograd) will look like:
564
TorchOp: "My annotation"
570
The recursive structure of operator calls makes data flow unwieldy.
571
In order to simplify analysis we would like to select the highest level
572
ops to represent in the graph. In this case those are the `zero_` ops;
573
the fact that `fill_` is called is an implementation detail. We also
574
do not want to group everything under "My annotation" as this could
575
create overly coarse bundles and lose critical semantics.
577
To address this issue we walk over the graph and select the topmost
578
torch ops ** which match at least one operator schema **. These form
579
the leaves of the first pass through the op tree. (As well as any
580
allocations or frees which do are not part of a kernel.) These events
581
form the logical nodes in our data flow graph.
584
leaf_events: List[_ProfilerEvent] = []
586
def leaf_op(e: _ProfilerEvent) -> bool:
587
return e.typed[0] == _EventType.TorchOp and (
588
e.typed[1].scope == RecordScope.BACKWARD_FUNCTION
589
or bool(SchemaMatcher.match_schemas(e.typed[1]))
592
def children_fn(e: _ProfilerEvent):
593
if leaf_op(e) or e.tag == _EventType.Allocation:
594
leaf_events.append(e)
599
for _ in op_tree.dfs(children_fn=children_fn):
602
return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns))
604
def lookup(self, key: TensorKey) -> int:
605
version = self._active_version.setdefault(key, 0)
606
assert version is not None
609
def bump(self, key: TensorKey) -> None:
610
prior_version = self._active_version.get(key, None)
611
assert prior_version is not None
612
self._active_version[key] = prior_version + 1
614
def delete(self, key: TensorKey) -> None:
615
assert self._active_version.setdefault(key, 0) is not None
616
self._active_version[key] = None
619
@dataclasses.dataclass
620
class CategoryElement:
621
by_id: Optional[Category] = None
622
by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
623
by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
625
# Used by unit tests to check internals. (And consequently by
626
# MemoryProfile.lookup) This should not be used in any other capacity.
627
_by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set)
630
@dataclasses.dataclass
632
_values: DefaultDict[int, CategoryElement] = dataclasses.field(
633
default_factory=lambda: collections.defaultdict(CategoryElement)
636
def set_by_id(self, key: TensorKey, category: Category) -> None:
637
self._values[key.id].by_id = category
638
self._values[key.id]._by_id_keyset.add(key)
640
def set_by_key(self, key: TensorKey, category: Category) -> None:
641
self._values[key.id].by_key[key] = category
643
def set_by_version(self, key: TensorKey, version: int, category: Category) -> None:
644
self._values[key.id].by_version[(key, version)] = category
646
def setdefault_by_version(
647
self, key: TensorKey, version: int, category: Category
649
self._values[key.id].by_version.setdefault((key, version), category)
651
def get(self, key: Key, version: int) -> Optional[Category]:
652
if isinstance(key, Key) and not isinstance(key, TensorKey):
654
element = self._values[key.id]
657
or element.by_key.get(key, None)
658
or element.by_version.get((key, version), None)
663
def __init__(self, result: _ProfilerResult) -> None:
664
self._op_tree = OpTree(result)
665
self._data_flow_graph = DataFlowGraph(self._op_tree)
666
self._size_map = SizeMap(self._op_tree)
667
self._categories = CategoryDict()
669
self._set_gradients_and_temporaries()
670
self._set_parameters_using_python_tracer()
672
self._set_parameters_using_data_flow()
673
self._set_activations()
674
self._set_optimizer_state()
675
self._set_autograd_detail()
678
def timeline(self) -> Tuple[Tuple[int, Action, KeyAndID, int], ...]:
679
output: List[Tuple[int, Action, KeyAndID, int]] = []
680
allocation_times: Dict[Tuple[TensorKey, bool], int] = {}
681
live_unknown: Dict[Tuple[int, torch.device], Literal[True]] = {}
682
for event in self._op_tree.dfs():
683
if event.typed[0] == _EventType.Allocation:
684
alloc_fields = event.typed[1]
685
alloc_size = alloc_fields.alloc_size
686
is_allocation = alloc_size > 0
687
t = event.start_time_ns
689
tkey = TensorKey.from_allocation(alloc_fields)
691
allocation_times[(tkey, is_allocation)] = t
694
key = Key(alloc_fields.device)
695
ptr_and_device = (alloc_fields.ptr, key.device)
697
if ptr_and_device in live_unknown:
699
(t, Action.INCREMENT_VERSION, (key, 0), alloc_size)
702
live_unknown[ptr_and_device] = True
703
output.append((t, Action.CREATE, (key, 0), alloc_size))
705
output.append((t, Action.DESTROY, (key, 0), -alloc_size))
706
if not live_unknown.pop(ptr_and_device, False):
708
(-1, Action.PREEXISTING, (key, 0), -alloc_size)
711
snapshot = self._category_snapshot()
712
last_version = dict(sorted(snapshot.keys()))
714
events: List[Tuple[int, Action, TensorAndID]] = [
715
(-1, Action.PREEXISTING, (key, version))
716
for key, version in snapshot.keys()
717
if (key, True) not in allocation_times and version == 0
720
for node in self._data_flow_graph.flow_nodes:
721
for key, edge in node._edges.items():
722
if edge.is_allocation:
723
t = allocation_times[(key, True)]
724
events.append((t, Action.CREATE, (key, 0)))
727
t = node._event.start_time_ns
728
version = edge.input_version
729
assert version is not None
730
events.append((t, Action.INCREMENT_VERSION, (key, version)))
733
t = allocation_times[(key, False)]
734
events.append((t, Action.DESTROY, (key, last_version[key])))
737
(time, action, (key, version), self._size_map[key])
738
for time, action, (key, version) in events
741
output.sort(key=lambda x: (x[0], x[1].value))
744
def _is_gradient(self, *args, **kwargs) -> bool:
745
return self._categories.get(*args, **kwargs) == Category.GRADIENT
747
def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]:
748
all_tensor_versions: Set[TensorAndID] = set()
750
for node in self._data_flow_graph.flow_nodes:
751
all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
752
all_tensor_versions.update((key, 0) for key in node.intermediates)
753
all_tensor_versions.update(node.outputs.items())
755
for i in self._categories._values.values():
756
all_tensor_versions.update((key, 0) for key in i._by_id_keyset)
759
(key, version): self._categories.get(key, version)
760
for key, version in sorted(all_tensor_versions)
763
def _any_version_depends_on_gradient(self) -> Set[int]:
764
"""Extract IDs of Tensors which depend or will depend on a gradient.
766
Note that this weakened definition of "depends" requires us to loop
767
over the data flow graph multiple times because it allows dependency
768
information to flow backward through edges and removes the guarantee
769
that nodes are topologically sorted. (Or indeed, even that a valid
770
topological order exists.) Put another way, we have converted an
771
acyclic data flow graph into a cyclic graph and we are attempting to
772
partition cycles involving a gradient from the rest of the graph.
774
depends_on_gradient: Set[int] = set()
776
start_size = len(depends_on_gradient)
777
for node in self._data_flow_graph.flow_nodes:
780
for key, (_, version) in node.inputs.items()
781
if self._categories.get(key, version)
782
in (Category.GRADIENT, Category.PARAMETER)
783
or key.id in depends_on_gradient
787
depends_on_gradient.update(ids)
788
depends_on_gradient.update(key.id for key in node.outputs)
790
# We are guaranteed to exit because there is a finite set of
791
# TensorAndID pairs. In practice we do not expect to loop more than
792
# three times: once to identify the core parameter update loop,
793
# once to fold the first step into that loop, and a third time
794
# where no new elements are added.
795
if len(depends_on_gradient) == start_size:
796
return depends_on_gradient
798
def _set_gradients_and_temporaries(self) -> None:
799
"""Mark Tensors which are unambiguous and simple to reason about."""
801
# Gradients are straightforward to detect. We directly check the
802
# `.grad` property in the Python tracer, and we can detect any new
803
# gradient Tensors from `AccumulateGrad` ops.
804
for event in self._op_tree.dfs():
805
for _, p_grad in extract_gradients(event):
806
self._categories.set_by_id(p_grad, Category.GRADIENT)
808
# Similarly, temporary Tensors are easy to identify and are useful to
809
# flag since they can make memory use "spikier" than one would
811
for node in self._data_flow_graph.flow_nodes:
812
for i in node.intermediates:
813
self._categories.set_by_key(i, Category.TEMPORARY)
815
def _set_parameters_using_python_tracer(self) -> None:
816
for event in self._op_tree.dfs():
817
for p in extract_parameters(event):
819
self._categories.set_by_id(p, Category.PARAMETER)
821
def _set_inputs(self) -> None:
822
"""Mark inputs based on which Tensors are updated using gradients.
824
The process for differentiating between inputs and activations is more
825
involved. Most Tensors in a training loop depend on at least one
826
gradient: parameters depend on them through updates, and activations
827
and optimizer state depend on them transitively through parameters.
828
Critically, we do not need to know which Tensors are parameters to
829
apply this method; we can simply walk the data flow graph to build the
830
set of all values which depend on a gradient and then obtain the set
831
of inputs from the conjugate set.
833
There is, however, one hiccup. The first time we see a parameter is
834
generally on the forward pass of the first step. We know from
835
inspection of the data flow graph that v1 of that Tensor depends on
836
a gradient (provided we profile an optimizer step), but not v0. To
837
address this problem we weaken the definition of "depends on a
838
gradient" to "any version of this Tensor depends on a gradient",
839
which in turn strengthens the criteria for the input set enough to
840
filter the activations in the forward pass of the first step."""
842
# All of this analysis is predicated on using at least one training
843
# step (or parameters from the python tracer) to partition the graph.
844
# Absent that we cannot determine which Tensors are inputs and which
845
# ones are part of the model.
846
depends_on_gradient = self._any_version_depends_on_gradient()
848
# We only want to annotate Tensors which actually contribute to the
850
produces_gradient: Set[TensorAndID] = set()
851
for node in reversed(self._data_flow_graph.flow_nodes):
852
tensors = {(key, version) for key, (_, version) in node.inputs.items()}
853
tensors |= node.outputs.items()
855
self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER)
856
or i in produces_gradient
859
produces_gradient |= tensors
861
# Don't include Tensors created in the backward pass, as these are
862
# generally Autograd implementation details rather than proper inputs.
863
input_candidates = produces_gradient.copy()
864
for node in self._data_flow_graph.flow_nodes:
865
if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
866
input_candidates -= set(node.outputs.items())
868
for key, version in input_candidates:
869
if key.id not in depends_on_gradient:
870
self._categories.setdefault_by_version(key, version, Category.INPUT)
872
def _set_parameters_using_data_flow(self) -> None:
873
"""Deduce which Tensors are parameters.
875
Consider the following code for the step of SGD with momentum
876
(nesterov=False), where `d_p` is the gradient of `param` and `buf` is
879
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
881
param.add_(d_p, alpha=-lr)
883
Both `param` and `buf` take a gradient and perform an in-place update.
885
The python tracer will inspect calls to `nn.Module.forward` and
886
`optim.Optimizer.step` to extract parameter and optimizer state
887
respectively (including parameters), so this is generally a non-issue.
889
However as a fallback we can also exploit several properties of
890
parameters to distinguish them from other model state.
892
First, they are directly used in the forward pass. (At this point we
893
haven't established which parts of the graph correspond to the forward
894
pass but we can deduce enough to suffice.) Some mutable state such as
895
batch norm moving averages also contribute to the forward pass, but
896
optimizer state does not.
898
Second, a parameter is by definition used to compute at least one
899
gradient and depends on at least one gradient.
901
snapshot = self._category_snapshot()
903
# Determine which Tensors might be parameters based on forward pass
904
# data flow. Note this these are only candidates; we filter nodes that
905
# we know are part of the backward pass but that doesn't guarantee that
906
# they are part of the forward pass.
907
candidate_parameters: Set[TensorAndID] = set()
908
candidate_fwd_tensors: Set[TensorAndID] = {
909
i for i, category in snapshot.items() if category == Category.INPUT
912
for node in self._data_flow_graph.flow_nodes:
913
inputs = {(key, value) for key, (_, value) in node.inputs.items()}
915
# Don't check nodes in the backward pass.
916
RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
917
and not any(self._is_gradient(*i) for i in inputs)
918
and not any(self._is_gradient(*i) for i in node.outputs.items())
920
# and only check nodes which depend on an input.
921
and candidate_fwd_tensors.intersection(inputs)
923
candidate_fwd_tensors |= node.outputs.items()
924
candidate_parameters |= inputs.difference(candidate_fwd_tensors)
926
# Require that each parameter eventually contributes to the value of a gradient
927
used_for_gradient: Set[TensorAndID] = set()
928
for node in reversed(self._data_flow_graph.flow_nodes):
930
self._is_gradient(*i) or i in used_for_gradient
931
for i in node.outputs.items()
933
for key, (_, version) in node.inputs.items():
934
used_for_gradient.add((key, version))
935
candidate_parameters.intersection_update(used_for_gradient)
937
# and depends on a gradient.
938
parameter_keys = {key.id for key, _ in candidate_parameters}
939
parameter_keys &= self._any_version_depends_on_gradient()
941
for key, _ in snapshot.keys():
942
if key.id in parameter_keys:
943
self._categories.set_by_id(key, Category.PARAMETER)
945
def _set_activations(self) -> None:
946
"""Flood the graph to identify activations."""
948
required = {Category.INPUT, Category.ACTIVATION}
949
also_allowed = {Category.PARAMETER, Category.TEMPORARY}
950
for node in self._data_flow_graph.flow_nodes:
951
inputs = {(key, value) for key, (_, value) in node.inputs.items()}
952
input_categories = {self._categories.get(*i) for i in inputs}
955
(input_categories & required)
956
and not (input_categories - (required | also_allowed))
958
# Stop filling when we reach the backward pass.
959
and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
961
for i in node.outputs.items():
962
self._categories.setdefault_by_version(*i, Category.ACTIVATION)
964
def _set_optimizer_state(self) -> None:
965
for event in self._op_tree.dfs():
966
if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer:
967
parameters = event.typed[1].optimizer.parameters
968
for _, t in it.chain(*[state for _, _, state in parameters]):
969
key = TensorKey.from_tensor(t)
971
self._categories.set_by_id(key, Category.OPTIMIZER_STATE)
973
def _set_autograd_detail(self):
974
prior = {None, Category.AUTOGRAD_DETAIL}
975
for node in self._data_flow_graph.flow_nodes:
976
if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
977
for key, version in node.outputs.items():
978
if version == 0 or self._categories.get(key, version - 1) in prior:
979
self._categories.setdefault_by_version(
980
key, version, Category.AUTOGRAD_DETAIL
984
class MemoryProfileTimeline:
985
def __init__(self, memory_profile):
986
"""The minimum representation of the memory profile timeline
987
includes the memory timeline and categories. The timeline
988
consists of [timestamp, action, (TensorKey, version), numbytes]
989
elements, to denote any actions (pre-existing, create, destroy,
990
or increment_version) that occurred to a specific Tensor for a
991
chunk of memory. The categories help map each (TensorKey,
992
version) pair into a category."""
993
self.timeline = memory_profile.timeline
994
self.categories = memory_profile._categories
996
def _coalesce_timeline(self, device_str):
997
"""Convert the memory timeline and categories into a memory plot
998
consisting of timestamps and their respective sizes by category
1002
Output: [timestamps, sizes by category]
1004
device = torch.device(device_str)
1005
times: List[int] = []
1006
sizes: List[List[int]] = []
1008
def update(key, version, delta):
1010
self.categories.get(key, version)
1011
if isinstance(key, TensorKey)
1014
index = _CATEGORY_TO_INDEX[category] + 1
1015
sizes[-1][index] += int(delta)
1018
for t, action, (key, version), numbytes in self.timeline:
1019
if key.device != device:
1022
# Convert timestamps from ns to us, to match trace events.
1026
# Save the smallest timestamp to populate pre-existing allocs.
1027
if t_min == -1 or (t < t_min and t > 0):
1033
sizes.append([0] + [0 for _ in _CATEGORY_TO_INDEX])
1035
elif t != times[-1]:
1037
sizes.append(sizes[-1].copy())
1039
# Handle memory and categories
1040
if action in (Action.PREEXISTING, Action.CREATE):
1041
update(key, version, numbytes)
1043
elif action == Action.INCREMENT_VERSION:
1044
update(key, version, -numbytes)
1045
update(key, version + 1, numbytes)
1047
elif action == Action.DESTROY:
1048
update(key, version, -numbytes)
1051
raise ValueError(f"Unknown action: {action}")
1053
times = [t_min if t < 0 else t for t in times]
1056
def export_memory_timeline(self, path, device_str) -> None:
1057
"""Saves the memory timeline as [times, sizes by category]
1058
as a JSON formatted file to the given path for the given
1060
times, sizes = self._coalesce_timeline(device_str)
1061
# TODO: Write a faster serialize (orjson not available in CI)
1064
with open(path, "w") as f:
1065
json.dump([times, sizes], f)
1067
def export_memory_timeline_raw(self, path, device_str) -> None:
1068
"""Saves the memory timeline as raw memory event tuples in the
1069
form of (timestamp, action, numbytes, category)
1070
as a JSON formatted file to the given path for the given
1072
device = torch.device(device_str)
1073
raw_events: List[Tuple[int, int, int, int]] = []
1075
def get_category_index(key, version):
1077
self.categories.get(key, version)
1078
if isinstance(key, TensorKey)
1081
return _CATEGORY_TO_INDEX[category]
1083
for t, action, (key, version), numbytes in self.timeline:
1084
if key.device != device:
1087
if action in (Action.PREEXISTING, Action.CREATE):
1091
_ACTION_TO_INDEX[action],
1093
get_category_index(key, version),
1097
elif action == Action.INCREMENT_VERSION:
1101
_ACTION_TO_INDEX[action],
1103
get_category_index(key, version),
1109
_ACTION_TO_INDEX[action],
1111
get_category_index(key, version + 1),
1115
elif action == Action.DESTROY:
1119
_ACTION_TO_INDEX[action],
1121
get_category_index(key, version),
1126
raise ValueError(f"Unknown action: {action}")
1130
with open(path, "w") as f:
1131
json.dump(raw_events, f)
1133
def export_memory_timeline_html(
1134
self, path, device_str, figsize=(20, 12), title=None
1136
"""Exports the memory timeline as an HTML file which contains
1137
the memory timeline plot embedded as a PNG file."""
1138
# Check if user has matplotlib installed, return gracefully if not.
1139
import importlib.util
1141
matplotlib_spec = importlib.util.find_spec("matplotlib")
1142
if matplotlib_spec is None:
1144
"export_memory_timeline_html failed because matplotlib was not found."
1148
from base64 import b64encode
1149
from os import remove
1150
from tempfile import NamedTemporaryFile
1152
import matplotlib.pyplot as plt
1155
mt = self._coalesce_timeline(device_str)
1156
times, sizes = np.array(mt[0]), np.array(mt[1])
1157
# For this timeline, start at 0 to match Chrome traces.
1160
stacked = np.cumsum(sizes, axis=1) / 1024**3
1161
device = torch.device(device_str)
1162
max_memory_allocated = torch.cuda.max_memory_allocated(device)
1163
max_memory_reserved = torch.cuda.max_memory_reserved(device)
1165
# Plot memory timeline as stacked data
1166
fig = plt.figure(figsize=figsize, dpi=80)
1168
for category, color in _CATEGORY_TO_COLORS.items():
1169
i = _CATEGORY_TO_INDEX[category]
1171
times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7
1173
fig.legend(["Unknown" if i is None else i.name for i in _CATEGORY_TO_COLORS])
1174
# Usually training steps are in magnitude of ms.
1175
axes.set_xlabel("Time (ms)")
1176
axes.set_ylabel("Memory (GB)")
1177
title = "\n\n".join(
1178
([title] if title else [])
1180
f"Max memory allocated: {max_memory_allocated/(1024**3):.2f} GiB \n"
1181
f"Max memory reserved: {max_memory_reserved/(1024**3):.2f} GiB"
1184
axes.set_title(title)
1186
# Embed the memory timeline image into the HTML file
1187
tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False)
1189
fig.savefig(tmpfile.name, format="png")
1191
with open(tmpfile.name, "rb") as tmp:
1192
encoded = b64encode(tmp.read()).decode("utf-8")
1194
<head><meta charset="utf-8" /><title>GPU Memory Timeline HTML</title></head>
1196
<img src='data:image/png;base64,{encoded}'>
1200
with open(path, "w") as f:
1202
remove(tmpfile.name)