pytorch

Форк
0
/
profiler.py 
1170 строк · 46.2 Кб
1
# mypy: allow-untyped-defs
2
from collections import defaultdict
3
from dataclasses import dataclass
4
from time import perf_counter_ns
5
from typing import Any, Dict, Iterable, List, Optional
6
from warnings import warn
7

8
import torch
9
import torch.cuda
10
from torch._C import _get_privateuse1_backend_name
11
from torch._C._profiler import _ExperimentalConfig
12
from torch.autograd import (
13
    _disable_profiler,
14
    _enable_profiler,
15
    _kineto_step,
16
    _prepare_profiler,
17
    _ProfilerResult,
18
    _supported_activities,
19
    _toggle_collection_dynamic,
20
    DeviceType,
21
    kineto_available,
22
    ProfilerActivity,
23
    ProfilerConfig,
24
    ProfilerState,
25
)
26
from torch.autograd.profiler_util import (
27
    _filter_name,
28
    _filter_stack_entry,
29
    _rewrite_name,
30
    EventList,
31
    FunctionEvent,
32
    MEMORY_EVENT_NAME,
33
    MemRecordsAcc,
34
    OUT_OF_MEMORY_EVENT_NAME,
35
)
36
from torch.futures import Future
37

38

39
__all__ = [
40
    "profile",
41
    "record_function",
42
    "emit_itt",
43
    "emit_nvtx",
44
    "load_nvprof",
45
    "EnforceUnique",
46
    "parse_nvprof_trace",
47
    "KinetoStepTracker",
48
    "EventList",
49
    "FunctionEvent",
50
    "MemRecordsAcc",
51
]
52

53
try:
54
    # Available in Python >= 3.2
55
    from contextlib import ContextDecorator as _ContextDecorator
56
except ImportError:
57
    import functools
58

59
    class _ContextDecorator:  # type: ignore[no-redef]
60
        def __enter__(self):
61
            raise NotImplementedError
62

63
        def __exit__(self, exc_type, exc_val, exc_tb):
64
            raise NotImplementedError
65

66
        def __call__(self, func):
67
            @functools.wraps(func)
68
            def wrapped(*args, **kwargs):
69
                with self:
70
                    return func(*args, **kwargs)
71

72
            return wrapped
73

74

75
# global python state - whether profiler is currently enabled
76
# useful for fast python checks to reduce latency
77
_is_profiler_enabled: bool = False
78

79

80
def _set_is_profiler_enabled(enable: bool):
81
    global _is_profiler_enabled
82
    _is_profiler_enabled = enable
83

84

85
def _run_on_profiler_start():
86
    _set_is_profiler_enabled(True)
87

88

89
def _run_on_profiler_stop():
90
    _set_is_profiler_enabled(False)
91

92

93
@dataclass
94
class _ProfilerStats:
95
    "Profiler timing and stats used by developers to catch issues/regressions"
96
    profiling_window_duration_sec: float = 0
97
    number_of_events: int = 0
98
    profiler_prepare_call_duration_us: int = 0
99
    profiler_enable_call_duration_us: int = 0
100
    profiler_disable_call_duration_us: int = 0
101
    parse_kineto_call_duration_us: int = 0
102
    function_events_build_tree_call_duration_us: int = 0
103

104

105
class profile:
106
    """Context manager that manages autograd profiler state and holds a summary of results.
107

108
    Under the hood it just records events of functions being executed in C++ and
109
    exposes those events to Python. You can wrap any code into it and it will
110
    only report runtime of PyTorch functions.
111
    Note: profiler is thread local and is automatically propagated into the async tasks
112

113
    Args:
114
        enabled (bool, optional): Setting this to False makes this context manager a no-op.
115

116
        use_cuda (bool, optional): Enables timing of CUDA events as well
117
            using the cudaEvent API. (will be deprecated)
118

119
        use_device (str, optional): Enables timing of device events.
120
            Adds approximately 4us of overhead to each tensor operation when use cuda.
121
            The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'.
122

123
        record_shapes (bool, optional): If shapes recording is set, information
124
            about input dimensions will be collected. This allows one to see which
125
            dimensions have been used under the hood and further group by them
126
            using prof.key_averages(group_by_input_shape=True). Please note that
127
            shape recording might skew your profiling data. It is recommended to
128
            use separate runs with and without shape recording to validate the timing.
129
            Most likely the skew will be negligible for bottom most events (in a case
130
            of nested function calls). But for higher level functions the total
131
            self cpu time might be artificially increased because of the shape
132
            collection.
133

134
        with_flops (bool, optional): If with_flops is set, the profiler will estimate
135
            the FLOPs (floating point operations) value using the operator's input shape.
136
            This allows one to estimate the hardware performance. Currently,
137
            this option only works for the matrix multiplication and 2D convolution operators.
138

139
        profile_memory (bool, optional): track tensor memory allocation/deallocation.
140

141
        with_stack (bool, optional): record source information (file and line number) for the ops.
142

143
        with_modules (bool): record module hierarchy (including function names)
144
            corresponding to the callstack of the op. e.g. If module A's forward call's
145
            module B's forward which contains an aten::add op,
146
            then aten::add's module hierarchy is A.B
147
            Note that this support exist, at the moment, only for TorchScript models
148
            and not eager mode models.
149

150
        use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
151

152
        use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
153
            ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
154

155
        experimental_config (_ExperimentalConfig) : A set of experimental options
156
            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
157

158
        acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
159

160

161
    .. warning:
162
        Enabling memory profiling or source attribution incurs additional profiler
163
        overhead
164

165
    .. warning:
166
        This context managers should not be called recursively, i.e. no nested
167
        instances are allowed
168

169
    .. warning:
170
        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
171
        one cannot use the profiler with ``use_device = 'cuda'`` to benchmark
172
        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
173
        please use ``use_device = None`` or ``num_workers = 0``.
174

175
    Example:
176
        >>> # xdoctest: +SKIP
177
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
178
        >>> x = torch.randn((1, 1), requires_grad=True)
179
        >>> with torch.autograd.profiler.profile() as prof:
180
        >>>     for _ in range(100):  # any normal python code, really!
181
        >>>         y = x ** 2
182
        >>>         y.backward()
183
        >>> # NOTE: some columns were removed for brevity
184
        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
185
        -----------------------------------  ---------------  ---------------  ---------------
186
        Name                                 Self CPU total   CPU time avg     Number of Calls
187
        -----------------------------------  ---------------  ---------------  ---------------
188
        mul                                  32.048ms         32.048ms         200
189
        pow                                  27.041ms         27.041ms         200
190
        PowBackward0                         9.727ms          55.483ms         100
191
        torch::autograd::AccumulateGrad      9.148ms          9.148ms          100
192
        torch::autograd::GraphRoot           691.816us        691.816us        100
193
        -----------------------------------  ---------------  ---------------  ---------------
194

195
    """
196

197
    def __init__(
198
        self,
199
        enabled=True,
200
        *,
201
        use_cuda=False,  # Deprecated
202
        use_device=None,
203
        record_shapes=False,
204
        with_flops=False,
205
        profile_memory=False,
206
        with_stack=False,
207
        with_modules=False,
208
        use_kineto=False,
209
        use_cpu=True,
210
        experimental_config=None,
211
        acc_events=False,
212
    ):
213
        self.enabled: bool = enabled
214
        if not self.enabled:
215
            return
216
        self.use_cuda = use_cuda
217
        if self.use_cuda:
218
            warn(
219
                "The attribute `use_cuda` will be deprecated soon, "
220
                "please use ``use_device = 'cuda'`` instead.",
221
                FutureWarning,
222
                stacklevel=2,
223
            )
224
            self.use_device: Optional[str] = "cuda"
225
        else:
226
            self.use_device = use_device
227
        # TODO Consider changing _function_events into data structure with size cap
228
        self._function_events: Optional[EventList] = None
229
        self._old_function_events: Optional[EventList] = None
230
        # Function event processing is done lazily
231
        self._needs_processing = False
232
        self.entered = False
233
        self.record_shapes = record_shapes
234
        self.with_flops = with_flops
235
        self.record_shapes |= self.with_flops
236
        self.profile_memory = profile_memory
237
        self.with_stack = with_stack
238
        self.with_modules = with_modules
239
        self.use_cpu = use_cpu
240
        self.acc_events = acc_events
241
        if experimental_config is None:
242
            experimental_config = _ExperimentalConfig()
243
        self.experimental_config = experimental_config
244
        self.kineto_results: Optional[_ProfilerResult] = None
245
        self.profiling_start_time_ns = 0
246
        self.profiling_end_time_ns = 0
247
        self._stats = _ProfilerStats()
248

249
        if not self.use_cpu:
250
            assert (
251
                use_kineto
252
            ), "Device-only events supported only with Kineto (use_kineto=True)"
253

254
        if self.use_device is not None:
255
            VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"]
256
            if _get_privateuse1_backend_name() != "privateuseone":
257
                VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name())
258
            if self.use_device not in VALID_DEVICE_OPTIONS:
259
                warn(f"The {self.use_device} is not a valid device option.")
260
                self.use_device = None
261

262
            if self.use_device == "cuda" and not torch.cuda.is_available():
263
                warn("CUDA is not available, disabling CUDA profiling")
264
                self.use_cuda = False
265
                self.use_device = None
266

267
            if self.use_device == "xpu" and not torch.xpu.is_available():
268
                warn("XPU is not available, disabling XPU profiling")
269
                self.use_device = None
270

271
        self.kineto_activities = set()
272
        if self.use_cpu:
273
            self.kineto_activities.add(ProfilerActivity.CPU)
274

275
        self.profiler_kind = ProfilerState.KINETO
276
        if self.use_device == "cuda":
277
            if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
278
                assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
279
                self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
280
            else:
281
                self.kineto_activities.add(ProfilerActivity.CUDA)
282
        elif self.use_device == "xpu":
283
            assert (
284
                use_kineto and ProfilerActivity.XPU in _supported_activities()
285
            ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices."
286
            self.kineto_activities.add(ProfilerActivity.XPU)
287
        elif self.use_device == "mtia":
288
            assert (
289
                use_kineto and ProfilerActivity.MTIA in _supported_activities()
290
            ), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices."
291
            self.kineto_activities.add(ProfilerActivity.MTIA)
292
        elif self.use_device is not None and self.use_device != "privateuseone":
293
            if (
294
                not use_kineto
295
                or ProfilerActivity.PrivateUse1 not in _supported_activities()
296
            ):
297
                assert (
298
                    self.use_cpu
299
                ), "Legacy custombackend profiling requires use_cpu=True"
300
                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
301
            else:
302
                self.kineto_activities.add(ProfilerActivity.PrivateUse1)
303

304
        assert (
305
            len(self.kineto_activities) > 0
306
        ), "No activities specified for the profiler"
307

308
    def config(self):
309
        return ProfilerConfig(
310
            self.profiler_kind,
311
            self.record_shapes,
312
            self.profile_memory,
313
            self.with_stack,
314
            self.with_flops,
315
            self.with_modules,
316
            self.experimental_config,
317
        )
318

319
    def __enter__(self):
320
        if not self.enabled:
321
            return
322
        if self.entered:
323
            raise RuntimeError("Profiler context manager is not reentrant")
324
        self._prepare_trace()
325
        self._start_trace()
326
        return self
327

328
    def _prepare_trace(self):
329
        self.entered = True
330
        t0 = perf_counter_ns()
331
        _prepare_profiler(self.config(), self.kineto_activities)
332
        t1 = perf_counter_ns()
333
        self._stats.profiler_prepare_call_duration_us = int((t1 - t0) / 1000)
334

335
    def _start_trace(self):
336
        self.entered = True
337
        _run_on_profiler_start()
338
        t0 = perf_counter_ns()
339
        _enable_profiler(self.config(), self.kineto_activities)
340
        t1 = perf_counter_ns()
341
        self._stats.profiler_enable_call_duration_us = int((t1 - t0) / 1000)
342
        self.profiling_start_time_ns = t1
343

344
    def __exit__(self, exc_type, exc_val, exc_tb):
345
        if not self.enabled:
346
            return
347
        if self.use_device and hasattr(torch, self.use_device):
348
            device_module = getattr(torch, self.use_device)
349
            if hasattr(device_module, "synchronize"):
350
                device_module.synchronize()
351

352
        if self._function_events and self.acc_events:
353
            self._old_function_events = self._function_events
354
        self._function_events = None
355
        self._needs_processing = True
356

357
        t0 = perf_counter_ns()
358

359
        self.kineto_results = _disable_profiler()
360
        t1 = perf_counter_ns()
361
        self._stats.profiler_disable_call_duration_us = int((t1 - t0) / 1000)
362
        self.profiling_end_time_ns = t0
363

364
        _run_on_profiler_stop()
365

366
        self._stats.profiling_window_duration_sec = (
367
            (self.profiling_end_time_ns - self.profiling_start_time_ns) * 1.0 / 1e9
368
        )
369

370
        # If we plan to accumulate events we should post process the function events
371
        # right away to retain the state across mulitple start/stop calls
372
        if self.acc_events:
373
            self._ensure_function_events()
374
        return False
375

376
    def __repr__(self):
377
        if self._needs_processing:
378
            self._ensure_function_events()
379
        if self._function_events is None:
380
            return "<unfinished torch.autograd.profile>"
381
        return repr(self._function_events)
382

383
    def __str__(self):
384
        if self._needs_processing:
385
            self._ensure_function_events()
386
        if self._function_events is None:
387
            return "<unfinished torch.autograd.profile>"
388
        return str(self._function_events)
389

390
    def _ensure_function_events(self):
391
        """Process function events lazily if required"""
392
        if self._function_events is not None:
393
            return
394
        self._needs_processing = False
395

396
        t0 = perf_counter_ns()
397
        parsed_results = []
398
        if self.kineto_results:
399
            parsed_results = self._parse_kineto_results(self.kineto_results)
400
        t1 = perf_counter_ns()
401
        self._stats.parse_kineto_call_duration_us = int((t1 - t0) / 1000)
402

403
        self._function_events = EventList(
404
            parsed_results,
405
            use_device=self.use_device,
406
            profile_memory=self.profile_memory,
407
            with_flops=self.with_flops,
408
        )
409
        t0 = perf_counter_ns()
410
        self._function_events._build_tree()
411
        t1 = perf_counter_ns()
412
        self._stats.function_events_build_tree_call_duration_us = int((t1 - t0) / 1000)
413
        self._stats.number_of_events = len(self._function_events)
414

415
        if self._old_function_events and self.acc_events:
416
            for evt in self._old_function_events:
417
                self._function_events.append(evt)
418
            self._old_function_events = None
419

420
        if self._function_events is None:
421
            raise RuntimeError("Profiler didn't finish running")
422

423
    @property
424
    def function_events(self):
425
        if self._function_events is None or self._needs_processing:
426
            self._ensure_function_events()
427
        return self._function_events
428

429
    def table(
430
        self,
431
        sort_by=None,
432
        row_limit=100,
433
        max_src_column_width=75,
434
        max_name_column_width=55,
435
        max_shapes_column_width=80,
436
        header=None,
437
        top_level_events_only=False,
438
    ):
439
        self._ensure_function_events()
440
        assert self._function_events is not None
441
        return self._function_events.table(
442
            sort_by=sort_by,
443
            row_limit=row_limit,
444
            max_src_column_width=max_src_column_width,
445
            max_name_column_width=max_name_column_width,
446
            max_shapes_column_width=max_shapes_column_width,
447
            header=header,
448
            top_level_events_only=top_level_events_only,
449
        )
450

451
    table.__doc__ = EventList.table.__doc__
452

453
    def export_chrome_trace(self, path):
454
        """
455
        Exports the collected trace in Chrome JSON format. If kineto is enabled, only
456
        last cycle in schedule is exported.
457
        """
458
        if kineto_available():
459
            self.kineto_results.save(path)  # type: ignore[union-attr]
460
        else:
461
            self._ensure_function_events()
462
            return self._function_events.export_chrome_trace(path)  # type: ignore[union-attr]
463

464
    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
465

466
    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
467
        self._ensure_function_events()
468
        assert self._function_events is not None, "Expected profiling results"
469
        assert self.with_stack, "export_stacks() requires with_stack=True"
470
        return self._function_events.export_stacks(path, metric)
471

472
    def toggle_collection_dynamic(
473
        self, enabled: bool, activities: Iterable[ProfilerActivity]
474
    ):
475
        """
476
        Toggles the collection of activities for the current profiler instance.
477
        """
478
        return _toggle_collection_dynamic(enabled, set(activities))
479

480
    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
481
        self._ensure_function_events()
482
        assert self._function_events is not None, "Expected profiling results"
483
        return self._function_events.key_averages(
484
            group_by_input_shape, group_by_stack_n
485
        )
486

487
    key_averages.__doc__ = EventList.key_averages.__doc__
488

489
    def total_average(self):
490
        self._ensure_function_events()
491
        assert self._function_events is not None, "Expected profiling results"
492
        return self._function_events.total_average()
493

494
    total_average.__doc__ = EventList.total_average.__doc__
495

496
    @property
497
    def self_cpu_time_total(self):
498
        """Returns total time spent on CPU.
499

500
        The total time is a sum of all self times across all the events.
501
        """
502
        self._ensure_function_events()
503
        assert self._function_events is not None
504
        return self._function_events.self_cpu_time_total
505

506
    def _parse_kineto_results(self, result: _ProfilerResult):
507
        # result.events() has most of the events - PyTorch op-level and device-level events
508

509
        trace_start_ns = result.trace_start_ns()
510
        mem_records = [
511
            [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
512
        ]
513
        oom_records = [
514
            evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
515
        ]
516
        mem_records_acc = MemRecordsAcc(mem_records)
517

518
        def _cpu_memory_usage(mem_record):
519
            return (
520
                mem_record.nbytes()
521
                if mem_record.device_type()
522
                in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
523
                else 0
524
            )
525

526
        def _device_memory_usage(mem_record):
527
            return (
528
                mem_record.nbytes()
529
                if mem_record.device_type()
530
                in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
531
                else 0
532
            )
533

534
        # Create and return FunctionEvent list, which contains all function events
535
        # Here 2 function events are created:
536
        # all_function_events contains all events associated with each kineto event from result
537
        all_function_events = []
538
        # frontend_function_events contains the events in aten or torch frontend level,
539
        # whose correlation id is 0
540
        frontend_function_events = []
541
        device_corr_map: Dict[int, List[FunctionEvent]] = {}
542
        max_evt_id = 0
543
        for kineto_event in result.events():
544
            if _filter_name(kineto_event.name()):
545
                continue
546
            rel_start_ns = kineto_event.start_ns() - trace_start_ns
547
            rel_end_ns = kineto_event.end_ns() - trace_start_ns
548
            abs_end_ns = kineto_event.end_ns()
549

550
            cpu_memory_usage = 0
551
            device_memory_usage = 0
552
            if kineto_event.device_type() == DeviceType.CPU:
553
                # find the corresponding memory allocation events
554
                for mem_record in mem_records_acc.in_interval(
555
                    kineto_event.start_ns() / 1000, abs_end_ns / 1000
556
                ):
557
                    cpu_memory_usage += _cpu_memory_usage(mem_record[0])
558
                    device_memory_usage += _device_memory_usage(mem_record[0])
559
                    mem_record[1] = True
560

561
            is_async = kineto_event.is_async() or (
562
                kineto_event.start_thread_id() != kineto_event.end_thread_id()
563
            )
564

565
            fe = FunctionEvent(
566
                id=kineto_event.correlation_id(),
567
                name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
568
                trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
569
                thread=kineto_event.start_thread_id(),
570
                start_us=rel_start_ns / 1000,
571
                end_us=rel_end_ns / 1000,
572
                fwd_thread=kineto_event.fwd_thread_id(),
573
                input_shapes=kineto_event.shapes(),
574
                concrete_inputs=kineto_event.concrete_inputs(),
575
                kwinputs=kineto_event.kwinputs(),
576
                stack=[
577
                    entry
578
                    for entry in kineto_event.stack()
579
                    if _filter_stack_entry(entry)
580
                ],
581
                scope=kineto_event.scope(),
582
                use_device=self.use_device,
583
                cpu_memory_usage=cpu_memory_usage,
584
                device_memory_usage=device_memory_usage,
585
                is_async=is_async,
586
                sequence_nr=kineto_event.sequence_nr(),
587
                device_type=kineto_event.device_type(),
588
                device_index=kineto_event.device_index(),
589
                device_resource_id=kineto_event.device_resource_id(),
590
                flops=kineto_event.flops(),
591
                is_user_annotation=kineto_event.is_user_annotation(),
592
            )
593
            max_evt_id = max(max_evt_id, fe.id)
594
            if fe.device_type == DeviceType.CPU and not fe.is_async:
595
                if self.use_device == "privateuseone":
596
                    privateuse1_time = kineto_event.privateuse1_elapsed_us()
597
                    if privateuse1_time > 0:
598
                        fe.append_kernel(fe.name, fe.device_index, privateuse1_time)
599
                        fe.is_legacy = True
600
                elif self.use_device == "cuda":
601
                    # Check if we have CUDA time as a fallback
602
                    cuda_time = kineto_event.cuda_elapsed_us()
603
                    if cuda_time > 0:
604
                        fe.append_kernel(fe.name, fe.device_index, cuda_time)
605
                        fe.is_legacy = True
606
            all_function_events.append(fe)
607
            corr_id = kineto_event.linked_correlation_id()
608
            if corr_id > 0:
609
                if corr_id not in device_corr_map:
610
                    device_corr_map[corr_id] = []
611
                device_corr_map[corr_id].append(fe)
612
            elif corr_id == 0:
613
                frontend_function_events.append(fe)
614
            else:
615
                raise RuntimeError(
616
                    f"Got negative correlation id {corr_id} in profiler post processing"
617
                )
618

619
        # associate device kernels and device runtime (CPU) with CPU events
620
        for fe in frontend_function_events:
621
            if (
622
                fe.device_type == DeviceType.CPU
623
                and not fe.is_async
624
                and fe.id in device_corr_map
625
            ):
626
                for f_evt in device_corr_map[fe.id]:
627
                    if (
628
                        f_evt.device_type == DeviceType.CUDA
629
                        or f_evt.device_type == DeviceType.PrivateUse1
630
                    ):
631
                        fe.append_kernel(
632
                            f_evt.name,
633
                            f_evt.device_index,
634
                            f_evt.time_range.end - f_evt.time_range.start,
635
                        )
636
                    elif f_evt.device_type == DeviceType.CPU:
637
                        # make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated
638
                        # with the 'thread' of the corresponding linked PyTorch event to properly track
639
                        # parents and children
640
                        f_evt.thread = fe.thread
641

642
        def createFunctionEventForMemoryEvents(evt):
643
            rel_start_ns = evt.start_ns() - trace_start_ns
644
            fe = FunctionEvent(
645
                id=max_evt_id,
646
                name=evt.name(),
647
                trace_name=None,  # not outputting in the trace
648
                thread=evt.start_thread_id(),
649
                start_us=rel_start_ns / 1000,
650
                end_us=rel_start_ns / 1000,  # no duration
651
                fwd_thread=evt.start_thread_id(),
652
                input_shapes=[],
653
                stack=[],
654
                scope=0,  # RecordScope::FUNCTION
655
                use_device=self.use_device,
656
                cpu_memory_usage=_cpu_memory_usage(evt),
657
                device_memory_usage=_device_memory_usage(evt),
658
                is_async=False,
659
                sequence_nr=-1,
660
                device_type=DeviceType.CPU,
661
                device_index=0,
662
            )
663
            return fe
664

665
        # output top-level memory events
666
        for mem_record in mem_records:
667
            if not mem_record[1]:
668
                max_evt_id += 1
669
                fe = createFunctionEventForMemoryEvents(mem_record[0])
670
                all_function_events.append(fe)
671

672
        for oom_record in oom_records:
673
            max_evt_id += 1
674
            fe = createFunctionEventForMemoryEvents(oom_record)
675
            all_function_events.append(fe)
676

677
        all_function_events.sort(
678
            key=lambda evt: [evt.time_range.start, -evt.time_range.end]
679
        )
680
        return all_function_events
681

682

683
class record_function(_ContextDecorator):
684
    """Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
685
    Label will only appear if CPU activity tracing is enabled.
686

687
    It is useful when tracing the code profile.
688

689
    Args:
690
        name (str): Label assigned to the block of code.
691
        node_id (int): ID of node, for distributed profiling. Unset in
692
        non-distributed cases.
693

694
    Example:
695
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
696
        >>> x = torch.randn((1, 1), requires_grad=True)
697
        >>> with torch.autograd.profiler.profile() as prof:
698
        ...     y = x ** 2
699
        ...     with torch.autograd.profiler.record_function("label-z"): # label the block
700
        ...         z = y ** 3
701
        ...     y.backward()
702
        ...
703
        >>> # xdoctest: +IGNORE_WANT
704
        >>> # NOTE: some columns were removed for brevity
705
        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
706
        -----------------------------------  ---------------  ---------------  ---------------
707
        Name                                 Self CPU total %  CPU time avg     Number of Calls
708
        -----------------------------------  ---------------  ---------------  ---------------
709
        pow                                  60.77%           47.470us         3
710
        mul                                  21.73%           25.465us         2
711
        PowBackward0                         12.03%           121.891us        1
712
        torch::autograd::AccumulateGrad      2.70%            6.324us          1
713
        label-z                              2.13%            12.421us         1
714
        torch::autograd::GraphRoot           0.64%            1.503us          1
715
        -----------------------------------  ---------------  ---------------  ---------------
716
        Self CPU time total: 234.344us
717
        CUDA time total: 0.000us
718

719
    """
720

721
    def __init__(self, name: str, args: Optional[str] = None):
722
        self.name: str = name
723
        self.args: Optional[str] = args
724
        # Whether or not we should run record function's end callbacks when exiting.
725
        self.run_callbacks_on_exit: bool = True
726
        # TODO: TorchScript ignores standard type annotation here
727
        # self.record: Optional["torch.classes.profiler._RecordFunction"] = None
728
        self.record = torch.jit.annotate(
729
            Optional["torch.classes.profiler._RecordFunction"], None
730
        )
731

732
    def __enter__(self):
733
        self.record = torch.ops.profiler._record_function_enter_new(
734
            self.name, self.args
735
        )
736
        return self
737

738
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
739
        if not self.run_callbacks_on_exit:
740
            return
741

742
        # Local variable is needed by TorchScript to refine Optional[T] to T
743
        record = self.record
744
        assert record is not None
745

746
        # TODO: Too slow with __torch_function__ handling enabled
747
        # See https://github.com/pytorch/pytorch/issues/76410
748
        if not torch.jit.is_scripting():
749
            with torch._C.DisableTorchFunctionSubclass():
750
                torch.ops.profiler._record_function_exit._RecordFunction(record)
751
        else:
752
            torch.ops.profiler._record_function_exit(record)
753

754
    def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
755
        """Use for profiling async calls that return a future.
756

757
        Calling this function will extend recording beyond this scope, until the future is
758
        satisfied. It is useful for profiling the end to end time of asynchronous calls.
759
        This function should only be called once to attach the callback onto the future, and
760
        will throw if called multiple times.
761

762
        Args:
763
            fut: (torch._C.Future): future for which to schedule
764
            callback for.
765

766
        Returns:
767
            A future that completes with the value of the passed in future when
768
            the profiling callbacks have ran.
769

770
        """
771
        # Throw if we have already attached a callback onto the future.
772
        if not self.run_callbacks_on_exit:
773
            raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
774

775
        # We are scheduling to run this RecordFunction's end callbacks when the
776
        # passed in future completes, so don't run end callbacks on exit.
777
        self.run_callbacks_on_exit = False
778

779
        # Local variable is needed by TorchScript to refine Optional[T] to T
780
        record = self.record
781
        assert record is not None
782

783
        # TODO: Too slow with __torch_function__ handling enabled
784
        # See https://github.com/pytorch/pytorch/issues/76410
785
        if not torch.jit.is_scripting():
786
            with torch._C.DisableTorchFunctionSubclass():
787
                profiled_future = (
788
                    torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
789
                        record, fut
790
                    )
791
                )
792
        else:
793
            profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
794
                record, fut
795
            )
796
        return profiled_future
797

798

799
class emit_itt:
800
    """Context manager that makes every autograd operation emit an ITT range.
801

802
    It is useful when running the program under Intel(R) VTune Profiler::
803

804
        vtune <--vtune-flags> <regular command here>
805

806
    The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
807
    control the collection of trace data during its execution across different Intel tools.
808
    This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
809
    you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
810

811
    .. warning:
812
        This context manager should not be called recursively, i.e. at most one
813
        instance should be enabled at any given time.
814

815
    Args:
816
        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
817
            Default: ``True``.
818
        record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
819
            each autograd op will append information about the sizes of Tensor arguments received
820
            by that op, in the following format:
821
            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
822
            Non-tensor arguments will be represented by ``[]``.
823
            Arguments will be listed in the order they are received by the backend op.
824
            Please note that this order may not match the order in which those arguments were passed
825
            on the Python side.  Also note that shape recording may increase the overhead of itt range creation.
826
            Default: ``False``
827

828
    Example:
829
        >>> # xdoctest: +SKIP("Undefined variables")
830
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
831
        >>> with torch.autograd.profiler.emit_itt():
832
        ...     model(x)
833

834
    """
835

836
    def __init__(self, enabled=True, record_shapes=False):
837
        self.enabled = enabled
838
        self.entered = False
839
        self.record_shapes = record_shapes
840

841
    def __enter__(self):
842
        if not self.enabled:
843
            return
844
        if self.entered:
845
            raise RuntimeError("ITT annotation context manager is not reentrant")
846
        self.entered = True
847
        _run_on_profiler_start()
848
        _enable_profiler(
849
            ProfilerConfig(
850
                ProfilerState.ITT,
851
                self.record_shapes,
852
                False,
853
                False,
854
                False,
855
                False,
856
                _ExperimentalConfig(),
857
            ),
858
            set(),
859
        )
860
        return self
861

862
    def __exit__(self, exc_type, exc_val, exc_tb):
863
        if not self.enabled:
864
            return
865
        _disable_profiler()
866
        _run_on_profiler_stop()
867
        return False
868

869

870
class emit_nvtx:
871
    """Context manager that makes every autograd operation emit an NVTX range.
872

873
    It is useful when running the program under nvprof::
874

875
        nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
876

877
    Unfortunately, there's no way to force nvprof to flush the data it collected
878
    to disk, so for CUDA profiling one has to use this context manager to annotate
879
    nvprof traces and wait for the process to exit before inspecting them.
880
    Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
881
    :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
882
    e.g. in Python REPL.
883

884
    .. warning:
885
        This context manager should not be called recursively, i.e. at most one
886
        instance should be enabled at any given time.
887

888
    Args:
889
        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
890
            Default: ``True``.
891
        record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
892
            each autograd op will append information about the sizes of Tensor arguments received
893
            by that op, in the following format:
894
            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
895
            Non-tensor arguments will be represented by ``[]``.
896
            Arguments will be listed in the order they are received by the backend op.
897
            Please note that this order may not match the order in which those arguments were passed
898
            on the Python side.  Also note that shape recording may increase the overhead of nvtx range creation.
899
            Default: ``False``
900

901
    Example:
902
        >>> # xdoctest: +SKIP("undefined variables")
903
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
904
        >>> with torch.cuda.profiler.profile():
905
        ...     model(x)  # Warmup CUDA memory allocator and profiler
906
        ...     with torch.autograd.profiler.emit_nvtx():
907
        ...         model(x)
908

909
    **Forward-backward correlation**
910

911
    When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
912
    correlating each backward-pass op with the corresponding forward-pass op can be difficult.
913
    To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
914
    generates.
915

916
    During the forward pass, each function range is decorated with ``seq=<N>``.  ``seq`` is a running
917
    counter, incremented each time a new backward Function object is created and stashed for backward.
918
    Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
919
    if a backward Function object is created by this forward function,
920
    the backward object will receive sequence number N.
921
    During the backward pass, the top-level range wrapping each C++ backward Function's
922
    ``apply()`` call is decorated with ``stashed seq=<M>``.  ``M`` is the sequence number that
923
    the backward object was created with.  By comparing ``stashed seq`` numbers in backward with ``seq``
924
    numbers in forward, you can track down which forward op created each backward Function.
925

926
    Any functions executed during the backward pass are also decorated with ``seq=<N>``.  During
927
    default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
928
    ``N`` may simply be 0 for all such functions.  Only the top-level ranges associated with
929
    backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
930
    objects with the earlier forward pass.
931

932
    **Double-backward**
933

934
    If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
935
    if you are setting up for a double-backward), each function's execution during backward
936
    is given a nonzero, useful ``seq=<N>``.  Those functions may themselves create Function objects
937
    to be executed later during double-backward, just as the original functions in the forward pass did.
938
    The relationship between backward and double-backward is conceptually the same as the relationship
939
    between forward and backward: The functions still emit current-sequence-number-tagged ranges,
940
    the Function objects they create still stash those sequence numbers, and during the eventual
941
    double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
942
    numbers, which can be compared to `seq` numbers from the backward pass.
943

944
    .. warning:
945
        The sequence number is thread-local, and some forward functions don't create an associated
946
        backward Function object (instead delegating that to sub-functions further down the call chain).
947
        For these reasons, the correspondence of stashed sequence numbers in
948
        backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
949
        not guaranteed to be 1 to 1.  The sequence numbers alone may not be enough to fully
950
        disambiguate which forward function created which
951
        backward Function object.  You may need to make a judgment based on analytic knowledge of what
952
        the expected correspondence should be.
953
    """
954

955
    def __init__(self, enabled=True, record_shapes=False):
956
        self.enabled = enabled
957
        self.entered = False
958
        self.record_shapes = record_shapes
959

960
    def __enter__(self):
961
        if not self.enabled:
962
            return
963
        if self.entered:
964
            raise RuntimeError("NVTX annotation context manager is not reentrant")
965
        self.entered = True
966
        torch.cuda.synchronize()
967
        _run_on_profiler_start()
968
        _enable_profiler(
969
            ProfilerConfig(
970
                ProfilerState.NVTX,
971
                self.record_shapes,
972
                False,
973
                False,
974
                False,
975
                False,
976
                _ExperimentalConfig(),
977
            ),
978
            set(),
979
        )
980
        return self
981

982
    def __exit__(self, exc_type, exc_val, exc_tb):
983
        if not self.enabled:
984
            return
985
        torch.cuda.synchronize()
986
        _disable_profiler()
987
        _run_on_profiler_stop()
988
        return False
989

990

991
def load_nvprof(path):
992
    """Open an nvprof trace file and parses autograd annotations.
993

994
    Args:
995
        path (str): path to nvprof trace
996
    """
997
    return EventList(parse_nvprof_trace(path))
998

999

1000
class EnforceUnique:
1001
    """Raises an error if a key is seen more than once."""
1002

1003
    def __init__(self):
1004
        self.seen = set()
1005

1006
    def see(self, *key):
1007
        r"""
1008
        Observe a key and raise an error if it is seen multiple times.
1009
        """
1010
        if key in self.seen:
1011
            raise RuntimeError("duplicate key: " + str(key))
1012
        self.seen.add(key)
1013

1014

1015
def parse_nvprof_trace(path):
1016
    import sqlite3
1017

1018
    conn = sqlite3.connect(path)
1019
    conn.row_factory = sqlite3.Row
1020

1021
    # Parse strings table
1022
    strings = {}
1023
    for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
1024
        strings[r["id"]] = torch._C._demangle(r["value"])
1025

1026
    # First, find all functions and create FunctionEvents for them
1027
    marker_query = """
1028
    SELECT
1029
        start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
1030
    FROM
1031
        CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1032
        ON start.id = end.id
1033
    WHERE
1034
        start.name != 0 AND end.name = 0
1035
    """
1036
    functions = []
1037
    functions_map = {}
1038
    unique = EnforceUnique()
1039
    for row in conn.execute(marker_query):
1040
        unique.see(row["marker_id"])
1041
        evt = FunctionEvent(
1042
            id=row["marker_id"],
1043
            node_id=0,  # missing a node_id when calling FunctionEvent. This is just to ensure
1044
            # that pytorch doesn't crash when creating a FunctionEvent() object
1045
            name=strings[row["name"]],
1046
            start_us=row["start_time"],
1047
            end_us=row["end_time"],
1048
            thread=0,
1049
        )  # TODO: find in sqlite database
1050
        functions.append(evt)
1051
        functions_map[evt.id] = evt
1052

1053
    # Now, correlate all kernels with FunctionEvents
1054
    kernel_query = """
1055
    SELECT
1056
        start.id AS marker_id, start.name, start.timestamp, end.timestamp,
1057
        runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
1058
        kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
1059
    FROM
1060
        CUPTI_ACTIVITY_KIND_MARKER AS start
1061
        INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1062
            ON start.id = end.id
1063
        INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
1064
            ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
1065
        INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
1066
            ON kernel.correlationId = runtime.correlationId
1067
    """
1068
    unique = EnforceUnique()
1069
    for row in conn.execute(kernel_query):
1070
        unique.see(row["marker_id"], row["runtime_id"])
1071
        # 211 is cudaKernelLaunch for cuda >= 9.2
1072
        assert row["cbid"] == 211
1073
        evt = functions_map[row["marker_id"]]
1074
        evt.append_kernel(
1075
            row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
1076
        )
1077

1078
    functions.sort(key=lambda evt: evt.time_range.start)
1079
    return functions
1080

1081

1082
class KinetoStepTracker:
1083
    """Provides an abstraction for incrementing the step count globally.
1084

1085
    Previously, we only had one place to mark that a step() has occurred
1086
    in the program via pytorch profiler step(). We will now add step hooks
1087
    in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
1088

1089
    - This could mean programs that already call profiler.step() every
1090
      iteration can end up double incrementing step count.
1091
    - If a model uses multiple optimizers we can also have double or more
1092
      counting of the step.
1093

1094
    We fix this by adding a layer of abstraction before calling step()
1095
    to the kineto library. The idea is to maintain steps per requester in a dict:
1096

1097
    .. code-block::
1098

1099
        {
1100
           "ProfilerStep": 100,  # triggered by profiler step() call
1101
           "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
1102
           "Optimizer2Step": 100,
1103
        }
1104

1105
    To figure out the global step count just take the max of dict values (100).
1106

1107
    If one of the count increments the max will go up.
1108

1109
    .. code-block::
1110

1111
        {
1112
           "ProfilerStep": 100,
1113
           "Optimizer1Step": 101,   # Optimizer1 got incremented first say
1114
           "Optimizer2Step": 100,
1115
        }
1116

1117
    Then global step count is 101
1118
    We only call the kineto step() function when global count increments.
1119

1120
    NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
1121
    for now. The result could be incorrect increments of the step count.
1122
    """
1123

1124
    _current_step = 0
1125
    _step_dict: Dict[str, int] = defaultdict(int)
1126

1127
    @classmethod
1128
    def init_step_count(cls, requester: str):
1129
        r"""
1130
        Initialize for a given requester.
1131
        """
1132
        cls._step_dict[requester] = cls._current_step
1133

1134
    @classmethod
1135
    def erase_step_count(cls, requester: str) -> bool:
1136
        r"""
1137
        Remove a given requester.
1138
        """
1139
        return cls._step_dict.pop(requester, None) is not None
1140

1141
    @classmethod
1142
    def increment_step(cls, requester: str) -> int:
1143
        """Increments the step count for the requester.
1144

1145
        Additionally if the max over all step counts has incremented then
1146
        trigger the _kineto_step() returns global step count
1147
        """
1148
        if requester not in cls._step_dict:
1149
            cls.init_step_count(requester)
1150
        cls._step_dict[requester] += 1
1151

1152
        new_step = max(cls._step_dict.values())
1153
        if new_step > cls._current_step:
1154
            delta = new_step - cls._current_step
1155
            if delta > 1:
1156
                warn(
1157
                    "Profiler step count has increased more than 1 - "
1158
                    f"current_step = {cls._current_step} step dict =  {cls._step_dict}"
1159
                )
1160
            for _ in range(0, delta):
1161
                _kineto_step()
1162
            cls._current_step = new_step
1163
        return cls._current_step
1164

1165
    @classmethod
1166
    def current_step(cls) -> int:
1167
        r"""
1168
        Get the latest step for any requester
1169
        """
1170
        return cls._current_step
1171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.