pytorch

Форк
0
/
_memory_profiler.py 
1202 строки · 47.0 Кб
1
import collections
2
import dataclasses
3
import enum
4
import itertools as it
5
import logging
6
from typing import (
7
    Any,
8
    cast,
9
    DefaultDict,
10
    Dict,
11
    Iterator,
12
    List,
13
    Optional,
14
    Set,
15
    Tuple,
16
    Union,
17
)
18

19
from typing_extensions import Literal
20

21
import torch
22
from torch._C import FunctionSchema
23
from torch._C._autograd import _ProfilerResult
24
from torch._C._profiler import (
25
    _EventType,
26
    _ExtraFields_Allocation,
27
    _ExtraFields_TorchOp,
28
    _ProfilerEvent,
29
    _TensorMetadata,
30
    RecordScope,
31
)
32
from torch._utils import _element_size
33
from torch.profiler import _utils
34

35
KeyAndID = Tuple["Key", int]
36
TensorAndID = Tuple["TensorKey", int]
37

38
log = logging.getLogger(__name__)
39

40

41
class Category(enum.Enum):
42
    INPUT = enum.auto()
43
    TEMPORARY = enum.auto()
44
    ACTIVATION = enum.auto()
45
    GRADIENT = enum.auto()
46
    AUTOGRAD_DETAIL = enum.auto()
47
    PARAMETER = enum.auto()
48
    OPTIMIZER_STATE = enum.auto()
49

50

51
_CATEGORY_TO_COLORS = {
52
    Category.PARAMETER: "darkgreen",
53
    Category.OPTIMIZER_STATE: "goldenrod",
54
    Category.INPUT: "black",
55
    Category.TEMPORARY: "mediumpurple",
56
    Category.ACTIVATION: "red",
57
    Category.GRADIENT: "mediumblue",
58
    Category.AUTOGRAD_DETAIL: "royalblue",
59
    None: "grey",
60
}
61

62
_CATEGORY_TO_INDEX = {c: i for i, c in enumerate(_CATEGORY_TO_COLORS)}
63

64

65
class Action(enum.Enum):
66
    PREEXISTING = enum.auto()
67
    CREATE = enum.auto()
68
    INCREMENT_VERSION = enum.auto()
69
    DESTROY = enum.auto()
70

71

72
_ACTION_TO_INDEX = {i: i.value for i in Action}
73

74

75
@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True)
76
class Key:
77
    device: torch.device
78

79

80
@dataclasses.dataclass
81
class _Storage:
82
    """Bundle storage pointer and id.
83

84
    All profiling logic should use `allocation_id`, however it is useful to
85
    print storage pointers for debugging and unit tests sometimes look up
86
    values using the storage data pointer of a live Tensor."""
87

88
    ptr: int
89
    allocation_id: int
90

91
    def __repr__(self) -> str:
92
        return f"{hex(self.ptr):>18} ({self.allocation_id})"
93

94
    def __eq__(self, other: object) -> bool:
95
        return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
96

97
    def __hash__(self) -> int:
98
        return hash(self.allocation_id)
99

100

101
@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
102
class TensorKey(Key):
103
    """Hashable identifier for a storage which has been asigned an ID.
104

105
    A detailed description of Tensor IDs and why they are needed is given in
106
    `torch/csrc/profiler/collection.h` when `TensorID` is declared. To
107
    summarize, multiple Storage buffers can map to the same logical Tensor.
108
    This dataclass is used to refer to a concrete in-memory StorageImpl of
109
    a Tensor.
110
    """
111

112
    id: int
113
    storage: _Storage
114

115
    def __repr__(self) -> str:
116
        return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
117

118
    def __lt__(self, other: "TensorKey") -> bool:
119
        return self._as_sortable < other._as_sortable
120

121
    @staticmethod
122
    def _make(
123
        tensor_id: Optional[int],
124
        storage_ptr: Optional[int],
125
        allocation_id: Optional[int],
126
        device: torch.device,
127
    ) -> Optional["TensorKey"]:
128
        if (
129
            tensor_id is not None
130
            and storage_ptr is not None
131
            and allocation_id is not None
132
        ):
133
            return TensorKey(device, tensor_id, _Storage(storage_ptr, allocation_id))
134
        return None
135

136
    @classmethod
137
    def from_allocation(cls, alloc: _ExtraFields_Allocation) -> Optional["TensorKey"]:
138
        return cls._make(alloc.id, alloc.ptr, alloc.allocation_id, alloc.device)
139

140
    @classmethod
141
    def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
142
        if t is not None:
143
            return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
144
        return None
145

146
    @property
147
    def _as_sortable(self) -> Tuple[int, int, str, int]:
148
        return self.id, self.storage.allocation_id, self.device.type, self.device.index
149

150

151
def _extract_parameters_and_gradients(
152
    node: _ProfilerEvent,
153
) -> Iterator[Tuple[Optional[TensorKey], Optional[TensorKey]]]:
154
    children = node.children
155

156
    # AccumulateGrad is used in the Autograd engine to handle gradient updates.
157
    # There are two possible cases:
158
    # 1) This is a newly created gradient Tensor. In that case there is nothing
159
    #    to accumulate, so autograd simply detaches the Tensor.
160
    #
161
    # 2) There is a preexisting gradient Tensor and we need to add the newly
162
    #    computed update. This is done with an in-place add (aten::add_) op.
163
    #    (The underscore suffix denotes "in-place".)
164
    if (
165
        node.typed[0] == _EventType.TorchOp
166
        and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
167
        # TODO(robieta): Move away from load bearing names
168
        and node.name == "torch::autograd::AccumulateGrad"
169
        and children
170
        and children[0].typed[0] == _EventType.TorchOp
171
        and children[0].name in ("aten::detach", "aten::add_")
172
        and children[0].typed[1].inputs
173
        and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
174
    ):
175
        yield None, TensorKey.from_tensor(children[0].typed[1].inputs[0])
176

177
    # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
178
    # NOTE: The values captured by the python tracer are cached; they can be
179
    #       used to build up labels but do not imply that a Tensor was live at
180
    #       a particular time.
181
    elif node.typed[0] == _EventType.PyCall:
182
        typed_fields = node.typed[1]
183
        assert typed_fields.module is None or typed_fields.optimizer is None
184
        if typed_fields.module is not None:
185
            for _, p, p_grad in typed_fields.module.parameters:
186
                yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
187

188
        if typed_fields.optimizer is not None:
189
            for p, p_grad, _ in typed_fields.optimizer.parameters:
190
                yield TensorKey.from_tensor(p), TensorKey.from_tensor(p_grad)
191

192

193
def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]:
194
    for p, p_grad in _extract_parameters_and_gradients(node):
195
        if p is not None:
196
            yield p
197

198

199
def extract_gradients(
200
    node: _ProfilerEvent,
201
) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
202
    for p, p_grad in _extract_parameters_and_gradients(node):
203
        if p_grad is not None:
204
            yield p, p_grad
205

206

207
def get_scopes(event: Optional[_ProfilerEvent]) -> Tuple[RecordScope, ...]:
208
    scopes = []
209
    while event:
210
        if event.typed[0] == _EventType.TorchOp:
211
            scopes.append(event.typed[1].scope)
212
        event = event.parent
213
    return tuple(scopes)
214

215

216
class SchemaMatcher:
217
    """Lookup operator schema based on profiled name.
218

219
    When profiling we record the operator's name but not the schema. However
220
    some analysis requires that information. Fortunately we can look up
221
    registered schema from the recorded name. We do not, however, record the
222
    overload and so we must compare the profiled arguments with all overloads
223
    to determine viable matches.
224

225
    Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed
226
    this code will be obsolete.
227
    """
228

229
    @classmethod
230
    def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[Optional[bool], ...]:
231
        """Determine which inputs may have mutated based on function schema.
232

233
        Note that we don't need to resolve down to a single schema to perform
234
        this analysis. An input is mutable if it is mutable in any overload. In
235
        practice, however, it is overwhelmingly common to match a single
236
        overload. If we cannot find any valid schema then we must be
237
        conservative and assume all inputs are mutable.
238
        """
239
        mutable: Optional[List[bool]] = None
240
        for schema in cls.match_schemas(t):
241
            mutable = mutable or [False for _ in schema.arguments]
242
            for i, arg in enumerate(schema.arguments):
243
                mutable[i] |= getattr(arg.alias_info, "is_write", False)
244

245
        return tuple(mutable or (None for _ in t.inputs))
246

247
    @classmethod
248
    def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]:
249
        signature = tuple(
250
            # Tensor
251
            TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata)
252
            #
253
            # TensorList
254
            else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list)
255
            #
256
            # Scalar and uncaptured inputs.
257
            else i
258
            for i in t.inputs
259
        )
260

261
        def matches(schema) -> bool:
262
            return len(schema.arguments) == len(signature) and all(
263
                cls._types_match(observed, schema_arg.type)
264
                for observed, schema_arg in zip(signature, schema.arguments)
265
            )
266

267
        return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
268

269
    @classmethod
270
    def _types_match(cls, observed, schema_type) -> bool:
271
        if isinstance(schema_type, torch._C.OptionalType):
272
            schema_type = schema_type.getElementType()
273
            return observed is None or cls._types_match(observed, schema_type)
274

275
        if isinstance(schema_type, torch._C.AnyType):
276
            return True
277

278
        if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()):
279
            return isinstance(observed, list) and all(
280
                isinstance(i, TensorKey) for i in observed
281
            )
282

283
        type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = (
284
            (torch._C.TensorType, TensorKey),
285
            (torch._C.NoneType, type(None)),
286
            (torch._C.BoolType, bool),
287
            (torch._C.IntType, int),
288
            (torch._C.FloatType, float),
289
            (torch._C.ComplexType, complex),
290
            (torch._C.NumberType, (bool, int, float, complex)),
291
        )
292

293
        for jit_type, py_types in type_map:
294
            if isinstance(schema_type, jit_type):
295
                return isinstance(observed, py_types)
296

297
        # Profiler only records a subset of possible argument types. If we
298
        # reach this point then the schema must call for a type that profiler
299
        # does not record. Thus, the schema can only be a match if `observed`
300
        # is also None.
301
        return observed is None
302

303
    @staticmethod
304
    def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]:
305
        # TODO(robieta):
306
        #   _jit_get_schemas_for_operator is quite expensive. (~100us / call)
307
        #   Consider adding `functools.lru_cache` if that becomes an issue.
308

309
        try:
310
            # Schema lookup will throw if `name` is malformed. (For example,
311
            # schemas must be namespaced and schema lookup will fail if name
312
            # does not include "::".) We simply catch the exception and return
313
            # `None` to denote that `name` cannot be an operator name.
314
            #
315
            # Note that record_function annotations also go through this path,
316
            # so it is expected that some names will not correspond to PyTorch
317
            # operators.
318
            if "::" not in name:
319
                return None
320
            return tuple(torch._C._jit_get_schemas_for_operator(name))
321
        except RuntimeError:
322
            return None
323

324

325
class OpTree:
326
    def __init__(self, result: _ProfilerResult) -> None:
327
        self._root_nodes = result.experimental_event_tree()
328
        self._sorted_nodes = tuple(sorted(self.dfs(), key=lambda x: x.start_time_ns))
329

330
    def dfs(self, *args, **kwargs) -> Iterator[_ProfilerEvent]:
331
        yield from _utils.traverse_dfs(self._root_nodes, *args, **kwargs)
332

333
    @property
334
    def sorted_nodes(self) -> Tuple[_ProfilerEvent, ...]:
335
        return self._sorted_nodes
336

337

338
class SizeMap:
339
    def __init__(self, op_tree: OpTree) -> None:
340
        self._values: Dict[TensorKey, int] = {}
341

342
        for node in op_tree.sorted_nodes:
343
            if node.typed[0] == _EventType.TorchOp:
344
                for t in self._flat_tensor_inputs(node.typed[1]):
345
                    self._update_values(t)
346

347
            elif node.typed[0] == _EventType.PyCall:
348
                typed_fields = node.typed[1]
349
                assert typed_fields.module is None or typed_fields.optimizer is None
350
                if typed_fields.module is not None:
351
                    for _, p, p_grad in typed_fields.module.parameters:
352
                        self._update_values(p)
353
                        self._update_values(p_grad)
354

355
                if typed_fields.optimizer is not None:
356
                    for p, p_grad, state in typed_fields.optimizer.parameters:
357
                        self._update_values(p)
358
                        self._update_values(p_grad)
359
                        for _, t in state:
360
                            self._update_values(t)
361

362
        allocations: Dict[TensorKey, int] = {}
363
        for node in op_tree.sorted_nodes:
364
            if node.typed[0] == _EventType.Allocation:
365
                alloc_fields = node.typed[1]
366
                key = TensorKey.from_allocation(alloc_fields)
367
                if key:
368
                    new_size = abs(alloc_fields.alloc_size)
369
                    prior_size = allocations.setdefault(key, new_size)
370

371
                    # It is possible to resize Storage in PyTorch, however we
372
                    # key on data pointer so most resizes will be treated as a
373
                    # change in storage. The one corner case that cannot be
374
                    # handled is `realloc` which successfully resizes the
375
                    # storage. At time of writing this is not done anywhere in
376
                    # the core PyTorch codebase.
377
                    if prior_size != new_size:
378
                        delta = f"{prior_size} vs. {new_size}"
379
                        log.warning("Mismatch between allocation and free: %s", delta)
380

381
        self._values.update(allocations)
382

383
    def _update_values(self, t: Optional[_TensorMetadata]) -> None:
384
        key = TensorKey.from_tensor(t)
385
        if key is not None and t is not None and t.layout == torch.strided:
386
            # Scalars are represented as zero dim Tensors
387
            n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
388

389
            num_bytes = n * _element_size(t.dtype)
390
            assert num_bytes >= 0, f"{num_bytes}"
391
            self._values[key] = max(self._values.get(key, 0), num_bytes)
392

393
    @staticmethod
394
    def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]:
395
        for i in op.inputs:
396
            if isinstance(i, _TensorMetadata):
397
                yield i
398
            elif isinstance(i, list):
399
                yield from i
400

401
    def __getitem__(self, key: TensorKey):
402
        return self._values[key]
403

404

405
@dataclasses.dataclass()
406
class DataFlowEdge:
407
    input_version: Optional[int] = None
408
    mutated: Optional[bool] = False
409

410
    @property
411
    def is_allocation(self) -> bool:
412
        return self.input_version is None
413

414
    @property
415
    def is_deletion(self) -> bool:
416
        return self.mutated is None
417

418

419
class DataFlowNode:
420
    def __init__(self, event: _ProfilerEvent, graph: "DataFlowGraph") -> None:
421
        self._event = event
422
        self._graph = graph
423
        self._edges: Dict[TensorKey, DataFlowEdge] = self._determine_edges()
424

425
        for key, edge in self._edges.items():
426
            if edge.mutated and not edge.is_allocation:
427
                self._graph.bump(key)
428

429
        # Make sure the version bumping behavior matches what we expect.
430
        versions = {k: (v, self._graph.lookup(k)) for k, v in self.outputs.items()}
431
        assert all(i == j for i, j in versions.values()), f"{versions}, {self._edges}"
432

433
    def _determine_edges(self) -> Dict[TensorKey, DataFlowEdge]:
434
        subtree = tuple(_utils.traverse_dfs([self._event]))
435

436
        # Start by populating edges from op inputs and outputs.
437
        mutable_by_key: Dict[Optional[TensorKey], Set[Optional[bool]]] = {}
438
        for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
439
            for op_input, mutable in zip(
440
                op.inputs, SchemaMatcher.inputs_are_mutable(op)
441
            ):
442
                # Tensor
443
                if isinstance(op_input, _TensorMetadata):
444
                    key = TensorKey.from_tensor(op_input)
445
                    mutable_by_key.setdefault(key, set()).add(mutable)
446

447
                # TensorList
448
                elif isinstance(op_input, list):
449
                    for op_input_i in op_input:
450
                        key = TensorKey.from_tensor(op_input_i)
451
                        mutable_by_key.setdefault(key, set()).add(mutable)
452

453
        edges: DefaultDict[Optional[TensorKey], DataFlowEdge]
454
        edges = collections.defaultdict(DataFlowEdge)
455
        for key, mutable_set in mutable_by_key.items():
456
            if key is not None:
457
                edges[key].input_version = self._graph.lookup(key) if key else -1
458

459
                # We consider an op to be mutated if we encounter a schema where it
460
                # is a mutable argument OR if it is ambiguous. (We never explicitly
461
                # see it in any schema.)
462
                mutated = (True in mutable_set) or (tuple(mutable_set) == (None,))
463
                edges[key].mutated = mutated
464

465
        # Then handle deletions. Note that deleting a Tensor implicitly adds
466
        # it as an input edge.
467
        for i in subtree:
468
            if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size < 0:
469
                key = TensorKey.from_allocation(i.typed[1])
470
                edge = edges[key]
471
                assert key is None or edge.mutated is not None, f"Double delete: {key}"
472
                edge.mutated = None
473
                edge.input_version = self._graph.lookup(key) if key else -1
474

475
        # And finally handle allocations. This step must be last, because the
476
        # previous two steps optimistically add input edges.
477
        for i in subtree:
478
            if i.typed[0] == _EventType.Allocation and i.typed[1].alloc_size > 0:
479
                edges[TensorKey.from_allocation(i.typed[1])].input_version = None
480

481
        # We don't need to sort the inputs, but it makes debugging and unit tests nicer.
482
        return dict(sorted((k, v) for k, v in edges.items() if k is not None))
483

484
    @property
485
    def inputs(self) -> Dict[TensorKey, Tuple[bool, int]]:
486
        return {
487
            # MyPy can't see through `is_allocation` to know that
488
            # `v.input_version` is not None.
489
            k: (bool(v.mutated), cast(int, v.input_version))
490
            for k, v in self._edges.items()
491
            if not v.is_allocation
492
        }
493

494
    @property
495
    def outputs(self) -> Dict[TensorKey, int]:
496
        return {
497
            k: 0 if v.input_version is None else v.input_version + 1
498
            for k, v in self._edges.items()
499
            if (v.is_allocation and not v.is_deletion) or v.mutated
500
        }
501

502
    @property
503
    def intermediates(self) -> Tuple[TensorKey, ...]:
504
        return tuple(
505
            k for k, v in self._edges.items() if v.is_allocation and v.is_deletion
506
        )
507

508
    @property
509
    def start_time(self) -> int:
510
        return self._event.start_time_ns
511

512

513
class DataFlowGraph:
514
    def __init__(self, op_tree: OpTree) -> None:
515
        self._op_tree = op_tree
516
        self._leaf_events = self._extract_leaf_events(op_tree)
517
        self._active_version: Dict[TensorKey, Optional[int]] = {}
518
        self._flow_nodes = [DataFlowNode(e, self) for e in self.leaf_events]
519
        self._flow_nodes.sort(key=lambda x: x.start_time)
520
        self.validate()
521

522
    @property
523
    def flow_nodes(self) -> Tuple[DataFlowNode, ...]:
524
        return tuple(self._flow_nodes)
525

526
    def validate(self):
527
        # Check that each (Tensor, version) pair has a unique creation node
528
        outputs: Set[Tuple[TensorKey, int]] = set()
529
        for node in self.flow_nodes:
530
            node_outputs = set(node.outputs.items())
531
            duplicates = outputs & node_outputs
532
            assert not duplicates, f"{node._event.name} {node._edges} {duplicates}"
533
            outputs |= node_outputs
534

535
        # And check that `self._nodes` forms a valid topologically sorted DAG.
536
        tensor_versions: Dict[TensorKey, int] = {}
537
        for node in self.flow_nodes:
538
            for key, (_, version) in node.inputs.items():
539
                expected = tensor_versions.get(key, 0)
540
                assert expected == version, (expected, version)
541

542
            for key, version in node.outputs.items():
543
                prior_version = tensor_versions.get(key, version)
544
                assert version >= prior_version, (version, prior_version)
545
                tensor_versions[key] = version
546

547
    @property
548
    def leaf_events(self) -> Tuple[_ProfilerEvent, ...]:
549
        return self._leaf_events
550

551
    @staticmethod
552
    def _extract_leaf_events(op_tree: OpTree) -> Tuple[_ProfilerEvent, ...]:
553
        """Partially traverse the op tree and extract top level ops.
554

555
        Consider the following code:
556
        ```
557
        with record_function("My annotation"):
558
            x.zero_()
559
            y.zero_()
560
        ```
561

562
        The op tree (assuming no Autograd) will look like:
563
          <Python context>
564
            TorchOp: "My annotation"
565
              TorchOp: zero_
566
                TorchOp: fill_
567
              TorchOp: zero_
568
                TorchOp: fill_
569

570
        The recursive structure of operator calls makes data flow unwieldy.
571
        In order to simplify analysis we would like to select the highest level
572
        ops to represent in the graph. In this case those are the `zero_` ops;
573
        the fact that `fill_` is called is an implementation detail. We also
574
        do not want to group everything under "My annotation" as this could
575
        create overly coarse bundles and lose critical semantics.
576

577
        To address this issue we walk over the graph and select the topmost
578
        torch ops ** which match at least one operator schema **. These form
579
        the leaves of the first pass through the op tree. (As well as any
580
        allocations or frees which do are not part of a kernel.) These events
581
        form the logical nodes in our data flow graph.
582
        """
583

584
        leaf_events: List[_ProfilerEvent] = []
585

586
        def leaf_op(e: _ProfilerEvent) -> bool:
587
            return e.typed[0] == _EventType.TorchOp and (
588
                e.typed[1].scope == RecordScope.BACKWARD_FUNCTION
589
                or bool(SchemaMatcher.match_schemas(e.typed[1]))
590
            )
591

592
        def children_fn(e: _ProfilerEvent):
593
            if leaf_op(e) or e.tag == _EventType.Allocation:
594
                leaf_events.append(e)
595
                return []
596

597
            return e.children
598

599
        for _ in op_tree.dfs(children_fn=children_fn):
600
            pass
601

602
        return tuple(sorted(leaf_events, key=lambda x: x.start_time_ns))
603

604
    def lookup(self, key: TensorKey) -> int:
605
        version = self._active_version.setdefault(key, 0)
606
        assert version is not None
607
        return version
608

609
    def bump(self, key: TensorKey) -> None:
610
        prior_version = self._active_version.get(key, None)
611
        assert prior_version is not None
612
        self._active_version[key] = prior_version + 1
613

614
    def delete(self, key: TensorKey) -> None:
615
        assert self._active_version.setdefault(key, 0) is not None
616
        self._active_version[key] = None
617

618

619
@dataclasses.dataclass
620
class CategoryElement:
621
    by_id: Optional[Category] = None
622
    by_key: Dict[TensorKey, Category] = dataclasses.field(default_factory=dict)
623
    by_version: Dict[TensorAndID, Category] = dataclasses.field(default_factory=dict)
624

625
    # Used by unit tests to check internals. (And consequently by
626
    # MemoryProfile.lookup) This should not be used in any other capacity.
627
    _by_id_keyset: Set[TensorKey] = dataclasses.field(default_factory=set)
628

629

630
@dataclasses.dataclass
631
class CategoryDict:
632
    _values: DefaultDict[int, CategoryElement] = dataclasses.field(
633
        default_factory=lambda: collections.defaultdict(CategoryElement)
634
    )
635

636
    def set_by_id(self, key: TensorKey, category: Category) -> None:
637
        self._values[key.id].by_id = category
638
        self._values[key.id]._by_id_keyset.add(key)
639

640
    def set_by_key(self, key: TensorKey, category: Category) -> None:
641
        self._values[key.id].by_key[key] = category
642

643
    def set_by_version(self, key: TensorKey, version: int, category: Category) -> None:
644
        self._values[key.id].by_version[(key, version)] = category
645

646
    def setdefault_by_version(
647
        self, key: TensorKey, version: int, category: Category
648
    ) -> None:
649
        self._values[key.id].by_version.setdefault((key, version), category)
650

651
    def get(self, key: Key, version: int) -> Optional[Category]:
652
        if isinstance(key, Key) and not isinstance(key, TensorKey):
653
            return None
654
        element = self._values[key.id]
655
        return (
656
            element.by_id
657
            or element.by_key.get(key, None)
658
            or element.by_version.get((key, version), None)
659
        )
660

661

662
class MemoryProfile:
663
    def __init__(self, result: _ProfilerResult) -> None:
664
        self._op_tree = OpTree(result)
665
        self._data_flow_graph = DataFlowGraph(self._op_tree)
666
        self._size_map = SizeMap(self._op_tree)
667
        self._categories = CategoryDict()
668

669
        self._set_gradients_and_temporaries()
670
        self._set_parameters_using_python_tracer()
671
        self._set_inputs()
672
        self._set_parameters_using_data_flow()
673
        self._set_activations()
674
        self._set_optimizer_state()
675
        self._set_autograd_detail()
676

677
    @property
678
    def timeline(self) -> Tuple[Tuple[int, Action, KeyAndID, int], ...]:
679
        output: List[Tuple[int, Action, KeyAndID, int]] = []
680
        allocation_times: Dict[Tuple[TensorKey, bool], int] = {}
681
        live_unknown: Dict[Tuple[int, torch.device], Literal[True]] = {}
682
        for event in self._op_tree.dfs():
683
            if event.typed[0] == _EventType.Allocation:
684
                alloc_fields = event.typed[1]
685
                alloc_size = alloc_fields.alloc_size
686
                is_allocation = alloc_size > 0
687
                t = event.start_time_ns
688

689
                tkey = TensorKey.from_allocation(alloc_fields)
690
                if tkey is not None:
691
                    allocation_times[(tkey, is_allocation)] = t
692

693
                else:
694
                    key = Key(alloc_fields.device)
695
                    ptr_and_device = (alloc_fields.ptr, key.device)
696
                    if is_allocation:
697
                        if ptr_and_device in live_unknown:
698
                            output.append(
699
                                (t, Action.INCREMENT_VERSION, (key, 0), alloc_size)
700
                            )
701
                        else:
702
                            live_unknown[ptr_and_device] = True
703
                            output.append((t, Action.CREATE, (key, 0), alloc_size))
704
                    else:
705
                        output.append((t, Action.DESTROY, (key, 0), -alloc_size))
706
                        if not live_unknown.pop(ptr_and_device, False):
707
                            output.append(
708
                                (-1, Action.PREEXISTING, (key, 0), -alloc_size)
709
                            )
710

711
        snapshot = self._category_snapshot()
712
        last_version = dict(sorted(snapshot.keys()))
713

714
        events: List[Tuple[int, Action, TensorAndID]] = [
715
            (-1, Action.PREEXISTING, (key, version))
716
            for key, version in snapshot.keys()
717
            if (key, True) not in allocation_times and version == 0
718
        ]
719

720
        for node in self._data_flow_graph.flow_nodes:
721
            for key, edge in node._edges.items():
722
                if edge.is_allocation:
723
                    t = allocation_times[(key, True)]
724
                    events.append((t, Action.CREATE, (key, 0)))
725

726
                elif edge.mutated:
727
                    t = node._event.start_time_ns
728
                    version = edge.input_version
729
                    assert version is not None
730
                    events.append((t, Action.INCREMENT_VERSION, (key, version)))
731

732
                if edge.is_deletion:
733
                    t = allocation_times[(key, False)]
734
                    events.append((t, Action.DESTROY, (key, last_version[key])))
735

736
        output.extend(
737
            (time, action, (key, version), self._size_map[key])
738
            for time, action, (key, version) in events
739
        )
740

741
        output.sort(key=lambda x: (x[0], x[1].value))
742
        return tuple(output)
743

744
    def _is_gradient(self, *args, **kwargs) -> bool:
745
        return self._categories.get(*args, **kwargs) == Category.GRADIENT
746

747
    def _category_snapshot(self) -> Dict[TensorAndID, Optional[Category]]:
748
        all_tensor_versions: Set[TensorAndID] = set()
749

750
        for node in self._data_flow_graph.flow_nodes:
751
            all_tensor_versions.update(((k, v) for k, (_, v) in node.inputs.items()))
752
            all_tensor_versions.update((key, 0) for key in node.intermediates)
753
            all_tensor_versions.update(node.outputs.items())
754

755
        for i in self._categories._values.values():
756
            all_tensor_versions.update((key, 0) for key in i._by_id_keyset)
757

758
        return {
759
            (key, version): self._categories.get(key, version)
760
            for key, version in sorted(all_tensor_versions)
761
        }
762

763
    def _any_version_depends_on_gradient(self) -> Set[int]:
764
        """Extract IDs of Tensors which depend or will depend on a gradient.
765

766
        Note that this weakened definition of "depends" requires us to loop
767
        over the data flow graph multiple times because it allows dependency
768
        information to flow backward through edges and removes the guarantee
769
        that nodes are topologically sorted. (Or indeed, even that a valid
770
        topological order exists.) Put another way, we have converted an
771
        acyclic data flow graph into a cyclic graph and we are attempting to
772
        partition cycles involving a gradient from the rest of the graph.
773
        """
774
        depends_on_gradient: Set[int] = set()
775
        while True:
776
            start_size = len(depends_on_gradient)
777
            for node in self._data_flow_graph.flow_nodes:
778
                ids = tuple(
779
                    key.id
780
                    for key, (_, version) in node.inputs.items()
781
                    if self._categories.get(key, version)
782
                    in (Category.GRADIENT, Category.PARAMETER)
783
                    or key.id in depends_on_gradient
784
                )
785

786
                if ids:
787
                    depends_on_gradient.update(ids)
788
                    depends_on_gradient.update(key.id for key in node.outputs)
789

790
            # We are guaranteed to exit because there is a finite set of
791
            # TensorAndID pairs. In practice we do not expect to loop more than
792
            # three times: once to identify the core parameter update loop,
793
            # once to fold the first step into that loop, and a third time
794
            # where no new elements are added.
795
            if len(depends_on_gradient) == start_size:
796
                return depends_on_gradient
797

798
    def _set_gradients_and_temporaries(self) -> None:
799
        """Mark Tensors which are unambiguous and simple to reason about."""
800

801
        # Gradients are straightforward to detect. We directly check the
802
        # `.grad` property in the Python tracer, and we can detect any new
803
        # gradient Tensors from `AccumulateGrad` ops.
804
        for event in self._op_tree.dfs():
805
            for _, p_grad in extract_gradients(event):
806
                self._categories.set_by_id(p_grad, Category.GRADIENT)
807

808
        # Similarly, temporary Tensors are easy to identify and are useful to
809
        # flag since they can make memory use "spikier" than one would
810
        # otherwise expect.
811
        for node in self._data_flow_graph.flow_nodes:
812
            for i in node.intermediates:
813
                self._categories.set_by_key(i, Category.TEMPORARY)
814

815
    def _set_parameters_using_python_tracer(self) -> None:
816
        for event in self._op_tree.dfs():
817
            for p in extract_parameters(event):
818
                if p is not None:
819
                    self._categories.set_by_id(p, Category.PARAMETER)
820

821
    def _set_inputs(self) -> None:
822
        """Mark inputs based on which Tensors are updated using gradients.
823

824
        The process for differentiating between inputs and activations is more
825
        involved. Most Tensors in a training loop depend on at least one
826
        gradient: parameters depend on them through updates, and activations
827
        and optimizer state depend on them transitively through parameters.
828
        Critically, we do not need to know which Tensors are parameters to
829
        apply this method; we can simply walk the data flow graph to build the
830
        set of all values which depend on a gradient and then obtain the set
831
        of inputs from the conjugate set.
832

833
        There is, however, one hiccup. The first time we see a parameter is
834
        generally on the forward pass of the first step. We know from
835
        inspection of the data flow graph that v1 of that Tensor depends on
836
        a gradient (provided we profile an optimizer step), but not v0. To
837
        address this problem we weaken the definition of "depends on a
838
        gradient" to "any version of this Tensor depends on a gradient",
839
        which in turn strengthens the criteria for the input set enough to
840
        filter the activations in the forward pass of the first step."""
841

842
        # All of this analysis is predicated on using at least one training
843
        # step (or parameters from the python tracer) to partition the graph.
844
        # Absent that we cannot determine which Tensors are inputs and which
845
        # ones are part of the model.
846
        depends_on_gradient = self._any_version_depends_on_gradient()
847

848
        # We only want to annotate Tensors which actually contribute to the
849
        # model calculation.
850
        produces_gradient: Set[TensorAndID] = set()
851
        for node in reversed(self._data_flow_graph.flow_nodes):
852
            tensors = {(key, version) for key, (_, version) in node.inputs.items()}
853
            tensors |= node.outputs.items()
854
            if any(
855
                self._categories.get(*i) in (Category.GRADIENT, Category.PARAMETER)
856
                or i in produces_gradient
857
                for i in tensors
858
            ):
859
                produces_gradient |= tensors
860

861
        # Don't include Tensors created in the backward pass, as these are
862
        # generally Autograd implementation details rather than proper inputs.
863
        input_candidates = produces_gradient.copy()
864
        for node in self._data_flow_graph.flow_nodes:
865
            if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
866
                input_candidates -= set(node.outputs.items())
867

868
        for key, version in input_candidates:
869
            if key.id not in depends_on_gradient:
870
                self._categories.setdefault_by_version(key, version, Category.INPUT)
871

872
    def _set_parameters_using_data_flow(self) -> None:
873
        """Deduce which Tensors are parameters.
874

875
        Consider the following code for the step of SGD with momentum
876
        (nesterov=False), where `d_p` is the gradient of `param` and `buf` is
877
        the momentum buffer.
878
        ```
879
          buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
880
          d_p = buf
881
          param.add_(d_p, alpha=-lr)
882
        ```
883
        Both `param` and `buf` take a gradient and perform an in-place update.
884

885
        The python tracer will inspect calls to `nn.Module.forward` and
886
        `optim.Optimizer.step` to extract parameter and optimizer state
887
        respectively (including parameters), so this is generally a non-issue.
888

889
        However as a fallback we can also exploit several properties of
890
        parameters to distinguish them from other model state.
891

892
        First, they are directly used in the forward pass. (At this point we
893
        haven't established which parts of the graph correspond to the forward
894
        pass but we can deduce enough to suffice.) Some mutable state such as
895
        batch norm moving averages also contribute to the forward pass, but
896
        optimizer state does not.
897

898
        Second, a parameter is by definition used to compute at least one
899
        gradient and depends on at least one gradient.
900
        """
901
        snapshot = self._category_snapshot()
902

903
        # Determine which Tensors might be parameters based on forward pass
904
        # data flow. Note this these are only candidates; we filter nodes that
905
        # we know are part of the backward pass but that doesn't guarantee that
906
        # they are part of the forward pass.
907
        candidate_parameters: Set[TensorAndID] = set()
908
        candidate_fwd_tensors: Set[TensorAndID] = {
909
            i for i, category in snapshot.items() if category == Category.INPUT
910
        }
911

912
        for node in self._data_flow_graph.flow_nodes:
913
            inputs = {(key, value) for key, (_, value) in node.inputs.items()}
914
            if (
915
                # Don't check nodes in the backward pass.
916
                RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
917
                and not any(self._is_gradient(*i) for i in inputs)
918
                and not any(self._is_gradient(*i) for i in node.outputs.items())
919
                #
920
                # and only check nodes which depend on an input.
921
                and candidate_fwd_tensors.intersection(inputs)
922
            ):
923
                candidate_fwd_tensors |= node.outputs.items()
924
                candidate_parameters |= inputs.difference(candidate_fwd_tensors)
925

926
        # Require that each parameter eventually contributes to the value of a gradient
927
        used_for_gradient: Set[TensorAndID] = set()
928
        for node in reversed(self._data_flow_graph.flow_nodes):
929
            if any(
930
                self._is_gradient(*i) or i in used_for_gradient
931
                for i in node.outputs.items()
932
            ):
933
                for key, (_, version) in node.inputs.items():
934
                    used_for_gradient.add((key, version))
935
        candidate_parameters.intersection_update(used_for_gradient)
936

937
        # and depends on a gradient.
938
        parameter_keys = {key.id for key, _ in candidate_parameters}
939
        parameter_keys &= self._any_version_depends_on_gradient()
940

941
        for key, _ in snapshot.keys():
942
            if key.id in parameter_keys:
943
                self._categories.set_by_id(key, Category.PARAMETER)
944

945
    def _set_activations(self) -> None:
946
        """Flood the graph to identify activations."""
947

948
        required = {Category.INPUT, Category.ACTIVATION}
949
        also_allowed = {Category.PARAMETER, Category.TEMPORARY}
950
        for node in self._data_flow_graph.flow_nodes:
951
            inputs = {(key, value) for key, (_, value) in node.inputs.items()}
952
            input_categories = {self._categories.get(*i) for i in inputs}
953

954
            if (
955
                (input_categories & required)
956
                and not (input_categories - (required | also_allowed))
957
                #
958
                # Stop filling when we reach the backward pass.
959
                and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
960
            ):
961
                for i in node.outputs.items():
962
                    self._categories.setdefault_by_version(*i, Category.ACTIVATION)
963

964
    def _set_optimizer_state(self) -> None:
965
        for event in self._op_tree.dfs():
966
            if event.typed[0] == _EventType.PyCall and event.typed[1].optimizer:
967
                parameters = event.typed[1].optimizer.parameters
968
                for _, t in it.chain(*[state for _, _, state in parameters]):
969
                    key = TensorKey.from_tensor(t)
970
                    if key is not None:
971
                        self._categories.set_by_id(key, Category.OPTIMIZER_STATE)
972

973
    def _set_autograd_detail(self):
974
        prior = {None, Category.AUTOGRAD_DETAIL}
975
        for node in self._data_flow_graph.flow_nodes:
976
            if RecordScope.BACKWARD_FUNCTION in get_scopes(node._event):
977
                for key, version in node.outputs.items():
978
                    if version == 0 or self._categories.get(key, version - 1) in prior:
979
                        self._categories.setdefault_by_version(
980
                            key, version, Category.AUTOGRAD_DETAIL
981
                        )
982

983

984
class MemoryProfileTimeline:
985
    def __init__(self, memory_profile):
986
        """The minimum representation of the memory profile timeline
987
        includes the memory timeline and categories. The timeline
988
        consists of [timestamp, action, (TensorKey, version), numbytes]
989
        elements, to denote any actions (pre-existing, create, destroy,
990
        or increment_version) that occurred to a specific Tensor for a
991
        chunk of memory. The categories help map each (TensorKey,
992
        version) pair into a category."""
993
        self.timeline = memory_profile.timeline
994
        self.categories = memory_profile._categories
995

996
    def _coalesce_timeline(self, device_str):
997
        """Convert the memory timeline and categories into a memory plot
998
        consisting of timestamps and their respective sizes by category
999
        for a given device.
1000

1001
        Input: device
1002
        Output: [timestamps, sizes by category]
1003
        """
1004
        device = torch.device(device_str)
1005
        times: List[int] = []
1006
        sizes: List[List[int]] = []
1007

1008
        def update(key, version, delta):
1009
            category = (
1010
                self.categories.get(key, version)
1011
                if isinstance(key, TensorKey)
1012
                else None
1013
            )
1014
            index = _CATEGORY_TO_INDEX[category] + 1
1015
            sizes[-1][index] += int(delta)
1016

1017
        t_min = -1
1018
        for t, action, (key, version), numbytes in self.timeline:
1019
            if key.device != device:
1020
                continue
1021

1022
            # Convert timestamps from ns to us, to match trace events.
1023
            if t != -1:
1024
                t = int(t / 1000)
1025

1026
            # Save the smallest timestamp to populate pre-existing allocs.
1027
            if t_min == -1 or (t < t_min and t > 0):
1028
                t_min = t
1029

1030
            # Handle timestep
1031
            if len(times) == 0:
1032
                times.append(t)
1033
                sizes.append([0] + [0 for _ in _CATEGORY_TO_INDEX])
1034

1035
            elif t != times[-1]:
1036
                times.append(t)
1037
                sizes.append(sizes[-1].copy())
1038

1039
            # Handle memory and categories
1040
            if action in (Action.PREEXISTING, Action.CREATE):
1041
                update(key, version, numbytes)
1042

1043
            elif action == Action.INCREMENT_VERSION:
1044
                update(key, version, -numbytes)
1045
                update(key, version + 1, numbytes)
1046

1047
            elif action == Action.DESTROY:
1048
                update(key, version, -numbytes)
1049

1050
            else:
1051
                raise ValueError(f"Unknown action: {action}")
1052

1053
        times = [t_min if t < 0 else t for t in times]
1054
        return times, sizes
1055

1056
    def export_memory_timeline(self, path, device_str) -> None:
1057
        """Saves the memory timeline as [times, sizes by category]
1058
        as a JSON formatted file to the given path for the given
1059
        device."""
1060
        times, sizes = self._coalesce_timeline(device_str)
1061
        # TODO: Write a faster serialize (orjson not available in CI)
1062
        import json
1063

1064
        with open(path, "w") as f:
1065
            json.dump([times, sizes], f)
1066

1067
    def export_memory_timeline_raw(self, path, device_str) -> None:
1068
        """Saves the memory timeline as raw memory event tuples in the
1069
        form of (timestamp, action, numbytes, category)
1070
        as a JSON formatted file to the given path for the given
1071
        device."""
1072
        device = torch.device(device_str)
1073
        raw_events: List[Tuple[int, int, int, int]] = []
1074

1075
        def get_category_index(key, version):
1076
            category = (
1077
                self.categories.get(key, version)
1078
                if isinstance(key, TensorKey)
1079
                else None
1080
            )
1081
            return _CATEGORY_TO_INDEX[category]
1082

1083
        for t, action, (key, version), numbytes in self.timeline:
1084
            if key.device != device:
1085
                continue
1086

1087
            if action in (Action.PREEXISTING, Action.CREATE):
1088
                raw_events.append(
1089
                    (
1090
                        t,
1091
                        _ACTION_TO_INDEX[action],
1092
                        numbytes,
1093
                        get_category_index(key, version),
1094
                    )
1095
                )
1096

1097
            elif action == Action.INCREMENT_VERSION:
1098
                raw_events.append(
1099
                    (
1100
                        t,
1101
                        _ACTION_TO_INDEX[action],
1102
                        -numbytes,
1103
                        get_category_index(key, version),
1104
                    )
1105
                )
1106
                raw_events.append(
1107
                    (
1108
                        t,
1109
                        _ACTION_TO_INDEX[action],
1110
                        numbytes,
1111
                        get_category_index(key, version + 1),
1112
                    )
1113
                )
1114

1115
            elif action == Action.DESTROY:
1116
                raw_events.append(
1117
                    (
1118
                        t,
1119
                        _ACTION_TO_INDEX[action],
1120
                        -numbytes,
1121
                        get_category_index(key, version),
1122
                    )
1123
                )
1124

1125
            else:
1126
                raise ValueError(f"Unknown action: {action}")
1127

1128
        import json
1129

1130
        with open(path, "w") as f:
1131
            json.dump(raw_events, f)
1132

1133
    def export_memory_timeline_html(
1134
        self, path, device_str, figsize=(20, 12), title=None
1135
    ) -> None:
1136
        """Exports the memory timeline as an HTML file which contains
1137
        the memory timeline plot embedded as a PNG file."""
1138
        # Check if user has matplotlib installed, return gracefully if not.
1139
        import importlib.util
1140

1141
        matplotlib_spec = importlib.util.find_spec("matplotlib")
1142
        if matplotlib_spec is None:
1143
            print(
1144
                "export_memory_timeline_html failed because matplotlib was not found."
1145
            )
1146
            return
1147

1148
        from base64 import b64encode
1149
        from os import remove
1150
        from tempfile import NamedTemporaryFile
1151

1152
        import matplotlib.pyplot as plt
1153
        import numpy as np
1154

1155
        mt = self._coalesce_timeline(device_str)
1156
        times, sizes = np.array(mt[0]), np.array(mt[1])
1157
        # For this timeline, start at 0 to match Chrome traces.
1158
        t_min = min(times)
1159
        times -= t_min
1160
        stacked = np.cumsum(sizes, axis=1) / 1024**3
1161
        device = torch.device(device_str)
1162
        max_memory_allocated = torch.cuda.max_memory_allocated(device)
1163
        max_memory_reserved = torch.cuda.max_memory_reserved(device)
1164

1165
        # Plot memory timeline as stacked data
1166
        fig = plt.figure(figsize=figsize, dpi=80)
1167
        axes = fig.gca()
1168
        for category, color in _CATEGORY_TO_COLORS.items():
1169
            i = _CATEGORY_TO_INDEX[category]
1170
            axes.fill_between(
1171
                times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7
1172
            )
1173
        fig.legend(["Unknown" if i is None else i.name for i in _CATEGORY_TO_COLORS])
1174
        # Usually training steps are in magnitude of ms.
1175
        axes.set_xlabel("Time (ms)")
1176
        axes.set_ylabel("Memory (GB)")
1177
        title = "\n\n".join(
1178
            ([title] if title else [])
1179
            + [
1180
                f"Max memory allocated: {max_memory_allocated/(1024**3):.2f} GiB \n"
1181
                f"Max memory reserved: {max_memory_reserved/(1024**3):.2f} GiB"
1182
            ]
1183
        )
1184
        axes.set_title(title)
1185

1186
        # Embed the memory timeline image into the HTML file
1187
        tmpfile = NamedTemporaryFile("wb", suffix=".png", delete=False)
1188
        tmpfile.close()
1189
        fig.savefig(tmpfile.name, format="png")
1190

1191
        with open(tmpfile.name, "rb") as tmp:
1192
            encoded = b64encode(tmp.read()).decode("utf-8")
1193
            html = f"""<html>
1194
<head><meta charset="utf-8" /><title>GPU Memory Timeline HTML</title></head>
1195
<body>
1196
  <img src='data:image/png;base64,{encoded}'>
1197
</body>
1198
</html>"""
1199

1200
            with open(path, "w") as f:
1201
                f.write(html)
1202
        remove(tmpfile.name)
1203

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.