pytorch

Форк
0
/
profiler_util.py 
1100 строк · 38.8 Кб
1
# mypy: allow-untyped-defs
2
import bisect
3
import itertools
4
import math
5
from collections import defaultdict, namedtuple
6
from operator import attrgetter
7
from typing import Any, Dict, List, Optional, Tuple
8
from typing_extensions import deprecated
9

10
import torch
11
from torch.autograd import DeviceType
12

13

14
__all__ = [
15
    "EventList",
16
    "FormattedTimesMixin",
17
    "Interval",
18
    "Kernel",
19
    "FunctionEvent",
20
    "FunctionEventAvg",
21
    "StringTable",
22
    "MemRecordsAcc",
23
]
24

25

26
class EventList(list):
27
    """A list of Events (for pretty printing)."""
28

29
    def __init__(self, *args, **kwargs):
30
        use_device = kwargs.pop("use_device", None)
31
        profile_memory = kwargs.pop("profile_memory", False)
32
        with_flops = kwargs.pop("with_flops", False)
33
        super().__init__(*args, **kwargs)
34
        self._use_device = use_device
35
        self._profile_memory = profile_memory
36
        self._tree_built = False
37
        self._with_flops = with_flops
38

39
    def _build_tree(self):
40
        self._populate_cpu_children()
41
        self._remove_dup_nodes()
42
        self._set_backward_stacktraces()
43
        self._tree_built = True
44

45
    def __str__(self):
46
        return self.table()
47

48
    def _remove_dup_nodes(self):
49
        while True:
50
            to_delete = set()
51
            for idx in range(len(self)):
52
                if (
53
                    self[idx].cpu_parent is not None
54
                    and self[idx].cpu_parent.name == self[idx].name
55
                    and len(self[idx].cpu_parent.cpu_children) == 1
56
                ):
57
                    self[idx].cpu_parent.cpu_children = self[idx].cpu_children
58
                    self[idx].cpu_parent.kernels = self[idx].kernels  # lift kernels up
59
                    for ch in self[idx].cpu_children:
60
                        ch.cpu_parent = self[idx].cpu_parent
61
                    to_delete.add(idx)
62
            if len(to_delete) == 0:
63
                break
64
            new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
65
            self.clear()
66
            self.extend(new_evts)
67

68
    def _populate_cpu_children(self):
69
        """Populate child events into each underlying FunctionEvent object.
70

71
        One event is a child of another if [s1, e1) is inside [s2, e2). Where
72
        s1 and e1 would be start and end of the child event's interval. And
73
        s2 and e2 start and end of the parent event's interval
74

75
        Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10]
76
        be a parent of two other intervals.
77

78
        If for any reason two intervals intersect only partially, this function
79
        will not record a parent child relationship between then.
80
        """
81
        # Some events can be async (i.e. start and end on different threads),
82
        # since it's generally undefined how to attribute children ranges to
83
        # async ranges, we do not use them when calculating nested ranges and stats
84
        sync_events = [
85
            evt
86
            for evt in self
87
            if not evt.is_async and evt.device_type == DeviceType.CPU
88
        ]
89
        events = sorted(
90
            sync_events,
91
            key=attrgetter("thread"),
92
        )
93
        # Group by both thread and node_id, so that events that happen to have
94
        # the same thread_id but are from different nodes aren't incorrectly
95
        # grouped together.
96
        threads = itertools.groupby(
97
            events, key=lambda event: (event.thread, event.node_id)
98
        )
99

100
        # For each thread we keep a stack of current nested parents.
101
        # We maintain the invariant that each interval is a subset of all other
102
        # intervals lower in the stack.
103
        #
104
        # First we sort the intervals by their start time. Then we iterate over them.
105
        # Every time we see a new interval we remove several parents from
106
        # the top until we restore the invariant. Then parent child relationship
107
        # if recorded if the stack is not empty.
108
        # Finally we add new interval to the list
109
        #
110
        # Algorithm has O(N * log(N)) complexity where N is number of
111
        # intervals
112
        for thread_id, thread_events in threads:
113
            thread_events_ = sorted(
114
                thread_events,
115
                key=lambda event: [event.time_range.start, -event.time_range.end],
116
            )
117
            current_events: List[FunctionEvent] = []
118
            cur_end = 0
119
            for event in thread_events_:
120
                while len(current_events) > 0:
121
                    parent = current_events[-1]
122
                    if (
123
                        event.time_range.start >= parent.time_range.end
124
                        or event.time_range.end > parent.time_range.end
125
                    ):
126
                        # this can't be a parent
127
                        current_events.pop()
128
                    else:
129
                        parent.append_cpu_child(event)
130
                        assert (
131
                            event.cpu_parent is None
132
                        ), f"There is already a CPU parent event for {event.key}"
133
                        event.set_cpu_parent(parent)
134
                        break
135

136
                current_events.append(event)
137

138
    def _set_backward_stacktraces(self):
139
        def bw_parent(evt):
140
            if evt is None:
141
                return None
142
            elif evt.scope == 1:  # BACKWARD_FUNCTION
143
                return evt
144
            else:
145
                return bw_parent(evt.cpu_parent)
146

147
        fwd_stacks = {}
148
        for evt in self:
149
            if bw_parent(evt) is None and evt.stack is not None:
150
                t = (evt.sequence_nr, evt.thread)
151
                if t not in fwd_stacks:
152
                    fwd_stacks[t] = evt.stack
153

154
        for evt in self:
155
            p = bw_parent(evt)
156
            if p is not None:
157
                assert p.fwd_thread is not None
158
                t = (p.sequence_nr, p.fwd_thread)
159
                if t in fwd_stacks:
160
                    evt.stack = fwd_stacks[t]
161
                else:
162
                    evt.stack = []
163

164
    @property
165
    def self_cpu_time_total(self):
166
        return sum(event.self_cpu_time_total for event in self)
167

168
    def table(
169
        self,
170
        sort_by=None,
171
        row_limit=100,
172
        max_src_column_width=75,
173
        max_name_column_width=55,
174
        max_shapes_column_width=80,
175
        header=None,
176
        top_level_events_only=False,
177
    ):
178
        """Print an EventList as a nicely formatted table.
179

180
        Args:
181
            sort_by (str, optional): Attribute used to sort entries. By default
182
                they are printed in the same order as they were registered.
183
                Valid keys include: ``cpu_time``, ``cuda_time``, ``xpu_time``,
184
                ``cpu_time_total``, ``cuda_time_total``, ``xpu_time_total``,
185
                ``cpu_memory_usage``, ``cuda_memory_usage``, ``xpu_memory_usage``,
186
                ``self_cpu_memory_usage``, ``self_cuda_memory_usage``,
187
                ``self_xpu_memory_usage``, ``count``.
188
            top_level_events_only(bool, optional): Boolean flag to determine the
189
                selection of events to display. If true, the profiler will only
190
                display events at top level like top-level invocation of python
191
                `lstm`, python `add` or other functions, nested events like low-level
192
                cpu/cuda/xpu ops events are omitted for profiler result readability.
193

194
        Returns:
195
            A string containing the table.
196
        """
197
        return _build_table(
198
            self,
199
            sort_by=sort_by,
200
            row_limit=row_limit,
201
            max_src_column_width=max_src_column_width,
202
            max_name_column_width=max_name_column_width,
203
            max_shapes_column_width=max_shapes_column_width,
204
            header=header,
205
            profile_memory=self._profile_memory,
206
            with_flops=self._with_flops,
207
            top_level_events_only=top_level_events_only,
208
        )
209

210
    def export_chrome_trace(self, path):
211
        """Export an EventList as a Chrome tracing tools file.
212

213
        The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
214

215
        Args:
216
            path (str): Path where the trace will be written.
217
        """
218
        import os
219

220
        device_name = "cuda" if not self._use_device else self._use_device
221
        with open(path, "w") as f:
222
            chrome_events = []
223
            next_id = 0
224
            # Use file IO over using json.dump since JSON dumping is very slow and
225
            # this technique is proven to give a 4x speedup.
226
            f.write("[")
227
            for evt in self:
228
                if evt.trace_name is None:
229
                    continue
230
                f.write(
231
                    '{{"name": "{}", '
232
                    '"ph": "X", '
233
                    '"ts": {}, '
234
                    '"dur": {}, '
235
                    '"tid": {}, '
236
                    '"pid": "CPU functions", '
237
                    '"args": {{}}}}, '.format(
238
                        evt.trace_name,
239
                        evt.time_range.start,
240
                        evt.time_range.elapsed_us(),
241
                        evt.thread
242
                        if not evt.is_remote
243
                        else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "',
244
                    )
245
                )
246
                for k in evt.kernels:
247
                    # 's' and 'f' draw Flow arrows from
248
                    # the CPU launch to the GPU kernel
249
                    f.write(
250
                        f'{{"name": "{evt.trace_name}", '
251
                        '"ph": "s", '
252
                        f'"ts": {evt.time_range.start}, '
253
                        f'"tid": {evt.thread}, '
254
                        '"pid": "CPU functions", '
255
                        f'"id": {next_id}, '
256
                        f'"cat": "cpu_to_{device_name}", '
257
                        '"args": {}}, '
258
                    )
259
                    # Note: use torch.profiler to get device kernel trace
260
                    next_id += 1
261
            if len(self) > 0:
262
                # remove trailing whitespace and comma
263
                f.seek(f.tell() - 2, os.SEEK_SET)
264
                f.truncate()
265
            f.write("]")
266

267
    def supported_export_stacks_metrics(self):
268
        return [
269
            "self_cpu_time_total",
270
            "self_cuda_time_total",
271
            "self_xpu_time_total",
272
            "self_privateuse1_time_total",
273
        ]
274

275
    def export_stacks(self, path: str, metric: str):
276
        if metric not in self.supported_export_stacks_metrics():
277
            raise ValueError(
278
                "metric should be one of: "
279
                + str(self.supported_export_stacks_metrics())
280
            )
281
        translate_table = str.maketrans(" ;\t\n", "____")
282
        with open(path, "w") as f:
283
            for evt in self:
284
                if evt.stack and len(evt.stack) > 0:
285
                    metric_value = getattr(
286
                        evt,
287
                        metric.replace("cuda", "device")
288
                        .replace("xpu", "device")
289
                        .replace("privateuse1", "device"),
290
                    )
291
                    if int(metric_value) > 0:
292
                        stack_str = ""
293
                        for entry in reversed(evt.stack):
294
                            stack_str += entry.translate(translate_table)
295
                            stack_str += ";"
296
                        stack_str = stack_str[:-1] + " " + str(int(metric_value))
297
                        f.write(stack_str + "\n")
298

299
    def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0):
300
        """Averages all function events over their keys.
301

302
        Args:
303
            group_by_input_shapes: group entries by
304
                (event name, input shapes) rather than just event name.
305
                This is useful to see which input shapes contribute to the runtime
306
                the most and may help with size-specific optimizations or
307
                choosing the best candidates for quantization (aka fitting a roof line)
308

309
            group_by_stack_n: group by top n stack trace entries
310

311
        Returns:
312
            An EventList containing FunctionEventAvg objects.
313
        """
314
        assert self._tree_built
315
        stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
316

317
        def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]:
318
            key = [
319
                str(event.key),
320
                str(event.node_id),
321
                str(event.device_type),
322
                str(event.is_legacy),
323
                str(event.is_user_annotation),
324
            ]
325
            if group_by_input_shapes:
326
                key.append(str(event.input_shapes))
327
            if group_by_stack_n > 0:
328
                key += event.stack[:group_by_stack_n]
329
            return tuple(key)
330

331
        for evt in self:
332
            stats[get_key(evt, group_by_input_shapes, group_by_stack_n)].add(evt)
333

334
        avg_list = EventList(
335
            stats.values(),
336
            use_device=self._use_device,
337
            profile_memory=self._profile_memory,
338
            with_flops=self._with_flops,
339
        )
340
        for evt in avg_list:
341
            evt.stack = evt.stack[:group_by_stack_n]
342
            if not group_by_input_shapes:
343
                evt.input_shapes = ""
344
        return avg_list
345

346
    def total_average(self):
347
        """Averages all events.
348

349
        Returns:
350
            A FunctionEventAvg object.
351
        """
352
        total_stat = FunctionEventAvg()
353
        for evt in self:
354
            total_stat += evt
355
            total_stat.key = None
356
        total_stat.key = "Total"
357
        return total_stat
358

359

360
def _format_time(time_us):
361
    """Define how to format time in FunctionEvent."""
362
    US_IN_SECOND = 1000.0 * 1000.0
363
    US_IN_MS = 1000.0
364
    if time_us >= US_IN_SECOND:
365
        return f"{time_us / US_IN_SECOND:.3f}s"
366
    if time_us >= US_IN_MS:
367
        return f"{time_us / US_IN_MS:.3f}ms"
368
    return f"{time_us:.3f}us"
369

370

371
def _format_time_share(time_us, total_time_us):
372
    """Define how to format time in FunctionEvent."""
373
    if total_time_us == 0:
374
        assert time_us == 0, f"Expected time_us == 0 but got {time_us}"
375
        return "NaN"
376
    return f"{time_us * 100.0 / total_time_us:.2f}%"
377

378

379
def _format_memory(nbytes):
380
    """Return a formatted memory size string."""
381
    KB = 1024
382
    MB = 1024 * KB
383
    GB = 1024 * MB
384
    if abs(nbytes) >= GB:
385
        return f"{nbytes * 1.0 / GB:.2f} Gb"
386
    elif abs(nbytes) >= MB:
387
        return f"{nbytes * 1.0 / MB:.2f} Mb"
388
    elif abs(nbytes) >= KB:
389
        return f"{nbytes * 1.0 / KB:.2f} Kb"
390
    else:
391
        return str(nbytes) + " b"
392

393

394
def _attr_formatter(name):
395
    return property(lambda self: _format_time(getattr(self, name)))
396

397

398
class FormattedTimesMixin:
399
    """Helpers for FunctionEvent and FunctionEventAvg.
400

401
    The subclass should define `*_time_total` and `count` attributes.
402
    """
403

404
    cpu_time_str = _attr_formatter("cpu_time")
405
    device_time_str = _attr_formatter("device_time")
406
    cpu_time_total_str = _attr_formatter("cpu_time_total")
407
    device_time_total_str = _attr_formatter("device_time_total")
408
    self_cpu_time_total_str = _attr_formatter("self_cpu_time_total")
409
    self_device_time_total_str = _attr_formatter("self_device_time_total")
410

411
    @property
412
    def cpu_time(self):
413
        return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count  # type: ignore[attr-defined]
414

415
    @property
416
    def device_time(self):
417
        return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count  # type: ignore[attr-defined]
418

419
    @property
420
    @deprecated(
421
        "`cuda_time` is deprecated, please use `device_time` instead.",
422
        category=FutureWarning,
423
    )
424
    def cuda_time(self):  # To be deprecated
425
        return self.device_time
426

427

428
class Interval:
429
    def __init__(self, start, end):
430
        self.start = start
431
        self.end = end
432

433
    def elapsed_us(self):
434
        r"""
435
        Returns the length of the interval
436
        """
437
        return self.end - self.start
438

439

440
Kernel = namedtuple("Kernel", ["name", "device", "duration"])
441

442

443
class FunctionEvent(FormattedTimesMixin):
444
    """Profiling information about a single function."""
445

446
    def __init__(
447
        self,
448
        id,
449
        name,
450
        thread,
451
        start_us,
452
        end_us,
453
        fwd_thread=None,
454
        input_shapes=None,
455
        stack=None,
456
        scope=0,
457
        use_device=None,
458
        cpu_memory_usage=0,
459
        device_memory_usage=0,
460
        is_async=False,
461
        is_remote=False,
462
        sequence_nr=-1,
463
        node_id=-1,
464
        device_type=DeviceType.CPU,
465
        device_index=0,
466
        device_resource_id=None,
467
        is_legacy=False,
468
        flops=None,
469
        trace_name=None,
470
        concrete_inputs=None,
471
        kwinputs=None,
472
        is_user_annotation=False,
473
    ):
474
        self.id: int = id
475
        self.node_id: int = node_id
476
        self.name: str = name
477
        self.trace_name: str = trace_name
478
        self.time_range: Interval = Interval(start_us, end_us)
479
        self.thread: int = thread
480
        self.fwd_thread: Optional[int] = fwd_thread
481
        self.kernels: List[Kernel] = []
482
        self.count: int = 1
483
        self.cpu_children: List[FunctionEvent] = []
484
        self.cpu_parent: Optional[FunctionEvent] = None
485
        self.input_shapes: Tuple[int, ...] = input_shapes
486
        self.concrete_inputs: List[Any] = concrete_inputs
487
        self.kwinputs: Dict[str, Any] = kwinputs
488
        self.stack: List = stack
489
        self.scope: int = scope
490
        self.use_device: Optional[str] = use_device
491
        self.cpu_memory_usage: int = cpu_memory_usage
492
        self.device_memory_usage: int = device_memory_usage
493
        self.is_async: bool = is_async
494
        self.is_remote: bool = is_remote
495
        self.sequence_nr: int = sequence_nr
496
        self.device_type: DeviceType = device_type
497
        self.device_index: int = device_index
498
        self.device_resource_id: int = (
499
            thread if device_resource_id is None else device_resource_id
500
        )
501
        self.is_legacy: bool = is_legacy
502
        self.flops: Optional[int] = flops
503
        self.is_user_annotation: Optional[bool] = is_user_annotation
504

505
    def append_kernel(self, name, device, duration):
506
        assert self.device_type == DeviceType.CPU
507
        self.kernels.append(Kernel(name, device, duration))
508

509
    def append_cpu_child(self, child):
510
        """Append a CPU child of type FunctionEvent.
511

512
        One is supposed to append only direct children to the event to have
513
        correct self cpu time being reported.
514
        """
515
        assert self.device_type == DeviceType.CPU
516
        assert isinstance(child, FunctionEvent)
517
        assert child.device_type == DeviceType.CPU
518
        self.cpu_children.append(child)
519

520
    def set_cpu_parent(self, parent):
521
        """Set the immediate CPU parent of type FunctionEvent.
522

523
        One profiling FunctionEvent should have only one CPU parent such that
524
        the child's range interval is completely inside the parent's. We use
525
        this connection to determine the event is from top-level op or not.
526
        """
527
        assert self.device_type == DeviceType.CPU
528
        assert isinstance(parent, FunctionEvent)
529
        assert parent.device_type == DeviceType.CPU
530
        self.cpu_parent = parent
531

532
    # Note: async events don't have children, are not used when computing 'self'
533
    # metrics of other events, have only total cpu time
534
    @property
535
    def self_cpu_memory_usage(self):
536
        if self.is_async or self.device_type != DeviceType.CPU:
537
            return 0
538
        return self.cpu_memory_usage - sum(
539
            child.cpu_memory_usage for child in self.cpu_children
540
        )
541

542
    @property
543
    def self_device_memory_usage(self):
544
        if self.is_async or self.device_type != DeviceType.CPU:
545
            return 0
546
        return self.device_memory_usage - sum(
547
            child.device_memory_usage for child in self.cpu_children
548
        )
549

550
    @property
551
    @deprecated(
552
        "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.",
553
        category=FutureWarning,
554
    )
555
    def self_cuda_memory_usage(self):  # To be deprecated
556
        return self.self_device_memory_usage
557

558
    @property
559
    def cpu_time_total(self):
560
        if self.device_type == DeviceType.CPU:
561
            return self.time_range.elapsed_us()
562
        else:
563
            return 0
564

565
    @property
566
    def self_cpu_time_total(self):
567
        if self.is_async or self.device_type != DeviceType.CPU:
568
            return 0
569
        return self.cpu_time_total - sum(
570
            child.cpu_time_total for child in self.cpu_children
571
        )
572

573
    @property
574
    def device_time_total(self):
575
        if self.is_async or not self.use_device:
576
            return 0
577
        if self.device_type == DeviceType.CPU:
578
            if not self.is_legacy:
579
                # account for the kernels in the children ops
580
                return sum(kinfo.duration for kinfo in self.kernels) + sum(
581
                    ch.device_time_total for ch in self.cpu_children
582
                )
583
            else:
584
                # each legacy cpu events has a single (fake) kernel
585
                return sum(kinfo.duration for kinfo in self.kernels)
586
        else:
587
            assert self.device_type in [
588
                DeviceType.CUDA,
589
                DeviceType.PrivateUse1,
590
                DeviceType.MTIA,
591
            ]
592
            return self.time_range.elapsed_us()
593

594
    @property
595
    @deprecated(
596
        "`cuda_time_total` is deprecated. Use `device_time_total` instead.",
597
        category=FutureWarning,
598
    )
599
    def cuda_time_total(self):  # To be deprecated
600
        return self.device_time_total
601

602
    @property
603
    def self_device_time_total(self):
604
        if self.is_async or not self.use_device:
605
            return 0
606
        if self.device_type == DeviceType.CPU:
607
            return self.device_time_total - sum(
608
                child.device_time_total for child in self.cpu_children
609
            )
610
        else:
611
            assert self.device_type in [
612
                DeviceType.CUDA,
613
                DeviceType.PrivateUse1,
614
                DeviceType.MTIA,
615
            ]
616
            return self.device_time_total
617

618
    @property
619
    @deprecated(
620
        "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.",
621
        category=FutureWarning,
622
    )
623
    def self_cuda_time_total(self):  # To be deprecated
624
        return self.self_device_time_total
625

626
    @property
627
    def key(self):
628
        return self.name
629

630
    def __repr__(self):
631
        device_name = self.use_device
632
        device_time = self.device_time_str
633
        device_memory_usage = self.device_memory_usage
634
        return (
635
            f"<FunctionEvent id={self.id} name={self.name} device_type={self.device_type} node_id={self.node_id} "
636
            f"cpu_time={self.cpu_time_str} start_us={self.time_range.start} end_us={self.time_range.end} "
637
            f"cpu_children={str([child.id for child in self.cpu_children])} {device_name}_time={device_time} "
638
            f"name={self.name} thread={self.thread} input_shapes={str(self.input_shapes)} "
639
            f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory_usage} "
640
            f"is_async={self.is_async} is_remote={self.is_remote} seq_nr={self.sequence_nr} is_legacy={self.is_legacy}>"
641
        )
642

643

644
class FunctionEventAvg(FormattedTimesMixin):
645
    """Used to average stats over multiple FunctionEvent objects."""
646

647
    def __init__(self) -> None:
648
        self.key: Optional[str] = None
649
        self.count: int = 0
650
        self.node_id: int = 0
651
        self.is_async: bool = False
652
        self.is_remote: bool = False
653
        self.use_device: Optional[str] = None
654
        self.cpu_time_total: int = 0
655
        self.device_time_total: int = 0
656
        self.self_cpu_time_total: int = 0
657
        self.self_device_time_total: int = 0
658
        self.input_shapes: Optional[List[List[int]]] = None
659
        self.stack: Optional[List] = None
660
        self.scope: Optional[int] = None
661
        self.cpu_memory_usage: int = 0
662
        self.device_memory_usage: int = 0
663
        self.self_cpu_memory_usage: int = 0
664
        self.self_device_memory_usage: int = 0
665
        self.cpu_children: Optional[List[FunctionEvent]] = None
666
        self.cpu_parent: Optional[FunctionEvent] = None
667
        self.device_type: DeviceType = DeviceType.CPU
668
        self.is_legacy: bool = False
669
        self.flops: int = 0
670

671
    def add(self, other):
672
        if self.key is None:
673
            # First function being recorded as part of FunctionEventAvg, propagate
674
            # fields.
675
            self.key = other.key
676
            self.node_id = other.node_id
677
            self.is_async = other.is_async
678
            self.is_remote = other.is_remote
679
            self.cpu_parent = other.cpu_parent
680
            self.cpu_children = other.cpu_children
681

682
            self.input_shapes = other.input_shapes
683
            self.stack = other.stack
684
            self.scope = other.scope
685
            self.device_type = other.device_type
686
            self.is_legacy = other.is_legacy
687
            self.use_device = other.use_device
688
            self.is_user_annotation = other.is_user_annotation
689

690
        assert isinstance(other, (FunctionEvent, FunctionEventAvg))
691
        assert other.key == self.key
692
        self.cpu_time_total += other.cpu_time_total
693
        self.device_time_total += other.device_time_total
694
        self.self_cpu_time_total += other.self_cpu_time_total
695
        self.self_device_time_total += other.self_device_time_total
696
        self.cpu_memory_usage += other.cpu_memory_usage
697
        self.device_memory_usage += other.device_memory_usage
698
        self.self_cpu_memory_usage += other.self_cpu_memory_usage
699
        self.self_device_memory_usage += other.self_device_memory_usage
700
        self.count += other.count
701
        if self.flops is None:
702
            self.flops = other.flops
703
        elif other.flops is not None:
704
            self.flops += other.flops
705
        return self
706

707
    def __iadd__(self, other):
708
        return self.add(other)
709

710
    def __repr__(self):
711
        device_name = "cuda" if not self.use_device else self.use_device
712
        self_device_time = self.self_device_time_total_str
713
        device_time = self.device_time_str
714
        device_memory = self.device_memory_usage
715
        return (
716
            f"<FunctionEventAvg key={self.key} self_cpu_time={self.self_cpu_time_total_str} cpu_time={self.cpu_time_str} "
717
            f" self_{device_name}_time={self_device_time} {device_name}_time={device_time} input_shapes={str(self.input_shapes)} "
718
            f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory}>"
719
        )
720

721

722
class StringTable(defaultdict):
723
    def __missing__(self, key):
724
        # manage cases like 't' (demangled to 'unsigned short') separately,
725
        # for now simply check the length to avoid unexpected results for
726
        # the short sequences
727
        self[key] = torch._C._demangle(key) if len(key) > 1 else key
728
        return self[key]
729

730

731
class MemRecordsAcc:
732
    """Acceleration structure for accessing mem_records in interval."""
733

734
    def __init__(self, mem_records):
735
        self._mem_records = mem_records
736
        self._start_nses: List[int] = []
737
        self._indices: List[int] = []
738
        if len(mem_records) > 0:
739
            tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
740
            self._start_nses, self._indices = zip(*tmp)  # type: ignore[assignment]
741

742
    def in_interval(self, start_us, end_us):
743
        r"""
744
        Return all records in the given interval
745
        To maintain backward compatibility, convert us to ns in function
746
        """
747
        start_idx = bisect.bisect_left(self._start_nses, start_us * 1000)
748
        end_idx = bisect.bisect_right(self._start_nses, end_us * 1000)
749
        for i in range(start_idx, end_idx):
750
            yield self._mem_records[self._indices[i]]
751

752

753
def _filter_stack_entry(entry):
754
    filtered_entries = [
755
        ("autograd/__init__", "_make_grads"),
756
        ("autograd/__init__", "backward"),
757
        ("torch/tensor", "backward"),
758
        ("_internal/common_utils", "prof_callable"),
759
        ("_internal/common_utils", "prof_func_call"),
760
        ("_internal/common_utils", "prof_meth_call"),
761
    ]
762
    return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries)
763

764

765
MEMORY_EVENT_NAME = "[memory]"
766
OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]"
767

768

769
def _filter_name(name):
770
    # ignoring the following utility ops
771
    filtered_out_names = [
772
        MEMORY_EVENT_NAME,  # used only for the top-level memory events
773
        OUT_OF_MEMORY_EVENT_NAME,
774
        "profiler::_record_function_enter",
775
        "profiler::_record_function_enter_new",
776
        "profiler::_record_function_exit",
777
        "aten::is_leaf",
778
        "aten::output_nr",
779
        "aten::_version",
780
    ]
781
    return name in filtered_out_names
782

783

784
# Demangles and optionally rewrites the provided event name,
785
# with_wildcard - whether to replace certain numbered event names
786
# with a wildcard name to aggregate them together in the profiler table
787
# output
788
def _rewrite_name(name, with_wildcard=False):
789
    string_table = StringTable()
790
    name = string_table[name]
791
    if with_wildcard:
792
        if name.startswith("ProfilerStep#"):
793
            name = "ProfilerStep*"
794
    return name
795

796

797
def _build_table(
798
    events,
799
    sort_by=None,
800
    header=None,
801
    row_limit=100,
802
    max_src_column_width=75,
803
    max_name_column_width=55,
804
    max_shapes_column_width=80,
805
    with_flops=False,
806
    profile_memory=False,
807
    top_level_events_only=False,
808
):
809
    """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
810
    if len(events) == 0:
811
        return ""
812

813
    has_device_time = any(event.self_device_time_total > 0 for event in events)
814
    has_device_mem = any(event.self_device_memory_usage > 0 for event in events)
815
    use_device = events[0].use_device
816
    # Running on PrivateUse1 device with profiler but not enable
817
    # ProfilerActivity.PrivateUse1 can also catch privateuse1 memory usage.
818
    # Here only need to check has_privateuse1_time if not use_device.
819
    if not use_device and has_device_time:
820
        raise RuntimeError("use_device is None, but there is device performance data.")
821

822
    has_input_shapes = any(
823
        (event.input_shapes is not None and len(event.input_shapes) > 0)
824
        for event in events
825
    )
826

827
    if sort_by is not None:
828
        events = EventList(
829
            sorted(
830
                events,
831
                key=lambda evt: getattr(
832
                    evt,
833
                    sort_by.replace("cuda", "device")
834
                    .replace("xpu", "device")
835
                    .replace("privateuse1", "device"),
836
                ),
837
                reverse=True,
838
            ),
839
            use_device=use_device,
840
            profile_memory=profile_memory,
841
            with_flops=with_flops,
842
        )
843

844
    name_column_width = max(len(evt.key) for evt in events) + 4
845
    if max_name_column_width is not None:
846
        name_column_width = min(name_column_width, max_name_column_width)
847

848
    shapes_column_width = max(len(str(evt.input_shapes)) for evt in events) + 4
849
    if max_shapes_column_width is not None:
850
        shapes_column_width = min(shapes_column_width, max_shapes_column_width)
851

852
    DEFAULT_COLUMN_WIDTH = 12
853
    flops_column_width = DEFAULT_COLUMN_WIDTH
854

855
    src_column_width = None
856
    stacks = []
857
    for evt in events:
858
        if evt.stack is not None and len(evt.stack) > 0:
859
            stacks.append(evt.stack)
860
    has_stack = len(stacks) > 0
861
    if has_stack:
862
        src_column_width = (
863
            max(max(len(entry) for entry in stack) for stack in stacks) + 4
864
        )
865
        if max_src_column_width is not None:
866
            src_column_width = min(src_column_width, max_src_column_width)
867

868
    headers = [
869
        "Name",
870
        "Self CPU %",
871
        "Self CPU",
872
        "CPU total %",
873
        "CPU total",
874
        "CPU time avg",
875
    ]
876
    device_name = use_device.upper() if use_device is not None else "None"
877
    if has_device_time:
878
        headers.extend(
879
            [
880
                f"Self {device_name}",
881
                f"Self {device_name} %",
882
                f"{device_name} total",
883
                f"{device_name} time avg",
884
            ]
885
        )
886
    if profile_memory:
887
        headers.extend(
888
            [
889
                "CPU Mem",
890
                "Self CPU Mem",
891
            ]
892
        )
893
        if use_device and has_device_mem:
894
            headers.extend(
895
                [
896
                    f"{device_name} Mem",
897
                    f"Self {device_name} Mem",
898
                ]
899
            )
900
    headers.append("# of Calls")
901
    # Only append Node ID if any event has a valid (>= 0) Node ID
902
    append_node_id = any(evt.node_id != -1 for evt in events)
903
    if append_node_id:
904
        headers.append("Node ID")
905

906
    # Have to use a list because nonlocal is Py3 only...
907
    SPACING_SIZE = 2
908
    row_format_lst = [""]
909
    header_sep_lst = [""]
910
    line_length_lst = [-SPACING_SIZE]
911

912
    def add_column(padding, text_dir=">"):
913
        row_format_lst[0] += (
914
            "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE)
915
        )
916
        header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE)
917
        line_length_lst[0] += padding + SPACING_SIZE
918

919
    def auto_scale_flops(flops):
920
        flop_headers = [
921
            "FLOPs",
922
            "KFLOPs",
923
            "MFLOPs",
924
            "GFLOPs",
925
            "TFLOPs",
926
            "PFLOPs",
927
        ]
928
        assert flops > 0
929
        log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
930
        assert log_flops >= 0 and log_flops < len(flop_headers)
931
        return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
932

933
    add_column(name_column_width)
934
    for _ in headers[1:]:
935
        add_column(DEFAULT_COLUMN_WIDTH)
936

937
    if has_input_shapes:
938
        headers.append("Input Shapes")
939
        add_column(shapes_column_width)
940

941
    if has_stack:
942
        headers.append("Source Location")
943
        add_column(src_column_width, text_dir="<")
944

945
    if with_flops:
946
        # Auto-scaling of flops header
947
        raw_flops = []
948
        for evt in events:
949
            if evt.flops > 0:
950
                raw_flops.append(evt.flops)
951
        if len(raw_flops) != 0:
952
            (flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
953
            headers.append(f"Total {flops_header}")
954
            add_column(flops_column_width)
955
        else:
956
            with_flops = False  # can't find any valid flops
957

958
    row_format = row_format_lst[0]
959
    header_sep = header_sep_lst[0]
960
    line_length = line_length_lst[0]
961
    add_column = None  # type: ignore[assignment]
962

963
    # Have to use a list because nonlocal is Py3 only...
964
    result = []
965

966
    def append(s):
967
        result.append(s)
968
        result.append("\n")  # Yes, newline after the end as well
969

970
    sum_self_cpu_time_total = 0
971
    sum_self_device_time_total = 0
972
    for evt in events:
973
        sum_self_cpu_time_total += evt.self_cpu_time_total
974
        if evt.device_type == DeviceType.CPU and evt.is_legacy:
975
            # in legacy profiler, kernel info is stored in cpu events
976
            sum_self_device_time_total += evt.self_device_time_total
977
        elif (
978
            evt.device_type
979
            in [
980
                DeviceType.CUDA,
981
                DeviceType.PrivateUse1,
982
                DeviceType.MTIA,
983
            ]
984
            and not evt.is_user_annotation
985
        ):
986
            # in kineto profiler, there're events with the correct device type (e.g. CUDA)
987
            sum_self_device_time_total += evt.self_device_time_total
988

989
    # Actual printing
990
    if header is not None:
991
        append("=" * line_length)
992
        append(header)
993
    if top_level_events_only:
994
        append("=" * line_length)
995
        append("This report only display top-level ops statistics")
996
    append(header_sep)
997
    append(row_format.format(*headers))
998

999
    append(header_sep)
1000

1001
    def trim_path(path, src_column_width):
1002
        if len(path) > src_column_width:
1003
            offset = len(path) - src_column_width
1004
            path = path[offset:]
1005
            if len(path) > 3:
1006
                path = "..." + path[3:]
1007
        return path
1008

1009
    event_limit = 0
1010
    for evt in events:
1011
        if event_limit == row_limit:
1012
            break
1013
        if top_level_events_only and evt.cpu_parent is not None:
1014
            continue
1015
        else:
1016
            event_limit += 1
1017
        name = evt.key
1018
        if max_name_column_width is not None and len(name) >= max_name_column_width - 3:
1019
            name = name[: (max_name_column_width - 3)] + "..."
1020
        row_values = [
1021
            name,
1022
            # Self CPU total %, 0 for async events.
1023
            _format_time_share(evt.self_cpu_time_total, sum_self_cpu_time_total),
1024
            evt.self_cpu_time_total_str,  # Self CPU total
1025
            # CPU total %, 0 for async events.
1026
            _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total)
1027
            if not evt.is_async
1028
            else 0,
1029
            evt.cpu_time_total_str,  # CPU total
1030
            evt.cpu_time_str,  # CPU time avg
1031
        ]
1032
        if has_device_time:
1033
            row_values.extend(
1034
                [
1035
                    evt.self_device_time_total_str,
1036
                    # device time total %
1037
                    _format_time_share(
1038
                        evt.self_device_time_total, sum_self_device_time_total
1039
                    ),
1040
                    evt.device_time_total_str,
1041
                    evt.device_time_str,  # device time avg
1042
                ]
1043
            )
1044
        if profile_memory:
1045
            row_values.extend(
1046
                [
1047
                    # CPU Mem Total
1048
                    _format_memory(evt.cpu_memory_usage),
1049
                    # Self CPU Mem Total
1050
                    _format_memory(evt.self_cpu_memory_usage),
1051
                ]
1052
            )
1053
            if use_device and has_device_mem:
1054
                row_values.extend(
1055
                    [
1056
                        # Device Mem Total
1057
                        _format_memory(evt.device_memory_usage),
1058
                        # Self Device Mem Total
1059
                        _format_memory(evt.self_device_memory_usage),
1060
                    ]
1061
                )
1062
        row_values.append(
1063
            evt.count,  # Number of calls
1064
        )
1065

1066
        if append_node_id:
1067
            row_values.append(evt.node_id)
1068
        if has_input_shapes:
1069
            row_values.append(str(evt.input_shapes)[:shapes_column_width])
1070
        if with_flops:
1071
            if evt.flops <= 0:
1072
                row_values.append("--")
1073
            else:
1074
                row_values.append(f"{evt.flops * flops_scale:8.3f}")  # type: ignore[possibly-undefined]
1075
        if has_stack:
1076
            src_field = ""
1077
            if len(evt.stack) > 0:
1078
                src_field = trim_path(evt.stack[0], src_column_width)
1079
            row_values.append(src_field)
1080
        append(row_format.format(*row_values))
1081

1082
        if has_stack:
1083
            empty_headers = [""] * (len(headers) - 1)
1084
            for entry in evt.stack[1:]:
1085
                append(
1086
                    row_format.format(
1087
                        *(empty_headers + [trim_path(entry, src_column_width)])
1088
                    )
1089
                )
1090
            empty_headers.append("")
1091
            append(row_format.format(*empty_headers))
1092

1093
    append(header_sep)
1094
    append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}")
1095
    if has_device_time:
1096
        append(
1097
            f"Self {use_device.upper() if use_device is not None else 'None'} "
1098
            f"time total: {_format_time(sum_self_device_time_total)}"
1099
        )
1100
    return "".join(result)
1101

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.