pytorch

Форк
0
/
profiler.py 
1098 строк · 42.9 Кб
1
from collections import defaultdict
2
from typing import Any, Dict, List, Optional
3
from warnings import warn
4

5
import torch
6

7
import torch.cuda
8
from torch._C import _get_privateuse1_backend_name
9
from torch._C._profiler import _ExperimentalConfig
10

11
from torch.autograd import (
12
    _disable_profiler,
13
    _enable_profiler,
14
    _kineto_step,
15
    _prepare_profiler,
16
    _ProfilerResult,
17
    _supported_activities,
18
    DeviceType,
19
    kineto_available,
20
    ProfilerActivity,
21
    ProfilerConfig,
22
    ProfilerState,
23
)
24
from torch.autograd.profiler_util import (
25
    _filter_name,
26
    _filter_stack_entry,
27
    _rewrite_name,
28
    EventList,
29
    FunctionEvent,
30
    MEMORY_EVENT_NAME,
31
    MemRecordsAcc,
32
    OUT_OF_MEMORY_EVENT_NAME,
33
)
34
from torch.futures import Future
35

36
__all__ = [
37
    "profile",
38
    "record_function",
39
    "emit_itt",
40
    "emit_nvtx",
41
    "load_nvprof",
42
    "EnforceUnique",
43
    "parse_nvprof_trace",
44
    "KinetoStepTracker",
45
    "EventList",
46
    "FunctionEvent",
47
    "MemRecordsAcc",
48
]
49

50
try:
51
    # Available in Python >= 3.2
52
    from contextlib import ContextDecorator as _ContextDecorator
53
except ImportError:
54
    import functools
55

56
    class _ContextDecorator:  # type: ignore[no-redef]
57
        def __enter__(self):
58
            raise NotImplementedError
59

60
        def __exit__(self, exc_type, exc_val, exc_tb):
61
            raise NotImplementedError
62

63
        def __call__(self, func):
64
            @functools.wraps(func)
65
            def wrapped(*args, **kwargs):
66
                with self:
67
                    return func(*args, **kwargs)
68

69
            return wrapped
70

71

72
# global python state - whether profiler is currently enabled
73
# useful for fast python checks to reduce latency
74
_is_profiler_enabled: bool = False
75

76

77
def _set_is_profiler_enabled(enable: bool):
78
    global _is_profiler_enabled
79
    _is_profiler_enabled = enable
80

81

82
def _run_on_profiler_start():
83
    _set_is_profiler_enabled(True)
84

85

86
def _run_on_profiler_stop():
87
    _set_is_profiler_enabled(False)
88

89

90
class profile:
91
    """Context manager that manages autograd profiler state and holds a summary of results.
92

93
    Under the hood it just records events of functions being executed in C++ and
94
    exposes those events to Python. You can wrap any code into it and it will
95
    only report runtime of PyTorch functions.
96
    Note: profiler is thread local and is automatically propagated into the async tasks
97

98
    Args:
99
        enabled (bool, optional): Setting this to False makes this context manager a no-op.
100

101
        use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
102
            Adds approximately 4us of overhead to each tensor operation.
103

104
        use_xpu (bool, optional): Enables timing of XPU events.
105
            Only supports Kineto profiling while XPU backend is available.
106

107
        record_shapes (bool, optional): If shapes recording is set, information
108
            about input dimensions will be collected. This allows one to see which
109
            dimensions have been used under the hood and further group by them
110
            using prof.key_averages(group_by_input_shape=True). Please note that
111
            shape recording might skew your profiling data. It is recommended to
112
            use separate runs with and without shape recording to validate the timing.
113
            Most likely the skew will be negligible for bottom most events (in a case
114
            of nested function calls). But for higher level functions the total
115
            self cpu time might be artificially increased because of the shape
116
            collection.
117

118
        with_flops (bool, optional): If with_flops is set, the profiler will estimate
119
            the FLOPs (floating point operations) value using the operator's input shape.
120
            This allows one to estimate the hardware performance. Currently,
121
            this option only works for the matrix multiplication and 2D convolution operators.
122

123
        profile_memory (bool, optional): track tensor memory allocation/deallocation.
124

125
        with_stack (bool, optional): record source information (file and line number) for the ops.
126

127
        with_modules (bool): record module hierarchy (including function names)
128
            corresponding to the callstack of the op. e.g. If module A's forward call's
129
            module B's forward which contains an aten::add op,
130
            then aten::add's module hierarchy is A.B
131
            Note that this support exist, at the moment, only for TorchScript models
132
            and not eager mode models.
133

134
        use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
135

136
        use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
137
            ``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
138

139
        experimental_config (_ExperimentalConfig) : A set of experimental options
140
            used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
141

142

143
    .. warning:
144
        Enabling memory profiling or source attribution incurs additional profiler
145
        overhead
146

147
    .. warning:
148
        This context managers should not be called recursively, i.e. no nested
149
        instances are allowed
150

151
    .. warning:
152
        Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
153
        one cannot use the profiler with ``use_cuda = True`` to benchmark
154
        DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
155
        please use ``use_cuda = False`` or ``num_workers = 0``.
156

157
    Example:
158
        >>> # xdoctest: +SKIP
159
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
160
        >>> x = torch.randn((1, 1), requires_grad=True)
161
        >>> with torch.autograd.profiler.profile() as prof:
162
        >>>     for _ in range(100):  # any normal python code, really!
163
        >>>         y = x ** 2
164
        >>>         y.backward()
165
        >>> # NOTE: some columns were removed for brevity
166
        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
167
        -----------------------------------  ---------------  ---------------  ---------------
168
        Name                                 Self CPU total   CPU time avg     Number of Calls
169
        -----------------------------------  ---------------  ---------------  ---------------
170
        mul                                  32.048ms         32.048ms         200
171
        pow                                  27.041ms         27.041ms         200
172
        PowBackward0                         9.727ms          55.483ms         100
173
        torch::autograd::AccumulateGrad      9.148ms          9.148ms          100
174
        torch::autograd::GraphRoot           691.816us        691.816us        100
175
        -----------------------------------  ---------------  ---------------  ---------------
176

177
    """
178

179
    def __init__(
180
        self,
181
        enabled=True,
182
        *,
183
        use_cuda=False,
184
        use_xpu=False,
185
        use_device=None,
186
        record_shapes=False,
187
        with_flops=False,
188
        profile_memory=False,
189
        with_stack=False,
190
        with_modules=False,
191
        use_kineto=False,
192
        use_cpu=True,
193
        use_mtia=False,
194
        experimental_config=None,
195
    ):
196
        self.enabled: bool = enabled
197
        if not self.enabled:
198
            return
199
        self.use_cuda = use_cuda
200
        self.use_xpu = use_xpu
201
        self.use_device: Optional[str] = (
202
            use_device if use_device != "privateuseone" else None
203
        )
204
        self.function_events: Optional[EventList] = None
205
        self.entered = False
206
        self.record_shapes = record_shapes
207
        self.with_flops = with_flops
208
        self.record_shapes |= self.with_flops
209
        self.profile_memory = profile_memory
210
        self.with_stack = with_stack
211
        self.with_modules = with_modules
212
        self.use_cpu = use_cpu
213
        self.use_mtia = use_mtia
214
        if experimental_config is None:
215
            experimental_config = _ExperimentalConfig()
216
        self.experimental_config = experimental_config
217
        self.kineto_results: Optional[_ProfilerResult] = None
218

219
        if not self.use_cpu:
220
            assert (
221
                use_kineto
222
            ), "Device-only events supported only with Kineto (use_kineto=True)"
223

224
        if self.use_device == "cuda":
225
            self.use_device = None
226
            self.use_cuda = True
227
            self.use_xpu = False
228
        elif self.use_device == "xpu":
229
            self.use_device = None
230
            self.use_cuda = False
231
            self.use_xpu = True
232

233
        if self.use_device and self.use_device != _get_privateuse1_backend_name():
234
            warn(f"{self.use_device} doesn't support profile.")
235
            self.use_device = None
236

237
        if self.use_cuda and not torch.cuda.is_available():
238
            warn("CUDA is not available, disabling CUDA profiling")
239
            self.use_cuda = False
240

241
        if self.use_xpu and (
242
            not hasattr(torch, "xpu") or not torch.xpu.is_available()
243
        ):  # type: ignore[attr-defined]
244
            warn("XPU is not available, disabling XPU profiling")
245
            self.use_xpu = False
246

247
        self.kineto_activities = set()
248
        if self.use_cpu:
249
            self.kineto_activities.add(ProfilerActivity.CPU)
250
        if self.use_mtia:
251
            self.kineto_activities.add(ProfilerActivity.MTIA)
252

253
        self.profiler_kind = ProfilerState.KINETO
254
        if self.use_cuda:
255
            if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
256
                assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
257
                self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
258
            else:
259
                self.kineto_activities.add(ProfilerActivity.CUDA)
260
        if self.use_xpu:
261
            assert (
262
                use_kineto or ProfilerActivity.XPU in _supported_activities()
263
            ), "Profiling on XPU backend must enable Kineto"
264
            self.kineto_activities.add(ProfilerActivity.XPU)
265
        if self.use_device:
266
            if (
267
                not use_kineto
268
                or ProfilerActivity.PrivateUse1 not in _supported_activities()
269
            ):
270
                assert (
271
                    self.use_cpu
272
                ), "Legacy custombackend profiling requires use_cpu=True"
273
                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
274
            else:
275
                self.kineto_activities.add(ProfilerActivity.PrivateUse1)
276
                self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1
277

278
        assert (
279
            len(self.kineto_activities) > 0
280
        ), "No activities specified for the profiler"
281

282
    def config(self):
283
        return ProfilerConfig(
284
            self.profiler_kind,
285
            self.record_shapes,
286
            self.profile_memory,
287
            self.with_stack,
288
            self.with_flops,
289
            self.with_modules,
290
            self.experimental_config,
291
        )
292

293
    def __enter__(self):
294
        if not self.enabled:
295
            return
296
        if self.entered:
297
            raise RuntimeError("Profiler context manager is not reentrant")
298
        self._prepare_trace()
299
        self._start_trace()
300
        return self
301

302
    def _prepare_trace(self):
303
        self.entered = True
304
        _prepare_profiler(self.config(), self.kineto_activities)
305

306
    def _start_trace(self):
307
        self.entered = True
308
        _run_on_profiler_start()
309
        _enable_profiler(self.config(), self.kineto_activities)
310

311
    def __exit__(self, exc_type, exc_val, exc_tb):
312
        if not self.enabled:
313
            return
314
        if self.use_cuda:
315
            torch.cuda.synchronize()
316
        if self.use_xpu:
317
            torch.xpu.synchronize()  # type: ignore[attr-defined]
318
        self.kineto_results = _disable_profiler()
319
        _run_on_profiler_stop()
320
        parsed_results = self._parse_kineto_results(self.kineto_results)
321
        self.function_events = EventList(
322
            parsed_results,
323
            use_cuda=self.use_cuda,
324
            use_xpu=self.use_xpu,
325
            use_device=self.use_device,
326
            profile_memory=self.profile_memory,
327
            with_flops=self.with_flops,
328
        )
329
        self.function_events._build_tree()
330
        return False
331

332
    def __repr__(self):
333
        if self.function_events is None:
334
            return "<unfinished torch.autograd.profile>"
335
        return repr(self.function_events)
336

337
    def __str__(self):
338
        if self.function_events is None:
339
            return "<unfinished torch.autograd.profile>"
340
        return str(self.function_events)
341

342
    def _check_finish(self):
343
        if self.function_events is None:
344
            raise RuntimeError("Profiler didn't finish running")
345

346
    def table(
347
        self,
348
        sort_by=None,
349
        row_limit=100,
350
        max_src_column_width=75,
351
        max_name_column_width=55,
352
        max_shapes_column_width=80,
353
        header=None,
354
        top_level_events_only=False,
355
    ):
356
        self._check_finish()
357
        assert self.function_events is not None
358
        return self.function_events.table(
359
            sort_by=sort_by,
360
            row_limit=row_limit,
361
            max_src_column_width=max_src_column_width,
362
            max_name_column_width=max_name_column_width,
363
            max_shapes_column_width=max_shapes_column_width,
364
            header=header,
365
            top_level_events_only=top_level_events_only,
366
        )
367

368
    table.__doc__ = EventList.table.__doc__
369

370
    def export_chrome_trace(self, path):
371
        self._check_finish()
372
        if kineto_available():
373
            self.kineto_results.save(path)  # type: ignore[union-attr]
374
        else:
375
            return self.function_events.export_chrome_trace(path)  # type: ignore[union-attr]
376

377
    export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
378

379
    def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
380
        self._check_finish()
381
        assert self.function_events is not None, "Expected profiling results"
382
        assert self.with_stack, "export_stacks() requires with_stack=True"
383
        return self.function_events.export_stacks(path, metric)
384

385
    def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
386
        self._check_finish()
387
        assert self.function_events is not None, "Expected profiling results"
388
        return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
389

390
    key_averages.__doc__ = EventList.key_averages.__doc__
391

392
    def total_average(self):
393
        self._check_finish()
394
        assert self.function_events is not None, "Expected profiling results"
395
        return self.function_events.total_average()
396

397
    total_average.__doc__ = EventList.total_average.__doc__
398

399
    @property
400
    def self_cpu_time_total(self):
401
        """Returns total time spent on CPU.
402

403
        The total time is a sum of all self times across all the events.
404
        """
405
        self._check_finish()
406
        assert self.function_events is not None
407
        return self.function_events.self_cpu_time_total
408

409
    def _parse_kineto_results(self, result: _ProfilerResult):
410
        # result.events() has most of the events - PyTorch op-level and device-level events
411

412
        trace_start_us = result.trace_start_us()
413
        mem_records = [
414
            [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
415
        ]
416
        oom_records = [
417
            evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
418
        ]
419
        mem_records_acc = MemRecordsAcc(mem_records)
420

421
        def _cpu_memory_usage(mem_record):
422
            return (
423
                mem_record.nbytes()
424
                if mem_record.device_type()
425
                in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
426
                else 0
427
            )
428

429
        def _cuda_memory_usage(mem_record):
430
            return (
431
                mem_record.nbytes()
432
                if mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP]
433
                else 0
434
            )
435

436
        def _xpu_memory_usage(mem_record):
437
            return (
438
                mem_record.nbytes()
439
                if mem_record.device_type() in [DeviceType.XPU]
440
                else 0
441
            )
442

443
        def _privateuse1_memory_usage(mem_record):
444
            return (
445
                mem_record.nbytes()
446
                if mem_record.device_type() in [DeviceType.PrivateUse1]
447
                else 0
448
            )
449

450
        # Create and return FunctionEvent list
451
        function_events = []
452
        device_corr_map: Dict[int, List[FunctionEvent]] = {}
453
        device_corr_map_values: List[FunctionEvent] = []
454
        max_evt_id = 0
455
        for kineto_event in result.events():
456
            if _filter_name(kineto_event.name()):
457
                continue
458
            rel_start_us = kineto_event.start_us() - trace_start_us
459
            rel_end_us = rel_start_us + kineto_event.duration_us()
460
            abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
461

462
            cpu_memory_usage = 0
463
            cuda_memory_usage = 0
464
            xpu_memory_usage = 0
465
            privateuse1_memory_usage = 0
466
            if kineto_event.device_type() == DeviceType.CPU:
467
                # find the corresponding memory allocation events
468
                for mem_record in mem_records_acc.in_interval(
469
                    kineto_event.start_us(), abs_end_us
470
                ):
471
                    cpu_memory_usage += _cpu_memory_usage(mem_record[0])
472
                    cuda_memory_usage += _cuda_memory_usage(mem_record[0])
473
                    xpu_memory_usage += _xpu_memory_usage(mem_record[0])
474
                    privateuse1_memory_usage += _privateuse1_memory_usage(mem_record[0])
475
                    mem_record[1] = True
476

477
            is_async = kineto_event.is_async() or (
478
                kineto_event.start_thread_id() != kineto_event.end_thread_id()
479
            )
480

481
            fe = FunctionEvent(
482
                id=kineto_event.correlation_id(),
483
                name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
484
                trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
485
                thread=kineto_event.start_thread_id(),
486
                start_us=rel_start_us,
487
                end_us=rel_end_us,
488
                fwd_thread=kineto_event.fwd_thread_id(),
489
                input_shapes=kineto_event.shapes(),
490
                concrete_inputs=kineto_event.concrete_inputs(),
491
                stack=[
492
                    entry
493
                    for entry in kineto_event.stack()
494
                    if _filter_stack_entry(entry)
495
                ],
496
                scope=kineto_event.scope(),
497
                use_device=self.use_device,
498
                cpu_memory_usage=cpu_memory_usage,
499
                cuda_memory_usage=cuda_memory_usage,
500
                xpu_memory_usage=xpu_memory_usage,
501
                privateuse1_memory_usage=privateuse1_memory_usage,
502
                is_async=is_async,
503
                sequence_nr=kineto_event.sequence_nr(),
504
                device_type=kineto_event.device_type(),
505
                device_index=kineto_event.device_index(),
506
                flops=kineto_event.flops(),
507
            )
508
            max_evt_id = max(max_evt_id, fe.id)
509
            if fe.device_type == DeviceType.CPU and not fe.is_async:
510
                if self.use_device:
511
                    privateuse1_time = kineto_event.privateuse1_elapsed_us()
512
                    if privateuse1_time > 0:
513
                        fe.append_privateuse1_kernel(
514
                            fe.name, fe.device_index, privateuse1_time
515
                        )
516
                        fe.is_legacy = True
517
                else:
518
                    # Check if we have CUDA time as a fallback
519
                    cuda_time = kineto_event.cuda_elapsed_us()
520
                    if cuda_time > 0:
521
                        fe.append_kernel(fe.name, fe.device_index, cuda_time)
522
                        fe.is_legacy = True
523
            function_events.append(fe)
524
            corr_id = kineto_event.linked_correlation_id()
525
            if corr_id > 0:
526
                if corr_id not in device_corr_map:
527
                    device_corr_map[corr_id] = []
528
                device_corr_map[corr_id].append(fe)
529
        # prepare the values list for corr map to avoid duplicately appending kernel
530
        # while some CPU Kineto events (e.g. CUDA/XPU Runtime Activity) have a same
531
        # correlation ID as CPU events
532
        device_corr_map_values = [
533
            fe for fe_list in device_corr_map.values() for fe in fe_list
534
        ]
535

536
        # associate CUDA/XPU kernels and CUDA/XPU runtime (CPU) with CPU events
537
        for fe in function_events:
538
            if (
539
                fe.device_type == DeviceType.CPU
540
                and not fe.is_async
541
                and fe.id in device_corr_map
542
                and fe not in device_corr_map_values
543
            ):
544
                for f_evt in device_corr_map[fe.id]:
545
                    if f_evt.device_type == DeviceType.CUDA:
546
                        fe.append_kernel(
547
                            f_evt.name,
548
                            f_evt.device_index,
549
                            f_evt.time_range.end - f_evt.time_range.start,
550
                        )
551
                    elif f_evt.device_type == DeviceType.XPU:
552
                        fe.append_xpu_kernel(
553
                            f_evt.name,
554
                            f_evt.device_index,
555
                            f_evt.time_range.end - f_evt.time_range.start,
556
                        )
557
                    elif f_evt.device_type == DeviceType.PrivateUse1:
558
                        fe.append_privateuse1_kernel(
559
                            f_evt.name,
560
                            f_evt.device_index,
561
                            f_evt.time_range.end - f_evt.time_range.start,
562
                        )
563
                    elif f_evt.device_type == DeviceType.CPU:
564
                        # make sure that 'thread' of a CPU Kineto (e.g. CUDA/XPU Runtime) event is associated
565
                        # with the 'thread' of the corresponding linked PyTorch event to properly track
566
                        # parents and children
567
                        f_evt.thread = fe.thread
568

569
        def createFunctionEventForMemoryEvents(evt):
570
            rel_start_us = evt.start_us() - trace_start_us
571
            fe = FunctionEvent(
572
                id=max_evt_id,
573
                name=evt.name(),
574
                trace_name=None,  # not outputting in the trace
575
                thread=evt.start_thread_id(),
576
                start_us=rel_start_us,
577
                end_us=rel_start_us,  # no duration
578
                fwd_thread=evt.start_thread_id(),
579
                input_shapes=[],
580
                stack=[],
581
                scope=0,  # RecordScope::FUNCTION
582
                use_device=self.use_device,
583
                cpu_memory_usage=_cpu_memory_usage(evt),
584
                cuda_memory_usage=_cuda_memory_usage(evt),
585
                xpu_memory_usage=_xpu_memory_usage(evt),
586
                privateuse1_memory_usage=_privateuse1_memory_usage(evt),
587
                is_async=False,
588
                sequence_nr=-1,
589
                device_type=DeviceType.CPU,
590
                device_index=0,
591
            )
592
            return fe
593

594
        # output top-level memory events
595
        for mem_record in mem_records:
596
            if not mem_record[1]:
597
                max_evt_id += 1
598
                fe = createFunctionEventForMemoryEvents(mem_record[0])
599
                function_events.append(fe)
600

601
        for oom_record in oom_records:
602
            max_evt_id += 1
603
            fe = createFunctionEventForMemoryEvents(oom_record)
604
            function_events.append(fe)
605

606
        function_events.sort(
607
            key=lambda evt: [evt.time_range.start, -evt.time_range.end]
608
        )
609
        return function_events
610

611

612
class record_function(_ContextDecorator):
613
    """Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
614

615
    It is useful when tracing the code profile.
616

617
    Args:
618
        name (str): Label assigned to the block of code.
619
        node_id (int): ID of node, for distributed profiling. Unset in
620
        non-distributed cases.
621

622
    Example:
623
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
624
        >>> x = torch.randn((1, 1), requires_grad=True)
625
        >>> with torch.autograd.profiler.profile() as prof:
626
        ...     y = x ** 2
627
        ...     with torch.autograd.profiler.record_function("label-z"): # label the block
628
        ...         z = y ** 3
629
        ...     y.backward()
630
        ...
631
        >>> # xdoctest: +IGNORE_WANT
632
        >>> # NOTE: some columns were removed for brevity
633
        >>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
634
        -----------------------------------  ---------------  ---------------  ---------------
635
        Name                                 Self CPU total %  CPU time avg     Number of Calls
636
        -----------------------------------  ---------------  ---------------  ---------------
637
        pow                                  60.77%           47.470us         3
638
        mul                                  21.73%           25.465us         2
639
        PowBackward0                         12.03%           121.891us        1
640
        torch::autograd::AccumulateGrad      2.70%            6.324us          1
641
        label-z                              2.13%            12.421us         1
642
        torch::autograd::GraphRoot           0.64%            1.503us          1
643
        -----------------------------------  ---------------  ---------------  ---------------
644
        Self CPU time total: 234.344us
645
        CUDA time total: 0.000us
646

647
    """
648

649
    def __init__(self, name: str, args: Optional[str] = None):
650
        self.name: str = name
651
        self.args: Optional[str] = args
652
        # Whether or not we should run record function's end callbacks when exiting.
653
        self.run_callbacks_on_exit: bool = True
654
        # TODO: TorchScript ignores standard type annotation here
655
        # self.record: Optional["torch.classes.profiler._RecordFunction"] = None
656
        self.record = torch.jit.annotate(
657
            Optional["torch.classes.profiler._RecordFunction"], None
658
        )
659

660
    def __enter__(self):
661
        self.record = torch.ops.profiler._record_function_enter_new(
662
            self.name, self.args
663
        )
664
        return self
665

666
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
667
        if not self.run_callbacks_on_exit:
668
            return
669

670
        # Local variable is needed by TorchScript to refine Optional[T] to T
671
        record = self.record
672
        assert record is not None
673

674
        # TODO: Too slow with __torch_function__ handling enabled
675
        # See https://github.com/pytorch/pytorch/issues/76410
676
        if not torch.jit.is_scripting():
677
            with torch._C.DisableTorchFunctionSubclass():
678
                torch.ops.profiler._record_function_exit._RecordFunction(record)
679
        else:
680
            torch.ops.profiler._record_function_exit(record)
681

682
    def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
683
        """Use for profiling async calls that return a future.
684

685
        Calling this function will extend recording beyond this scope, until the future is
686
        satisfied. It is useful for profiling the end to end time of asynchronous calls.
687
        This function should only be called once to attach the callback onto the future, and
688
        will throw if called multiple times.
689

690
        Args:
691
            fut: (torch._C.Future): future for which to schedule
692
            callback for.
693

694
        Returns:
695
            A future that completes with the value of the passed in future when
696
            the profiling callbacks have ran.
697

698
        """
699
        # Throw if we have already attached a callback onto the future.
700
        if not self.run_callbacks_on_exit:
701
            raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
702

703
        # We are scheduling to run this RecordFunction's end callbacks when the
704
        # passed in future completes, so don't run end callbacks on exit.
705
        self.run_callbacks_on_exit = False
706

707
        # Local variable is needed by TorchScript to refine Optional[T] to T
708
        record = self.record
709
        assert record is not None
710

711
        # TODO: Too slow with __torch_function__ handling enabled
712
        # See https://github.com/pytorch/pytorch/issues/76410
713
        if not torch.jit.is_scripting():
714
            with torch._C.DisableTorchFunctionSubclass():
715
                profiled_future = (
716
                    torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
717
                        record, fut
718
                    )
719
                )
720
        else:
721
            profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
722
                record, fut
723
            )
724
        return profiled_future
725

726

727
class emit_itt:
728
    """Context manager that makes every autograd operation emit an ITT range.
729

730
    It is useful when running the program under Intel(R) VTune Profiler::
731

732
        vtune <--vtune-flags> <regular command here>
733

734
    The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
735
    control the collection of trace data during its execution across different Intel tools.
736
    This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
737
    you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
738

739
    .. warning:
740
        This context manager should not be called recursively, i.e. at most one
741
        instance should be enabled at any given time.
742

743
    Args:
744
        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
745
            Default: ``True``.
746
        record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
747
            each autograd op will append information about the sizes of Tensor arguments received
748
            by that op, in the following format:
749
            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
750
            Non-tensor arguments will be represented by ``[]``.
751
            Arguments will be listed in the order they are received by the backend op.
752
            Please note that this order may not match the order in which those arguments were passed
753
            on the Python side.  Also note that shape recording may increase the overhead of itt range creation.
754
            Default: ``False``
755

756
    Example:
757
        >>> # xdoctest: +SKIP("Undefined variables")
758
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
759
        >>> with torch.autograd.profiler.emit_itt():
760
        ...     model(x)
761

762
    """
763

764
    def __init__(self, enabled=True, record_shapes=False):
765
        self.enabled = enabled
766
        self.entered = False
767
        self.record_shapes = record_shapes
768

769
    def __enter__(self):
770
        if not self.enabled:
771
            return
772
        if self.entered:
773
            raise RuntimeError("ITT annotation context manager is not reentrant")
774
        self.entered = True
775
        _run_on_profiler_start()
776
        _enable_profiler(
777
            ProfilerConfig(
778
                ProfilerState.ITT,
779
                self.record_shapes,
780
                False,
781
                False,
782
                False,
783
                False,
784
                _ExperimentalConfig(),
785
            ),
786
            set(),
787
        )
788
        return self
789

790
    def __exit__(self, exc_type, exc_val, exc_tb):
791
        if not self.enabled:
792
            return
793
        _disable_profiler()
794
        _run_on_profiler_stop()
795
        return False
796

797

798
class emit_nvtx:
799
    """Context manager that makes every autograd operation emit an NVTX range.
800

801
    It is useful when running the program under nvprof::
802

803
        nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
804

805
    Unfortunately, there's no way to force nvprof to flush the data it collected
806
    to disk, so for CUDA profiling one has to use this context manager to annotate
807
    nvprof traces and wait for the process to exit before inspecting them.
808
    Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
809
    :func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
810
    e.g. in Python REPL.
811

812
    .. warning:
813
        This context manager should not be called recursively, i.e. at most one
814
        instance should be enabled at any given time.
815

816
    Args:
817
        enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
818
            Default: ``True``.
819
        record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
820
            each autograd op will append information about the sizes of Tensor arguments received
821
            by that op, in the following format:
822
            ``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
823
            Non-tensor arguments will be represented by ``[]``.
824
            Arguments will be listed in the order they are received by the backend op.
825
            Please note that this order may not match the order in which those arguments were passed
826
            on the Python side.  Also note that shape recording may increase the overhead of nvtx range creation.
827
            Default: ``False``
828

829
    Example:
830
        >>> # xdoctest: +SKIP("undefined variables")
831
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
832
        >>> with torch.cuda.profiler.profile():
833
        ...     model(x)  # Warmup CUDA memory allocator and profiler
834
        ...     with torch.autograd.profiler.emit_nvtx():
835
        ...         model(x)
836

837
    **Forward-backward correlation**
838

839
    When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
840
    correlating each backward-pass op with the corresponding forward-pass op can be difficult.
841
    To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
842
    generates.
843

844
    During the forward pass, each function range is decorated with ``seq=<N>``.  ``seq`` is a running
845
    counter, incremented each time a new backward Function object is created and stashed for backward.
846
    Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
847
    if a backward Function object is created by this forward function,
848
    the backward object will receive sequence number N.
849
    During the backward pass, the top-level range wrapping each C++ backward Function's
850
    ``apply()`` call is decorated with ``stashed seq=<M>``.  ``M`` is the sequence number that
851
    the backward object was created with.  By comparing ``stashed seq`` numbers in backward with ``seq``
852
    numbers in forward, you can track down which forward op created each backward Function.
853

854
    Any functions executed during the backward pass are also decorated with ``seq=<N>``.  During
855
    default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
856
    ``N`` may simply be 0 for all such functions.  Only the top-level ranges associated with
857
    backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
858
    objects with the earlier forward pass.
859

860
    **Double-backward**
861

862
    If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
863
    if you are setting up for a double-backward), each function's execution during backward
864
    is given a nonzero, useful ``seq=<N>``.  Those functions may themselves create Function objects
865
    to be executed later during double-backward, just as the original functions in the forward pass did.
866
    The relationship between backward and double-backward is conceptually the same as the relationship
867
    between forward and backward: The functions still emit current-sequence-number-tagged ranges,
868
    the Function objects they create still stash those sequence numbers, and during the eventual
869
    double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
870
    numbers, which can be compared to `seq` numbers from the backward pass.
871

872
    .. warning:
873
        The sequence number is thread-local, and some forward functions don't create an associated
874
        backward Function object (instead delegating that to sub-functions further down the call chain).
875
        For these reasons, the correspondence of stashed sequence numbers in
876
        backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
877
        not guaranteed to be 1 to 1.  The sequence numbers alone may not be enough to fully
878
        disambiguate which forward function created which
879
        backward Function object.  You may need to make a judgment based on analytic knowledge of what
880
        the expected correspondence should be.
881
    """
882

883
    def __init__(self, enabled=True, record_shapes=False):
884
        self.enabled = enabled
885
        self.entered = False
886
        self.record_shapes = record_shapes
887

888
    def __enter__(self):
889
        if not self.enabled:
890
            return
891
        if self.entered:
892
            raise RuntimeError("NVTX annotation context manager is not reentrant")
893
        self.entered = True
894
        torch.cuda.synchronize()
895
        _run_on_profiler_start()
896
        _enable_profiler(
897
            ProfilerConfig(
898
                ProfilerState.NVTX,
899
                self.record_shapes,
900
                False,
901
                False,
902
                False,
903
                False,
904
                _ExperimentalConfig(),
905
            ),
906
            set(),
907
        )
908
        return self
909

910
    def __exit__(self, exc_type, exc_val, exc_tb):
911
        if not self.enabled:
912
            return
913
        torch.cuda.synchronize()
914
        _disable_profiler()
915
        _run_on_profiler_stop()
916
        return False
917

918

919
def load_nvprof(path):
920
    """Open an nvprof trace file and parses autograd annotations.
921

922
    Args:
923
        path (str): path to nvprof trace
924
    """
925
    return EventList(parse_nvprof_trace(path))
926

927

928
class EnforceUnique:
929
    """Raises an error if a key is seen more than once."""
930

931
    def __init__(self):
932
        self.seen = set()
933

934
    def see(self, *key):
935
        r"""
936
        Observe a key and raise an error if it is seen multiple times.
937
        """
938
        if key in self.seen:
939
            raise RuntimeError("duplicate key: " + str(key))
940
        self.seen.add(key)
941

942

943
def parse_nvprof_trace(path):
944
    import sqlite3
945

946
    conn = sqlite3.connect(path)
947
    conn.row_factory = sqlite3.Row
948

949
    # Parse strings table
950
    strings = {}
951
    for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
952
        strings[r["id"]] = torch._C._demangle(r["value"])
953

954
    # First, find all functions and create FunctionEvents for them
955
    marker_query = """
956
    SELECT
957
        start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
958
    FROM
959
        CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
960
        ON start.id = end.id
961
    WHERE
962
        start.name != 0 AND end.name = 0
963
    """
964
    functions = []
965
    functions_map = {}
966
    unique = EnforceUnique()
967
    for row in conn.execute(marker_query):
968
        unique.see(row["marker_id"])
969
        evt = FunctionEvent(
970
            id=row["marker_id"],
971
            node_id=0,  # missing a node_id when calling FunctionEvent. This is just to ensure
972
            # that pytorch doesn't crash when creating a FunctionEvent() object
973
            name=strings[row["name"]],
974
            start_us=row["start_time"],
975
            end_us=row["end_time"],
976
            thread=0,
977
        )  # TODO: find in sqlite database
978
        functions.append(evt)
979
        functions_map[evt.id] = evt
980

981
    # Now, correlate all kernels with FunctionEvents
982
    kernel_query = """
983
    SELECT
984
        start.id AS marker_id, start.name, start.timestamp, end.timestamp,
985
        runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
986
        kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
987
    FROM
988
        CUPTI_ACTIVITY_KIND_MARKER AS start
989
        INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
990
            ON start.id = end.id
991
        INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
992
            ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
993
        INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
994
            ON kernel.correlationId = runtime.correlationId
995
    """
996
    unique = EnforceUnique()
997
    for row in conn.execute(kernel_query):
998
        unique.see(row["marker_id"], row["runtime_id"])
999
        # 211 is cudaKernelLaunch for cuda >= 9.2
1000
        assert row["cbid"] == 211
1001
        evt = functions_map[row["marker_id"]]
1002
        evt.append_kernel(
1003
            row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
1004
        )
1005

1006
    functions.sort(key=lambda evt: evt.time_range.start)
1007
    return functions
1008

1009

1010
class KinetoStepTracker:
1011
    """Provides an abstraction for incrementing the step count globally.
1012

1013
    Previously, we only had one place to mark that a step() has occurred
1014
    in the program via pytorch profiler step(). We will now add step hooks
1015
    in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
1016

1017
    - This could mean programs that already call profiler.step() every
1018
      iteration can end up double incrementing step count.
1019
    - If a model uses multiple optimizers we can also have double or more
1020
      counting of the step.
1021

1022
    We fix this by adding a layer of abstraction before calling step()
1023
    to the kineto library. The idea is to maintain steps per requester in a dict:
1024

1025
    .. code-block::
1026

1027
        {
1028
           "ProfilerStep": 100,  # triggered by profiler step() call
1029
           "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
1030
           "Optimizer2Step": 100,
1031
        }
1032

1033
    To figure out the global step count just take the max of dict values (100).
1034

1035
    If one of the count increments the max will go up.
1036

1037
    .. code-block::
1038

1039
        {
1040
           "ProfilerStep": 100,
1041
           "Optimizer1Step": 101,   # Optimizer1 got incremented first say
1042
           "Optimizer2Step": 100,
1043
        }
1044

1045
    Then global step count is 101
1046
    We only call the kineto step() function when global count increments.
1047

1048
    NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
1049
    for now. The result could be incorrect increments of the step count.
1050
    """
1051

1052
    _current_step = -1
1053
    _step_dict: Dict[str, int] = defaultdict(int)
1054

1055
    @classmethod
1056
    def init_step_count(cls, requester: str):
1057
        r"""
1058
        Initialize for a given requester.
1059
        """
1060
        cls._step_dict[requester] = cls._current_step
1061

1062
    @classmethod
1063
    def erase_step_count(cls, requester: str) -> bool:
1064
        r"""
1065
        Remove a given requester.
1066
        """
1067
        return cls._step_dict.pop(requester, None) is not None
1068

1069
    @classmethod
1070
    def increment_step(cls, requester: str) -> int:
1071
        """Increments the step count for the requester.
1072

1073
        Additionally if the max over all step counts has incremented then
1074
        trigger the _kineto_step() returns global step count
1075
        """
1076
        if requester not in cls._step_dict:
1077
            cls.init_step_count(requester)
1078
        cls._step_dict[requester] += 1
1079

1080
        new_step = max(cls._step_dict.values())
1081
        if new_step > cls._current_step:
1082
            delta = new_step - cls._current_step
1083
            if delta > 1:
1084
                warn(
1085
                    "Profiler step count has increased more than 1 - "
1086
                    f"current_step = {cls._current_step} step dict =  {cls._step_dict}"
1087
                )
1088
            for _ in range(0, delta):
1089
                _kineto_step()
1090
            cls._current_step = new_step
1091
        return cls._current_step
1092

1093
    @classmethod
1094
    def current_step(cls) -> int:
1095
        r"""
1096
        Get the latest step for any requester
1097
        """
1098
        return cls._current_step
1099

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.