1
# mypy: allow-untyped-defs
2
from collections import defaultdict
3
from dataclasses import dataclass
4
from time import perf_counter_ns
5
from typing import Any, Dict, Iterable, List, Optional
6
from warnings import warn
10
from torch._C import _get_privateuse1_backend_name
11
from torch._C._profiler import _ExperimentalConfig
12
from torch.autograd import (
18
_supported_activities,
19
_toggle_collection_dynamic,
26
from torch.autograd.profiler_util import (
34
OUT_OF_MEMORY_EVENT_NAME,
36
from torch.futures import Future
54
# Available in Python >= 3.2
55
from contextlib import ContextDecorator as _ContextDecorator
59
class _ContextDecorator: # type: ignore[no-redef]
61
raise NotImplementedError
63
def __exit__(self, exc_type, exc_val, exc_tb):
64
raise NotImplementedError
66
def __call__(self, func):
67
@functools.wraps(func)
68
def wrapped(*args, **kwargs):
70
return func(*args, **kwargs)
75
# global python state - whether profiler is currently enabled
76
# useful for fast python checks to reduce latency
77
_is_profiler_enabled: bool = False
80
def _set_is_profiler_enabled(enable: bool):
81
global _is_profiler_enabled
82
_is_profiler_enabled = enable
85
def _run_on_profiler_start():
86
_set_is_profiler_enabled(True)
89
def _run_on_profiler_stop():
90
_set_is_profiler_enabled(False)
95
"Profiler timing and stats used by developers to catch issues/regressions"
96
profiling_window_duration_sec: float = 0
97
number_of_events: int = 0
98
profiler_prepare_call_duration_us: int = 0
99
profiler_enable_call_duration_us: int = 0
100
profiler_disable_call_duration_us: int = 0
101
parse_kineto_call_duration_us: int = 0
102
function_events_build_tree_call_duration_us: int = 0
106
"""Context manager that manages autograd profiler state and holds a summary of results.
108
Under the hood it just records events of functions being executed in C++ and
109
exposes those events to Python. You can wrap any code into it and it will
110
only report runtime of PyTorch functions.
111
Note: profiler is thread local and is automatically propagated into the async tasks
114
enabled (bool, optional): Setting this to False makes this context manager a no-op.
116
use_cuda (bool, optional): Enables timing of CUDA events as well
117
using the cudaEvent API. (will be deprecated)
119
use_device (str, optional): Enables timing of device events.
120
Adds approximately 4us of overhead to each tensor operation when use cuda.
121
The valid devices options are 'cuda', 'xpu', 'mtia' and 'privateuseone'.
123
record_shapes (bool, optional): If shapes recording is set, information
124
about input dimensions will be collected. This allows one to see which
125
dimensions have been used under the hood and further group by them
126
using prof.key_averages(group_by_input_shape=True). Please note that
127
shape recording might skew your profiling data. It is recommended to
128
use separate runs with and without shape recording to validate the timing.
129
Most likely the skew will be negligible for bottom most events (in a case
130
of nested function calls). But for higher level functions the total
131
self cpu time might be artificially increased because of the shape
134
with_flops (bool, optional): If with_flops is set, the profiler will estimate
135
the FLOPs (floating point operations) value using the operator's input shape.
136
This allows one to estimate the hardware performance. Currently,
137
this option only works for the matrix multiplication and 2D convolution operators.
139
profile_memory (bool, optional): track tensor memory allocation/deallocation.
141
with_stack (bool, optional): record source information (file and line number) for the ops.
143
with_modules (bool): record module hierarchy (including function names)
144
corresponding to the callstack of the op. e.g. If module A's forward call's
145
module B's forward which contains an aten::add op,
146
then aten::add's module hierarchy is A.B
147
Note that this support exist, at the moment, only for TorchScript models
148
and not eager mode models.
150
use_kineto (bool, optional): experimental, enable profiling with Kineto profiler.
152
use_cpu (bool, optional): profile CPU events; setting to ``False`` requires
153
``use_kineto=True`` and can be used to lower the overhead for GPU-only profiling.
155
experimental_config (_ExperimentalConfig) : A set of experimental options
156
used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
158
acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
162
Enabling memory profiling or source attribution incurs additional profiler
166
This context managers should not be called recursively, i.e. no nested
167
instances are allowed
170
Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
171
one cannot use the profiler with ``use_device = 'cuda'`` to benchmark
172
DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
173
please use ``use_device = None`` or ``num_workers = 0``.
176
>>> # xdoctest: +SKIP
177
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
178
>>> x = torch.randn((1, 1), requires_grad=True)
179
>>> with torch.autograd.profiler.profile() as prof:
180
>>> for _ in range(100): # any normal python code, really!
183
>>> # NOTE: some columns were removed for brevity
184
>>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
185
----------------------------------- --------------- --------------- ---------------
186
Name Self CPU total CPU time avg Number of Calls
187
----------------------------------- --------------- --------------- ---------------
188
mul 32.048ms 32.048ms 200
189
pow 27.041ms 27.041ms 200
190
PowBackward0 9.727ms 55.483ms 100
191
torch::autograd::AccumulateGrad 9.148ms 9.148ms 100
192
torch::autograd::GraphRoot 691.816us 691.816us 100
193
----------------------------------- --------------- --------------- ---------------
201
use_cuda=False, # Deprecated
205
profile_memory=False,
210
experimental_config=None,
213
self.enabled: bool = enabled
216
self.use_cuda = use_cuda
219
"The attribute `use_cuda` will be deprecated soon, "
220
"please use ``use_device = 'cuda'`` instead.",
224
self.use_device: Optional[str] = "cuda"
226
self.use_device = use_device
227
# TODO Consider changing _function_events into data structure with size cap
228
self._function_events: Optional[EventList] = None
229
self._old_function_events: Optional[EventList] = None
230
# Function event processing is done lazily
231
self._needs_processing = False
233
self.record_shapes = record_shapes
234
self.with_flops = with_flops
235
self.record_shapes |= self.with_flops
236
self.profile_memory = profile_memory
237
self.with_stack = with_stack
238
self.with_modules = with_modules
239
self.use_cpu = use_cpu
240
self.acc_events = acc_events
241
if experimental_config is None:
242
experimental_config = _ExperimentalConfig()
243
self.experimental_config = experimental_config
244
self.kineto_results: Optional[_ProfilerResult] = None
245
self.profiling_start_time_ns = 0
246
self.profiling_end_time_ns = 0
247
self._stats = _ProfilerStats()
252
), "Device-only events supported only with Kineto (use_kineto=True)"
254
if self.use_device is not None:
255
VALID_DEVICE_OPTIONS = ["cuda", "xpu", "mtia"]
256
if _get_privateuse1_backend_name() != "privateuseone":
257
VALID_DEVICE_OPTIONS.append(_get_privateuse1_backend_name())
258
if self.use_device not in VALID_DEVICE_OPTIONS:
259
warn(f"The {self.use_device} is not a valid device option.")
260
self.use_device = None
262
if self.use_device == "cuda" and not torch.cuda.is_available():
263
warn("CUDA is not available, disabling CUDA profiling")
264
self.use_cuda = False
265
self.use_device = None
267
if self.use_device == "xpu" and not torch.xpu.is_available():
268
warn("XPU is not available, disabling XPU profiling")
269
self.use_device = None
271
self.kineto_activities = set()
273
self.kineto_activities.add(ProfilerActivity.CPU)
275
self.profiler_kind = ProfilerState.KINETO
276
if self.use_device == "cuda":
277
if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
278
assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
279
self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
281
self.kineto_activities.add(ProfilerActivity.CUDA)
282
elif self.use_device == "xpu":
284
use_kineto and ProfilerActivity.XPU in _supported_activities()
285
), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices."
286
self.kineto_activities.add(ProfilerActivity.XPU)
287
elif self.use_device == "mtia":
289
use_kineto and ProfilerActivity.MTIA in _supported_activities()
290
), "Legacy MTIA profiling is not supported. Requires use_kineto=True on MTIA devices."
291
self.kineto_activities.add(ProfilerActivity.MTIA)
292
elif self.use_device is not None and self.use_device != "privateuseone":
295
or ProfilerActivity.PrivateUse1 not in _supported_activities()
299
), "Legacy custombackend profiling requires use_cpu=True"
300
self.profiler_kind = ProfilerState.KINETO_PRIVATEUSE1_FALLBACK
302
self.kineto_activities.add(ProfilerActivity.PrivateUse1)
305
len(self.kineto_activities) > 0
306
), "No activities specified for the profiler"
309
return ProfilerConfig(
316
self.experimental_config,
323
raise RuntimeError("Profiler context manager is not reentrant")
324
self._prepare_trace()
328
def _prepare_trace(self):
330
t0 = perf_counter_ns()
331
_prepare_profiler(self.config(), self.kineto_activities)
332
t1 = perf_counter_ns()
333
self._stats.profiler_prepare_call_duration_us = int((t1 - t0) / 1000)
335
def _start_trace(self):
337
_run_on_profiler_start()
338
t0 = perf_counter_ns()
339
_enable_profiler(self.config(), self.kineto_activities)
340
t1 = perf_counter_ns()
341
self._stats.profiler_enable_call_duration_us = int((t1 - t0) / 1000)
342
self.profiling_start_time_ns = t1
344
def __exit__(self, exc_type, exc_val, exc_tb):
347
if self.use_device and hasattr(torch, self.use_device):
348
device_module = getattr(torch, self.use_device)
349
if hasattr(device_module, "synchronize"):
350
device_module.synchronize()
352
if self._function_events and self.acc_events:
353
self._old_function_events = self._function_events
354
self._function_events = None
355
self._needs_processing = True
357
t0 = perf_counter_ns()
359
self.kineto_results = _disable_profiler()
360
t1 = perf_counter_ns()
361
self._stats.profiler_disable_call_duration_us = int((t1 - t0) / 1000)
362
self.profiling_end_time_ns = t0
364
_run_on_profiler_stop()
366
self._stats.profiling_window_duration_sec = (
367
(self.profiling_end_time_ns - self.profiling_start_time_ns) * 1.0 / 1e9
370
# If we plan to accumulate events we should post process the function events
371
# right away to retain the state across mulitple start/stop calls
373
self._ensure_function_events()
377
if self._needs_processing:
378
self._ensure_function_events()
379
if self._function_events is None:
380
return "<unfinished torch.autograd.profile>"
381
return repr(self._function_events)
384
if self._needs_processing:
385
self._ensure_function_events()
386
if self._function_events is None:
387
return "<unfinished torch.autograd.profile>"
388
return str(self._function_events)
390
def _ensure_function_events(self):
391
"""Process function events lazily if required"""
392
if self._function_events is not None:
394
self._needs_processing = False
396
t0 = perf_counter_ns()
398
if self.kineto_results:
399
parsed_results = self._parse_kineto_results(self.kineto_results)
400
t1 = perf_counter_ns()
401
self._stats.parse_kineto_call_duration_us = int((t1 - t0) / 1000)
403
self._function_events = EventList(
405
use_device=self.use_device,
406
profile_memory=self.profile_memory,
407
with_flops=self.with_flops,
409
t0 = perf_counter_ns()
410
self._function_events._build_tree()
411
t1 = perf_counter_ns()
412
self._stats.function_events_build_tree_call_duration_us = int((t1 - t0) / 1000)
413
self._stats.number_of_events = len(self._function_events)
415
if self._old_function_events and self.acc_events:
416
for evt in self._old_function_events:
417
self._function_events.append(evt)
418
self._old_function_events = None
420
if self._function_events is None:
421
raise RuntimeError("Profiler didn't finish running")
424
def function_events(self):
425
if self._function_events is None or self._needs_processing:
426
self._ensure_function_events()
427
return self._function_events
433
max_src_column_width=75,
434
max_name_column_width=55,
435
max_shapes_column_width=80,
437
top_level_events_only=False,
439
self._ensure_function_events()
440
assert self._function_events is not None
441
return self._function_events.table(
444
max_src_column_width=max_src_column_width,
445
max_name_column_width=max_name_column_width,
446
max_shapes_column_width=max_shapes_column_width,
448
top_level_events_only=top_level_events_only,
451
table.__doc__ = EventList.table.__doc__
453
def export_chrome_trace(self, path):
455
Exports the collected trace in Chrome JSON format. If kineto is enabled, only
456
last cycle in schedule is exported.
458
if kineto_available():
459
self.kineto_results.save(path) # type: ignore[union-attr]
461
self._ensure_function_events()
462
return self._function_events.export_chrome_trace(path) # type: ignore[union-attr]
464
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
466
def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
467
self._ensure_function_events()
468
assert self._function_events is not None, "Expected profiling results"
469
assert self.with_stack, "export_stacks() requires with_stack=True"
470
return self._function_events.export_stacks(path, metric)
472
def toggle_collection_dynamic(
473
self, enabled: bool, activities: Iterable[ProfilerActivity]
476
Toggles the collection of activities for the current profiler instance.
478
return _toggle_collection_dynamic(enabled, set(activities))
480
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
481
self._ensure_function_events()
482
assert self._function_events is not None, "Expected profiling results"
483
return self._function_events.key_averages(
484
group_by_input_shape, group_by_stack_n
487
key_averages.__doc__ = EventList.key_averages.__doc__
489
def total_average(self):
490
self._ensure_function_events()
491
assert self._function_events is not None, "Expected profiling results"
492
return self._function_events.total_average()
494
total_average.__doc__ = EventList.total_average.__doc__
497
def self_cpu_time_total(self):
498
"""Returns total time spent on CPU.
500
The total time is a sum of all self times across all the events.
502
self._ensure_function_events()
503
assert self._function_events is not None
504
return self._function_events.self_cpu_time_total
506
def _parse_kineto_results(self, result: _ProfilerResult):
507
# result.events() has most of the events - PyTorch op-level and device-level events
509
trace_start_ns = result.trace_start_ns()
511
[evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME
514
evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME
516
mem_records_acc = MemRecordsAcc(mem_records)
518
def _cpu_memory_usage(mem_record):
521
if mem_record.device_type()
522
in [DeviceType.CPU, DeviceType.MKLDNN, DeviceType.IDEEP]
526
def _device_memory_usage(mem_record):
529
if mem_record.device_type()
530
in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
534
# Create and return FunctionEvent list, which contains all function events
535
# Here 2 function events are created:
536
# all_function_events contains all events associated with each kineto event from result
537
all_function_events = []
538
# frontend_function_events contains the events in aten or torch frontend level,
539
# whose correlation id is 0
540
frontend_function_events = []
541
device_corr_map: Dict[int, List[FunctionEvent]] = {}
543
for kineto_event in result.events():
544
if _filter_name(kineto_event.name()):
546
rel_start_ns = kineto_event.start_ns() - trace_start_ns
547
rel_end_ns = kineto_event.end_ns() - trace_start_ns
548
abs_end_ns = kineto_event.end_ns()
551
device_memory_usage = 0
552
if kineto_event.device_type() == DeviceType.CPU:
553
# find the corresponding memory allocation events
554
for mem_record in mem_records_acc.in_interval(
555
kineto_event.start_ns() / 1000, abs_end_ns / 1000
557
cpu_memory_usage += _cpu_memory_usage(mem_record[0])
558
device_memory_usage += _device_memory_usage(mem_record[0])
561
is_async = kineto_event.is_async() or (
562
kineto_event.start_thread_id() != kineto_event.end_thread_id()
566
id=kineto_event.correlation_id(),
567
name=_rewrite_name(name=kineto_event.name(), with_wildcard=True),
568
trace_name=_rewrite_name(name=kineto_event.name(), with_wildcard=False),
569
thread=kineto_event.start_thread_id(),
570
start_us=rel_start_ns / 1000,
571
end_us=rel_end_ns / 1000,
572
fwd_thread=kineto_event.fwd_thread_id(),
573
input_shapes=kineto_event.shapes(),
574
concrete_inputs=kineto_event.concrete_inputs(),
575
kwinputs=kineto_event.kwinputs(),
578
for entry in kineto_event.stack()
579
if _filter_stack_entry(entry)
581
scope=kineto_event.scope(),
582
use_device=self.use_device,
583
cpu_memory_usage=cpu_memory_usage,
584
device_memory_usage=device_memory_usage,
586
sequence_nr=kineto_event.sequence_nr(),
587
device_type=kineto_event.device_type(),
588
device_index=kineto_event.device_index(),
589
device_resource_id=kineto_event.device_resource_id(),
590
flops=kineto_event.flops(),
591
is_user_annotation=kineto_event.is_user_annotation(),
593
max_evt_id = max(max_evt_id, fe.id)
594
if fe.device_type == DeviceType.CPU and not fe.is_async:
595
if self.use_device == "privateuseone":
596
privateuse1_time = kineto_event.privateuse1_elapsed_us()
597
if privateuse1_time > 0:
598
fe.append_kernel(fe.name, fe.device_index, privateuse1_time)
600
elif self.use_device == "cuda":
601
# Check if we have CUDA time as a fallback
602
cuda_time = kineto_event.cuda_elapsed_us()
604
fe.append_kernel(fe.name, fe.device_index, cuda_time)
606
all_function_events.append(fe)
607
corr_id = kineto_event.linked_correlation_id()
609
if corr_id not in device_corr_map:
610
device_corr_map[corr_id] = []
611
device_corr_map[corr_id].append(fe)
613
frontend_function_events.append(fe)
616
f"Got negative correlation id {corr_id} in profiler post processing"
619
# associate device kernels and device runtime (CPU) with CPU events
620
for fe in frontend_function_events:
622
fe.device_type == DeviceType.CPU
624
and fe.id in device_corr_map
626
for f_evt in device_corr_map[fe.id]:
628
f_evt.device_type == DeviceType.CUDA
629
or f_evt.device_type == DeviceType.PrivateUse1
634
f_evt.time_range.end - f_evt.time_range.start,
636
elif f_evt.device_type == DeviceType.CPU:
637
# make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated
638
# with the 'thread' of the corresponding linked PyTorch event to properly track
639
# parents and children
640
f_evt.thread = fe.thread
642
def createFunctionEventForMemoryEvents(evt):
643
rel_start_ns = evt.start_ns() - trace_start_ns
647
trace_name=None, # not outputting in the trace
648
thread=evt.start_thread_id(),
649
start_us=rel_start_ns / 1000,
650
end_us=rel_start_ns / 1000, # no duration
651
fwd_thread=evt.start_thread_id(),
654
scope=0, # RecordScope::FUNCTION
655
use_device=self.use_device,
656
cpu_memory_usage=_cpu_memory_usage(evt),
657
device_memory_usage=_device_memory_usage(evt),
660
device_type=DeviceType.CPU,
665
# output top-level memory events
666
for mem_record in mem_records:
667
if not mem_record[1]:
669
fe = createFunctionEventForMemoryEvents(mem_record[0])
670
all_function_events.append(fe)
672
for oom_record in oom_records:
674
fe = createFunctionEventForMemoryEvents(oom_record)
675
all_function_events.append(fe)
677
all_function_events.sort(
678
key=lambda evt: [evt.time_range.start, -evt.time_range.end]
680
return all_function_events
683
class record_function(_ContextDecorator):
684
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
685
Label will only appear if CPU activity tracing is enabled.
687
It is useful when tracing the code profile.
690
name (str): Label assigned to the block of code.
691
node_id (int): ID of node, for distributed profiling. Unset in
692
non-distributed cases.
695
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
696
>>> x = torch.randn((1, 1), requires_grad=True)
697
>>> with torch.autograd.profiler.profile() as prof:
699
... with torch.autograd.profiler.record_function("label-z"): # label the block
703
>>> # xdoctest: +IGNORE_WANT
704
>>> # NOTE: some columns were removed for brevity
705
>>> print(prof.key_averages().table(sort_by="self_cpu_time_total"))
706
----------------------------------- --------------- --------------- ---------------
707
Name Self CPU total % CPU time avg Number of Calls
708
----------------------------------- --------------- --------------- ---------------
709
pow 60.77% 47.470us 3
710
mul 21.73% 25.465us 2
711
PowBackward0 12.03% 121.891us 1
712
torch::autograd::AccumulateGrad 2.70% 6.324us 1
713
label-z 2.13% 12.421us 1
714
torch::autograd::GraphRoot 0.64% 1.503us 1
715
----------------------------------- --------------- --------------- ---------------
716
Self CPU time total: 234.344us
717
CUDA time total: 0.000us
721
def __init__(self, name: str, args: Optional[str] = None):
722
self.name: str = name
723
self.args: Optional[str] = args
724
# Whether or not we should run record function's end callbacks when exiting.
725
self.run_callbacks_on_exit: bool = True
726
# TODO: TorchScript ignores standard type annotation here
727
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
728
self.record = torch.jit.annotate(
729
Optional["torch.classes.profiler._RecordFunction"], None
733
self.record = torch.ops.profiler._record_function_enter_new(
738
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
739
if not self.run_callbacks_on_exit:
742
# Local variable is needed by TorchScript to refine Optional[T] to T
744
assert record is not None
746
# TODO: Too slow with __torch_function__ handling enabled
747
# See https://github.com/pytorch/pytorch/issues/76410
748
if not torch.jit.is_scripting():
749
with torch._C.DisableTorchFunctionSubclass():
750
torch.ops.profiler._record_function_exit._RecordFunction(record)
752
torch.ops.profiler._record_function_exit(record)
754
def _call_end_callbacks_on_future(self, fut: Future[Any]) -> Future[Any]:
755
"""Use for profiling async calls that return a future.
757
Calling this function will extend recording beyond this scope, until the future is
758
satisfied. It is useful for profiling the end to end time of asynchronous calls.
759
This function should only be called once to attach the callback onto the future, and
760
will throw if called multiple times.
763
fut: (torch._C.Future): future for which to schedule
767
A future that completes with the value of the passed in future when
768
the profiling callbacks have ran.
771
# Throw if we have already attached a callback onto the future.
772
if not self.run_callbacks_on_exit:
773
raise RuntimeError("_call_end_callbacks_on_future can only be called once.")
775
# We are scheduling to run this RecordFunction's end callbacks when the
776
# passed in future completes, so don't run end callbacks on exit.
777
self.run_callbacks_on_exit = False
779
# Local variable is needed by TorchScript to refine Optional[T] to T
781
assert record is not None
783
# TODO: Too slow with __torch_function__ handling enabled
784
# See https://github.com/pytorch/pytorch/issues/76410
785
if not torch.jit.is_scripting():
786
with torch._C.DisableTorchFunctionSubclass():
788
torch.ops.profiler._call_end_callbacks_on_jit_fut._RecordFunction(
793
profiled_future = torch.ops.profiler._call_end_callbacks_on_jit_fut(
796
return profiled_future
800
"""Context manager that makes every autograd operation emit an ITT range.
802
It is useful when running the program under Intel(R) VTune Profiler::
804
vtune <--vtune-flags> <regular command here>
806
The Instrumentation and Tracing Technology (ITT) API enables your application to generate and
807
control the collection of trace data during its execution across different Intel tools.
808
This context manager is to annotate Intel(R) VTune Profiling trace. With help of this context manager,
809
you will be able to see labled ranges in Intel(R) VTune Profiler GUI.
812
This context manager should not be called recursively, i.e. at most one
813
instance should be enabled at any given time.
816
enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
818
record_shapes (bool, optional): If ``record_shapes=True``, the itt range wrapping
819
each autograd op will append information about the sizes of Tensor arguments received
820
by that op, in the following format:
821
``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
822
Non-tensor arguments will be represented by ``[]``.
823
Arguments will be listed in the order they are received by the backend op.
824
Please note that this order may not match the order in which those arguments were passed
825
on the Python side. Also note that shape recording may increase the overhead of itt range creation.
829
>>> # xdoctest: +SKIP("Undefined variables")
830
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
831
>>> with torch.autograd.profiler.emit_itt():
836
def __init__(self, enabled=True, record_shapes=False):
837
self.enabled = enabled
839
self.record_shapes = record_shapes
845
raise RuntimeError("ITT annotation context manager is not reentrant")
847
_run_on_profiler_start()
856
_ExperimentalConfig(),
862
def __exit__(self, exc_type, exc_val, exc_tb):
866
_run_on_profiler_stop()
871
"""Context manager that makes every autograd operation emit an NVTX range.
873
It is useful when running the program under nvprof::
875
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
877
Unfortunately, there's no way to force nvprof to flush the data it collected
878
to disk, so for CUDA profiling one has to use this context manager to annotate
879
nvprof traces and wait for the process to exit before inspecting them.
880
Then, either NVIDIA Visual Profiler (nvvp) can be used to visualize the timeline, or
881
:func:`torch.autograd.profiler.load_nvprof` can load the results for inspection
885
This context manager should not be called recursively, i.e. at most one
886
instance should be enabled at any given time.
889
enabled (bool, optional): Setting ``enabled=False`` makes this context manager a no-op.
891
record_shapes (bool, optional): If ``record_shapes=True``, the nvtx range wrapping
892
each autograd op will append information about the sizes of Tensor arguments received
893
by that op, in the following format:
894
``[[arg0.size(0), arg0.size(1), ...], [arg1.size(0), arg1.size(1), ...], ...]``
895
Non-tensor arguments will be represented by ``[]``.
896
Arguments will be listed in the order they are received by the backend op.
897
Please note that this order may not match the order in which those arguments were passed
898
on the Python side. Also note that shape recording may increase the overhead of nvtx range creation.
902
>>> # xdoctest: +SKIP("undefined variables")
903
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD_PROFILER)
904
>>> with torch.cuda.profiler.profile():
905
... model(x) # Warmup CUDA memory allocator and profiler
906
... with torch.autograd.profiler.emit_nvtx():
909
**Forward-backward correlation**
911
When viewing a profile created using :class:`emit_nvtx` in the Nvidia Visual Profiler,
912
correlating each backward-pass op with the corresponding forward-pass op can be difficult.
913
To ease this task, :class:`emit_nvtx` appends sequence number information to the ranges it
916
During the forward pass, each function range is decorated with ``seq=<N>``. ``seq`` is a running
917
counter, incremented each time a new backward Function object is created and stashed for backward.
918
Thus, the ``seq=<N>`` annotation associated with each forward function range tells you that
919
if a backward Function object is created by this forward function,
920
the backward object will receive sequence number N.
921
During the backward pass, the top-level range wrapping each C++ backward Function's
922
``apply()`` call is decorated with ``stashed seq=<M>``. ``M`` is the sequence number that
923
the backward object was created with. By comparing ``stashed seq`` numbers in backward with ``seq``
924
numbers in forward, you can track down which forward op created each backward Function.
926
Any functions executed during the backward pass are also decorated with ``seq=<N>``. During
927
default backward (with ``create_graph=False``) this information is irrelevant, and in fact,
928
``N`` may simply be 0 for all such functions. Only the top-level ranges associated with
929
backward Function objects' ``apply()`` methods are useful, as a way to correlate these Function
930
objects with the earlier forward pass.
934
If, on the other hand, a backward pass with ``create_graph=True`` is underway (in other words,
935
if you are setting up for a double-backward), each function's execution during backward
936
is given a nonzero, useful ``seq=<N>``. Those functions may themselves create Function objects
937
to be executed later during double-backward, just as the original functions in the forward pass did.
938
The relationship between backward and double-backward is conceptually the same as the relationship
939
between forward and backward: The functions still emit current-sequence-number-tagged ranges,
940
the Function objects they create still stash those sequence numbers, and during the eventual
941
double-backward, the Function objects' ``apply()`` ranges are still tagged with ``stashed seq``
942
numbers, which can be compared to `seq` numbers from the backward pass.
945
The sequence number is thread-local, and some forward functions don't create an associated
946
backward Function object (instead delegating that to sub-functions further down the call chain).
947
For these reasons, the correspondence of stashed sequence numbers in
948
backward Function ``apply()`` ranges with `seq` numbers in forward-pass ranges is
949
not guaranteed to be 1 to 1. The sequence numbers alone may not be enough to fully
950
disambiguate which forward function created which
951
backward Function object. You may need to make a judgment based on analytic knowledge of what
952
the expected correspondence should be.
955
def __init__(self, enabled=True, record_shapes=False):
956
self.enabled = enabled
958
self.record_shapes = record_shapes
964
raise RuntimeError("NVTX annotation context manager is not reentrant")
966
torch.cuda.synchronize()
967
_run_on_profiler_start()
976
_ExperimentalConfig(),
982
def __exit__(self, exc_type, exc_val, exc_tb):
985
torch.cuda.synchronize()
987
_run_on_profiler_stop()
991
def load_nvprof(path):
992
"""Open an nvprof trace file and parses autograd annotations.
995
path (str): path to nvprof trace
997
return EventList(parse_nvprof_trace(path))
1001
"""Raises an error if a key is seen more than once."""
1006
def see(self, *key):
1008
Observe a key and raise an error if it is seen multiple times.
1010
if key in self.seen:
1011
raise RuntimeError("duplicate key: " + str(key))
1015
def parse_nvprof_trace(path):
1018
conn = sqlite3.connect(path)
1019
conn.row_factory = sqlite3.Row
1021
# Parse strings table
1023
for r in conn.execute("SELECT _id_ as id, value FROM StringTable"):
1024
strings[r["id"]] = torch._C._demangle(r["value"])
1026
# First, find all functions and create FunctionEvents for them
1029
start.id AS marker_id, start.name, start.timestamp AS start_time, end.timestamp AS end_time
1031
CUPTI_ACTIVITY_KIND_MARKER AS start INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1032
ON start.id = end.id
1034
start.name != 0 AND end.name = 0
1038
unique = EnforceUnique()
1039
for row in conn.execute(marker_query):
1040
unique.see(row["marker_id"])
1041
evt = FunctionEvent(
1042
id=row["marker_id"],
1043
node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
1044
# that pytorch doesn't crash when creating a FunctionEvent() object
1045
name=strings[row["name"]],
1046
start_us=row["start_time"],
1047
end_us=row["end_time"],
1049
) # TODO: find in sqlite database
1050
functions.append(evt)
1051
functions_map[evt.id] = evt
1053
# Now, correlate all kernels with FunctionEvents
1056
start.id AS marker_id, start.name, start.timestamp, end.timestamp,
1057
runtime._id_ AS runtime_id, runtime.cbid, runtime.start AS runtime_start, runtime.end AS runtime_end,
1058
kernel.start AS kernel_start, kernel.end AS kernel_end, kernel.name AS kernel_name
1060
CUPTI_ACTIVITY_KIND_MARKER AS start
1061
INNER JOIN CUPTI_ACTIVITY_KIND_MARKER AS end
1062
ON start.id = end.id
1063
INNER JOIN CUPTI_ACTIVITY_KIND_RUNTIME as runtime
1064
ON (start.timestamp < runtime.start AND runtime.end < end.timestamp)
1065
INNER JOIN CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL AS kernel
1066
ON kernel.correlationId = runtime.correlationId
1068
unique = EnforceUnique()
1069
for row in conn.execute(kernel_query):
1070
unique.see(row["marker_id"], row["runtime_id"])
1071
# 211 is cudaKernelLaunch for cuda >= 9.2
1072
assert row["cbid"] == 211
1073
evt = functions_map[row["marker_id"]]
1075
row["kernel_name"], 0, row["kernel_end"] - row["kernel_start"]
1078
functions.sort(key=lambda evt: evt.time_range.start)
1082
class KinetoStepTracker:
1083
"""Provides an abstraction for incrementing the step count globally.
1085
Previously, we only had one place to mark that a step() has occurred
1086
in the program via pytorch profiler step(). We will now add step hooks
1087
in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
1089
- This could mean programs that already call profiler.step() every
1090
iteration can end up double incrementing step count.
1091
- If a model uses multiple optimizers we can also have double or more
1092
counting of the step.
1094
We fix this by adding a layer of abstraction before calling step()
1095
to the kineto library. The idea is to maintain steps per requester in a dict:
1100
"ProfilerStep": 100, # triggered by profiler step() call
1101
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
1102
"Optimizer2Step": 100,
1105
To figure out the global step count just take the max of dict values (100).
1107
If one of the count increments the max will go up.
1112
"ProfilerStep": 100,
1113
"Optimizer1Step": 101, # Optimizer1 got incremented first say
1114
"Optimizer2Step": 100,
1117
Then global step count is 101
1118
We only call the kineto step() function when global count increments.
1120
NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
1121
for now. The result could be incorrect increments of the step count.
1125
_step_dict: Dict[str, int] = defaultdict(int)
1128
def init_step_count(cls, requester: str):
1130
Initialize for a given requester.
1132
cls._step_dict[requester] = cls._current_step
1135
def erase_step_count(cls, requester: str) -> bool:
1137
Remove a given requester.
1139
return cls._step_dict.pop(requester, None) is not None
1142
def increment_step(cls, requester: str) -> int:
1143
"""Increments the step count for the requester.
1145
Additionally if the max over all step counts has incremented then
1146
trigger the _kineto_step() returns global step count
1148
if requester not in cls._step_dict:
1149
cls.init_step_count(requester)
1150
cls._step_dict[requester] += 1
1152
new_step = max(cls._step_dict.values())
1153
if new_step > cls._current_step:
1154
delta = new_step - cls._current_step
1157
"Profiler step count has increased more than 1 - "
1158
f"current_step = {cls._current_step} step dict = {cls._step_dict}"
1160
for _ in range(0, delta):
1162
cls._current_step = new_step
1163
return cls._current_step
1166
def current_step(cls) -> int:
1168
Get the latest step for any requester
1170
return cls._current_step