1
from collections import defaultdict
2
from typing import Any, Dict, List, Optional
3
from warnings import warn
8
from torch._C import _get_privateuse1_backend_name
9
from torch._C._profiler import _ExperimentalConfig
11
from torch.autograd import (
17
_supported_activities,
24
from torch.autograd.profiler_util import (
32
OUT_OF_MEMORY_EVENT_NAME,
34
from torch.futures import Future
51
# Available in Python >= 3.2
52
from contextlib import ContextDecorator as _ContextDecorator
56
class _ContextDecorator: # type: ignore[no-redef]
58
raise NotImplementedError
60
def __exit__(self, exc_type, exc_val, exc_tb):
61
raise NotImplementedError
63
def __call__(self, func):
64
@functools.wraps(func)
65
def wrapped(*args, **kwargs):
67
return func(*args, **kwargs)
72
# global python state - whether profiler is currently enabled
73
# useful for fast python checks to reduce latency
74
_is_profiler_enabled: bool = False
77
def _set_is_profiler_enabled(enable: bool):
78
global _is_profiler_enabled
79
_is_profiler_enabled = enable
82
def _run_on_profiler_start():
83
_set_is_profiler_enabled(True)
86
def _run_on_profiler_stop():
87
_set_is_profiler_enabled(False)
91
"""Context manager that manages autograd profiler state and holds a summary of results.
93
Under the hood it just records events of functions being executed in C++ and
94
exposes those events to Python. You can wrap any code into it and it will
95
only report runtime of PyTorch functions.
96
Note: profiler is thread local and is automatically propagated into the async tasks
99
enabled (bool, optional): Setting this to False makes this context manager a no-op.
101
use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
102
Adds approximately 4us of overhead to each tensor operation.
104
use_xpu (bool, optional): Enables timing of XPU events.
105
Only supports Kineto profiling while XPU backend is available.
107
record_shapes (bool, optional): If shapes recording is set, information
108
about input dimensions will be collected. This allows one to see which
109
dimensions have been used under the hood and further group by them
110
using prof.key_averages(group_by_input_shape=True). Please note that
111
shape recording might skew your profiling data. It is recommended to
112
use separate runs with and without shape recording to validate the timing.
113
Most likely the skew will be negligible for bottom most events (in a case
114
of nested function calls). But for higher level functions the total
115
self cpu time might be artificially increased because of the shape
118
with_flops (bool, optional): If with_flops is set, the profiler will estimate
119
the FLOPs (floating point operations) value using the operator's input shape.
120
This allows one to estimate the hardware performance. Currently,
121
this option only works for the matrix multiplication and 2D convolution operators.
123
profile_memory (bool, optional): track tensor memory allocation/deallocation.
125
with_stack (bool, optional): record source information (file and line number) for the ops.
127
with_modules (bool): record module hierarchy (including function names)
128
corresponding to the callstack of the op. e.g. If module A's forward call's
129
module B's forward which contains an aten::add op,
130
then aten::add's module hierarchy is A.B
131
Note that this support exist, at the moment, only for TorchScript models
132
and not eager mode models.
134
use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
136
use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
137
``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
139
experimental_config (_ExperimentalConfig) : A set of experimental options
140
used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
144
Enabling memory profiling or source attribution incurs additional profiler
148
This context managers should not be called recursively, i.e. no nested
149
instances are allowed
152
Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
153
one cannot use the profiler with ``use_cuda = True`` to benchmark
154
DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
155
please use ``use_cuda = False`` or ``num_workers = 0``.
158
>>> # xdoctest: +SKIP
159
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
160
>>> x = torch.randn((1, 1), requires_grad=True)
161
>>> with torch.autograd.profiler.profile() as prof:
162
>>> for _ in range(100): # any normal python code, really!
165
>>> # NOTE: some columns were removed for brevity
166
>>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
167
----------------------------------- --------------- --------------- ---------------
168
Name Self CPU total CPU time avg Number of Calls
169
----------------------------------- --------------- --------------- ---------------
170
mul 32.048ms 32.048ms 200
171
pow 27.041ms 27.041ms 200
172
PowBackward0 9.727ms 55.483ms 100
173
torch::autograd::AccumulateGrad 9.148ms 9.148ms 100
174
torch::autograd::GraphRoot 691.816us 691.816us 100
175
----------------------------------- --------------- --------------- ---------------
188
profile_memory=False,
194
experimental_config=None,
196
self.enabled: bool = enabled
199
self.use_cuda = use_cuda
200
self.use_xpu = use_xpu
201
self.use_device: Optional[str] = (
202
use_device if use_device != "privateuseone" else None
204
self.function_events: Optional[EventList] = None
206
self.record_shapes = record_shapes
207
self.with_flops = with_flops
208
self.record_shapes |= self.with_flops
209
self.profile_memory = profile_memory
210
self.with_stack = with_stack
211
self.with_modules = with_modules
212
self.use_cpu = use_cpu
213
self.use_mtia = use_mtia
214
if experimental_config is None:
215
experimental_config = _ExperimentalConfig()
216
self.experimental_config = experimental_config
217
self.kineto_results: Optional[_ProfilerResult] = None
222
), "Device-only events supported only with Kineto (use_kineto=True)"
224
if self.use_device == "cuda":
225
self.use_device = None
228
elif self.use_device == "xpu":
229
self.use_device = None
230
self.use_cuda = False
233
if self.use_device and self.use_device != _get_privateuse1_backend_name():
234
warn(f"{self.use_device} doesn't support profile.")
235
self.use_device = None
237
if self.use_cuda and not torch.cuda.is_available():
238
warn("CUDA is not available, disabling CUDA profiling")
239
self.use_cuda = False
241
if self.use_xpu and (
242
not hasattr(torch, "xpu") or not torch.xpu.is_available()
243
): # type: ignore[attr-defined]
244
warn("XPU is not available, disabling XPU profiling")
247
self.kineto_activities = set()
249
self.kineto_activities.add(ProfilerActivity.CPU)
251
self.kineto_activities.add(ProfilerActivity.MTIA)
253
self.profiler_kind = ProfilerState.KINETO
255
if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
256
assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
257
self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
259
self.kineto_activities.add(ProfilerActivity.CUDA)
262
use_kineto or ProfilerActivity.XPU in _supported_activities()
263
), "Profiling on XPU backend must enable Kineto"
264
self.kineto_activities.add(ProfilerActivity.XPU)
268
or ProfilerActivity.PrivateUse1 not in _supported_activities()
272
), "Legacy custombackend profiling requires use_cpu=True"
273
self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
275
self.kineto_activities.add(ProfilerActivity.PrivateUse1)
276
self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1
279
len(self.kineto_activities) > 0
280
), "No activities specified for the profiler"
283
return ProfilerConfig(
290
self.experimental_config,
297
raise RuntimeError("Profiler context manager is not reentrant")
298
self._prepare_trace()
302
def _prepare_trace(self):
304
_prepare_profiler(self.config(), self.kineto_activities)
306
def _start_trace(self):
308
_run_on_profiler_start()
309
_enable_profiler(self.config(), self.kineto_activities)
311
def __exit__(self, exc_type, exc_val, exc_tb):
315
torch.cuda.synchronize()
317
torch.xpu.synchronize() # type: ignore[attr-defined]
318
self.kineto_results = _disable_profiler()
319
_run_on_profiler_stop()
320
parsed_results = self._parse_kineto_results(self.kineto_results)
321
self.function_events = EventList(
323
use_cuda=self.use_cuda,
324
use_xpu=self.use_xpu,
325
use_device=self.use_device,
326
profile_memory=self.profile_memory,
327
with_flops=self.with_flops,
329
self.function_events._build_tree()
333
if self.function_events is None:
334
return "<unfinished torch.autograd.profile>"
335
return repr(self.function_events)
338
if self.function_events is None:
339
return "<unfinished torch.autograd.profile>"
340
return str(self.function_events)
342
def _check_finish(self):
343
if self.function_events is None:
344
raise RuntimeError("Profiler didn't finish running")
350
max_src_column_width=75,
351
max_name_column_width=55,
352
max_shapes_column_width=80,
354
top_level_events_only=False,
357
assert self.function_events is not None
358
return self.function_events.table(
361
max_src_column_width=max_src_column_width,
362
max_name_column_width=max_name_column_width,
363
max_shapes_column_width=max_shapes_column_width,
365
top_level_events_only=top_level_events_only,
368
table.__doc__ = EventList.table.__doc__
370
def export_chrome_trace(self, path):
372
if kineto_available():
373
self.kineto_results.save(path) # type: ignore[union-attr]
375
return self.function_events.export_chrome_trace(path) # type: ignore[union-attr]
377
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
379
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
381
assert self.function_events is not None, "Expected profiling results"
382
assert self.with_stack, "export_stacks() requires with_stack=True"
383
return self.function_events.export_stacks(path, metric)
385
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
387
assert self.function_events is not None, "Expected profiling results"
388
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
390
key_averages.__doc__ = EventList.key_averages.__doc__
392
def total_average(self):
394
assert self.function_events is not None, "Expected profiling results"
395
return self.function_events.total_average()
397
total_average.__doc__ = EventList.total_average.__doc__
400
def self_cpu_time_total(self):
401
"""Returns total time spent on CPU.
403
The total time is a sum of all self times across all the events.
406
assert self.function_events is not None
407
return self.function_events.self_cpu_time_total
409
def _parse_kineto_results(self, result: _ProfilerResult):
410
# result.events() has most of the events - PyTorch op-level and device-level events
412
trace_start_us = result.trace_start_us()
414
[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
417
evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
419
mem_records_acc = MemRecordsAcc(mem_records)
421
def _cpu_memory_usage(mem_record):
424
if mem_record.device_type()
425
in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
429
def _cuda_memory_usage(mem_record):
432
if mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP]
436
def _xpu_memory_usage(mem_record):
439
if mem_record.device_type() in [DeviceType.XPU]
443
def _privateuse1_memory_usage(mem_record):
446
if mem_record.device_type() in [DeviceType.PrivateUse1]
450
# Create and return FunctionEvent list
452
device_corr_map: Dict[int, List[FunctionEvent]] = {}
453
device_corr_map_values: List[FunctionEvent] = []
455
for kineto_event in result.events():
456
if _filter_name(kineto_event.name()):
458
rel_start_us = kineto_event.start_us() - trace_start_us
459
rel_end_us = rel_start_us + kineto_event.duration_us()
460
abs_end_us = kineto_event.start_us() + kineto_event.duration_us()
463
cuda_memory_usage = 0
465
privateuse1_memory_usage = 0
466
if kineto_event.device_type() == DeviceType.CPU:
467
# find the corresponding memory allocation events
468
for mem_record in mem_records_acc.in_interval(
469
kineto_event.start_us(), abs_end_us
471
cpu_memory_usage += _cpu_memory_usage(mem_record[0])
472
cuda_memory_usage += _cuda_memory_usage(mem_record[0])
473
xpu_memory_usage += _xpu_memory_usage(mem_record[0])
474
privateuse1_memory_usage += _privateuse1_memory_usage(mem_record[0])
477
is_async = kineto_event.is_async() or (
478
kineto_event.start_thread_id() != kineto_event.end_thread_id()
482
id=kineto_event.correlation_id(),
483
name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
484
trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
485
thread=kineto_event.start_thread_id(),
486
start_us=rel_start_us,
488
fwd_thread=kineto_event.fwd_thread_id(),
489
input_shapes=kineto_event.shapes(),
490
concrete_inputs=kineto_event.concrete_inputs(),
493
for entry in kineto_event.stack()
494
if _filter_stack_entry(entry)
496
scope=kineto_event.scope(),
497
use_device=self.use_device,
498
cpu_memory_usage=cpu_memory_usage,
499
cuda_memory_usage=cuda_memory_usage,
500
xpu_memory_usage=xpu_memory_usage,
501
privateuse1_memory_usage=privateuse1_memory_usage,
503
sequence_nr=kineto_event.sequence_nr(),
504
device_type=kineto_event.device_type(),
505
device_index=kineto_event.device_index(),
506
flops=kineto_event.flops(),
508
max_evt_id = max(max_evt_id, fe.id)
509
if fe.device_type == DeviceType.CPU and not fe.is_async:
511
privateuse1_time = kineto_event.privateuse1_elapsed_us()
512
if privateuse1_time > 0:
513
fe.append_privateuse1_kernel(
514
fe.name, fe.device_index, privateuse1_time
518
# Check if we have CUDA time as a fallback
519
cuda_time = kineto_event.cuda_elapsed_us()
521
fe.append_kernel(fe.name, fe.device_index, cuda_time)
523
function_events.append(fe)
524
corr_id = kineto_event.linked_correlation_id()
526
if corr_id not in device_corr_map:
527
device_corr_map[corr_id] = []
528
device_corr_map[corr_id].append(fe)
529
# prepare the values list for corr map to avoid duplicately appending kernel
530
# while some CPU Kineto events (e.g. CUDA/XPU Runtime Activity) have a same
531
# correlation ID as CPU events
532
device_corr_map_values = [
533
fe for fe_list in device_corr_map.values() for fe in fe_list
536
# associate CUDA/XPU kernels and CUDA/XPU runtime (CPU) with CPU events
537
for fe in function_events:
539
fe.device_type == DeviceType.CPU
541
and fe.id in device_corr_map
542
and fe not in device_corr_map_values
544
for f_evt in device_corr_map[fe.id]:
545
if f_evt.device_type == DeviceType.CUDA:
549
f_evt.time_range.end - f_evt.time_range.start,
551
elif f_evt.device_type == DeviceType.XPU:
552
fe.append_xpu_kernel(
555
f_evt.time_range.end - f_evt.time_range.start,
557
elif f_evt.device_type == DeviceType.PrivateUse1:
558
fe.append_privateuse1_kernel(
561
f_evt.time_range.end - f_evt.time_range.start,
563
elif f_evt.device_type == DeviceType.CPU:
564
# make sure that 'thread' of a CPU Kineto (e.g. CUDA/XPU Runtime) event is associated
565
# with the 'thread' of the corresponding linked PyTorch event to properly track
566
# parents and children
567
f_evt.thread = fe.thread
569
def createFunctionEventForMemoryEvents(evt):
570
rel_start_us = evt.start_us() - trace_start_us
574
trace_name=None, # not outputting in the trace
575
thread=evt.start_thread_id(),
576
start_us=rel_start_us,
577
end_us=rel_start_us, # no duration
578
fwd_thread=evt.start_thread_id(),
581
scope=0, # RecordScope::FUNCTION
582
use_device=self.use_device,
583
cpu_memory_usage=_cpu_memory_usage(evt),
584
cuda_memory_usage=_cuda_memory_usage(evt),
585
xpu_memory_usage=_xpu_memory_usage(evt),
586
privateuse1_memory_usage=_privateuse1_memory_usage(evt),
589
device_type=DeviceType.CPU,
594
# output top-level memory events
595
for mem_record in mem_records:
596
if not mem_record[1]:
598
fe = createFunctionEventForMemoryEvents(mem_record[0])
599
function_events.append(fe)
601
for oom_record in oom_records:
603
fe = createFunctionEventForMemoryEvents(oom_record)
604
function_events.append(fe)
606
function_events.sort(
607
key=lambda evt: [evt.time_range.start, -evt.time_range.end]
609
return function_events
612
class record_function(_ContextDecorator):
613
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
615
It is useful when tracing the code profile.
618
name (str): Label assigned to the block of code.
619
node_id (int): ID of node, for distributed profiling. Unset in
620
non-distributed cases.
623
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
624
>>> x = torch.randn((1, 1), requires_grad=True)
625
>>> with torch.autograd.profiler.profile() as prof:
627
... with torch.autograd.profiler.record_function("label-z"): # label the block
631
>>> # xdoctest: +IGNORE_WANT
632
>>> # NOTE: some columns were removed for brevity
633
>>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
634
----------------------------------- --------------- --------------- ---------------
635
Name Self CPU total % CPU time avg Number of Calls
636
----------------------------------- --------------- --------------- ---------------
637
pow 60.77% 47.470us 3
638
mul 21.73% 25.465us 2
639
PowBackward0 12.03% 121.891us 1
640
torch::autograd::AccumulateGrad 2.70% 6.324us 1
641
label-z 2.13% 12.421us 1
642
torch::autograd::GraphRoot 0.64% 1.503us 1
643
----------------------------------- --------------- --------------- ---------------
644
Self CPU time total: 234.344us
645
CUDA time total: 0.000us
649
def __init__(self, name: str, args: Optional[str] = None):
650
self.name: str = name
651
self.args: Optional[str] = args
652
# Whether or not we should run record function's end callbacks when exiting.
653
self.run_callbacks_on_exit: bool = True
654
# TODO: TorchScript ignores standard type annotation here
655
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
656
self.record = torch.jit.annotate(
657
Optional["torch.classes.profiler._RecordFunction"], None
661
self.record = torch.ops.profiler._record_function_enter_new(
666
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
667
if not self.run_callbacks_on_exit:
670
# Local variable is needed by TorchScript to refine Optional[T] to T
672
assert record is not None
674
# TODO: Too slow with __torch_function__ handling enabled
675
# See https://github.com/pytorch/pytorch/issues/76410
676
if not torch.jit.is_scripting():
677
with torch._C.DisableTorchFunctionSubclass():
678
torch.ops.profiler._record_function_exit._RecordFunction(record)
680
torch.ops.profiler._record_function_exit(record)
682
def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
683
"""Use for profiling async calls that return a future.
685
Calling this function will extend recording beyond this scope, until the future is
686
satisfied. It is useful for profiling the end to end time of asynchronous calls.
687
This function should only be called once to attach the callback onto the future, and
688
will throw if called multiple times.
691
fut: (torch._C.Future): future for which to schedule
695
A future that completes with the value of the passed in future when
696
the profiling callbacks have ran.
699
# Throw if we have already attached a callback onto the future.
700
if not self.run_callbacks_on_exit:
701
raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
703
# We are scheduling to run this RecordFunction's end callbacks when the
704
# passed in future completes, so don't run end callbacks on exit.
705
self.run_callbacks_on_exit = False
707
# Local variable is needed by TorchScript to refine Optional[T] to T
709
assert record is not None
711
# TODO: Too slow with __torch_function__ handling enabled
712
# See https://github.com/pytorch/pytorch/issues/76410
713
if not torch.jit.is_scripting():
714
with torch._C.DisableTorchFunctionSubclass():
716
torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
721
profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
724
return profiled_future
728
"""Context manager that makes every autograd operation emit an ITT range.
730
It is useful when running the program under Intel(R) VTune Profiler::
732
vtune <--vtune-flags> <regular command here>
734
The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
735
control the collection of trace data during its execution across different Intel tools.
736
This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
737
you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
740
This context manager should not be called recursively, i.e. at most one
741
instance should be enabled at any given time.
744
enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
746
record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
747
each autograd op will append information about the sizes of Tensor arguments received
748
by that op, in the following format:
749
``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
750
Non-tensor arguments will be represented by ``[]``.
751
Arguments will be listed in the order they are received by the backend op.
752
Please note that this order may not match the order in which those arguments were passed
753
on the Python side. Also note that shape recording may increase the overhead of itt range creation.
757
>>> # xdoctest: +SKIP("Undefined variables")
758
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
759
>>> with torch.autograd.profiler.emit_itt():
764
def __init__(self, enabled=True, record_shapes=False):
765
self.enabled = enabled
767
self.record_shapes = record_shapes
773
raise RuntimeError("ITT annotation context manager is not reentrant")
775
_run_on_profiler_start()
784
_ExperimentalConfig(),
790
def __exit__(self, exc_type, exc_val, exc_tb):
794
_run_on_profiler_stop()
799
"""Context manager that makes every autograd operation emit an NVTX range.
801
It is useful when running the program under nvprof::
803
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
805
Unfortunately, there's no way to force nvprof to flush the data it collected
806
to disk, so for CUDA profiling one has to use this context manager to annotate
807
nvprof traces and wait for the process to exit before inspecting them.
808
Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
809
:func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
813
This context manager should not be called recursively, i.e. at most one
814
instance should be enabled at any given time.
817
enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
819
record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
820
each autograd op will append information about the sizes of Tensor arguments received
821
by that op, in the following format:
822
``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
823
Non-tensor arguments will be represented by ``[]``.
824
Arguments will be listed in the order they are received by the backend op.
825
Please note that this order may not match the order in which those arguments were passed
826
on the Python side. Also note that shape recording may increase the overhead of nvtx range creation.
830
>>> # xdoctest: +SKIP("undefined variables")
831
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
832
>>> with torch.cuda.profiler.profile():
833
... model(x) # Warmup CUDA memory allocator and profiler
834
... with torch.autograd.profiler.emit_nvtx():
837
**Forward-backward correlation**
839
When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
840
correlating each backward-pass op with the corresponding forward-pass op can be difficult.
841
To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
844
During the forward pass, each function range is decorated with ``seq=<N>``. ``seq`` is a running
845
counter, incremented each time a new backward Function object is created and stashed for backward.
846
Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
847
if a backward Function object is created by this forward function,
848
the backward object will receive sequence number N.
849
During the backward pass, the top-level range wrapping each C++ backward Function's
850
``apply()`` call is decorated with ``stashed seq=<M>``. ``M`` is the sequence number that
851
the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq``
852
numbers in forward, you can track down which forward op created each backward Function.
854
Any functions executed during the backward pass are also decorated with ``seq=<N>``. During
855
default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
856
``N`` may simply be 0 for all such functions. Only the top-level ranges associated with
857
backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
858
objects with the earlier forward pass.
862
If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
863
if you are setting up for a double-backward), each function's execution during backward
864
is given a nonzero, useful ``seq=<N>``. Those functions may themselves create Function objects
865
to be executed later during double-backward, just as the original functions in the forward pass did.
866
The relationship between backward and double-backward is conceptually the same as the relationship
867
between forward and backward: The functions still emit current-sequence-number-tagged ranges,
868
the Function objects they create still stash those sequence numbers, and during the eventual
869
double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
870
numbers, which can be compared to `seq` numbers from the backward pass.
873
The sequence number is thread-local, and some forward functions don't create an associated
874
backward Function object (instead delegating that to sub-functions further down the call chain).
875
For these reasons, the correspondence of stashed sequence numbers in
876
backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
877
not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully
878
disambiguate which forward function created which
879
backward Function object. You may need to make a judgment based on analytic knowledge of what
880
the expected correspondence should be.
883
def __init__(self, enabled=True, record_shapes=False):
884
self.enabled = enabled
886
self.record_shapes = record_shapes
892
raise RuntimeError("NVTX annotation context manager is not reentrant")
894
torch.cuda.synchronize()
895
_run_on_profiler_start()
904
_ExperimentalConfig(),
910
def __exit__(self, exc_type, exc_val, exc_tb):
913
torch.cuda.synchronize()
915
_run_on_profiler_stop()
919
def load_nvprof(path):
920
"""Open an nvprof trace file and parses autograd annotations.
923
path (str): path to nvprof trace
925
return EventList(parse_nvprof_trace(path))
929
"""Raises an error if a key is seen more than once."""
936
Observe a key and raise an error if it is seen multiple times.
939
raise RuntimeError("duplicate key: " + str(key))
943
def parse_nvprof_trace(path):
946
conn = sqlite3.connect(path)
947
conn.row_factory = sqlite3.Row
949
# Parse strings table
951
for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
952
strings[r["id"]] = torch._C._demangle(r["value"])
954
# First, find all functions and create FunctionEvents for them
957
start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
959
CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
962
start.name != 0 AND end.name = 0
966
unique = EnforceUnique()
967
for row in conn.execute(marker_query):
968
unique.see(row["marker_id"])
971
node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
972
# that pytorch doesn't crash when creating a FunctionEvent() object
973
name=strings[row["name"]],
974
start_us=row["start_time"],
975
end_us=row["end_time"],
977
) # TODO: find in sqlite database
978
functions.append(evt)
979
functions_map[evt.id] = evt
981
# Now, correlate all kernels with FunctionEvents
984
start.id AS marker_id, start.name, start.timestamp, end.timestamp,
985
runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
986
kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
988
CUPTI_ACTIVITY_KIND_MARKER AS start
989
INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
991
INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
992
ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
993
INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
994
ON kernel.correlationId = runtime.correlationId
996
unique = EnforceUnique()
997
for row in conn.execute(kernel_query):
998
unique.see(row["marker_id"], row["runtime_id"])
999
# 211 is cudaKernelLaunch for cuda >= 9.2
1000
assert row["cbid"] == 211
1001
evt = functions_map[row["marker_id"]]
1003
row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
1006
functions.sort(key=lambda evt: evt.time_range.start)
1010
class KinetoStepTracker:
1011
"""Provides an abstraction for incrementing the step count globally.
1013
Previously, we only had one place to mark that a step() has occurred
1014
in the program via pytorch profiler step(). We will now add step hooks
1015
in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
1017
- This could mean programs that already call profiler.step() every
1018
iteration can end up double incrementing step count.
1019
- If a model uses multiple optimizers we can also have double or more
1020
counting of the step.
1022
We fix this by adding a layer of abstraction before calling step()
1023
to the kineto library. The idea is to maintain steps per requester in a dict:
1028
"ProfilerStep": 100, # triggered by profiler step() call
1029
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
1030
"Optimizer2Step": 100,
1033
To figure out the global step count just take the max of dict values (100).
1035
If one of the count increments the max will go up.
1040
"ProfilerStep": 100,
1041
"Optimizer1Step": 101, # Optimizer1 got incremented first say
1042
"Optimizer2Step": 100,
1045
Then global step count is 101
1046
We only call the kineto step() function when global count increments.
1048
NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
1049
for now. The result could be incorrect increments of the step count.
1053
_step_dict: Dict[str, int] = defaultdict(int)
1056
def init_step_count(cls, requester: str):
1058
Initialize for a given requester.
1060
cls._step_dict[requester] = cls._current_step
1063
def erase_step_count(cls, requester: str) -> bool:
1065
Remove a given requester.
1067
return cls._step_dict.pop(requester, None) is not None
1070
def increment_step(cls, requester: str) -> int:
1071
"""Increments the step count for the requester.
1073
Additionally if the max over all step counts has incremented then
1074
trigger the _kineto_step() returns global step count
1076
if requester not in cls._step_dict:
1077
cls.init_step_count(requester)
1078
cls._step_dict[requester] += 1
1080
new_step = max(cls._step_dict.values())
1081
if new_step > cls._current_step:
1082
delta = new_step - cls._current_step
1085
"Profiler step count has increased more than 1 - "
1086
f"current_step = {cls._current_step} step dict = {cls._step_dict}"
1088
for _ in range(0, delta):
1090
cls._current_step = new_step
1091
return cls._current_step
1094
def current_step(cls) -> int:
1096
Get the latest step for any requester
1098
return cls._current_step