2
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
4
It stores information on accesses to tensors to determine if they are synchronized
5
or not. When enabled in a python program and a possible data race is detected, a
6
detailed warning will be printed and the program will exit.
8
It can be enabled either by importing this module and calling
9
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
21
from dataclasses import dataclass, field
22
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
25
import torch.utils._cuda_trace as cuda_trace
26
from torch.utils import _pytree as pytree
27
from torch.utils._python_dispatch import TorchDispatchMode
41
logger = logging.getLogger(__name__)
44
class AccessType(enum.Enum):
49
return "reading from" if self is AccessType.READ else "writing to"
54
r"""Stores information about a single access to a tensor by a kernel.
57
type: either AccessType.READ or AccessType.Write.
58
seq_num: the sequential number of the kernel performing the access.
59
stream: the stream id of the stream executing the kernel.
60
operator: the schema of the launched kernel, which lists the
61
arguments and return type.
62
aliases: the arguments in the schema this access corresponds to.
63
is_output: Whether the tensor was an output of the kernel.
64
stack_trace: the stack summary object captured during access.
73
stack_trace: traceback.StackSummary
76
class SynchronizationError(Exception):
77
"""Base class for errors detected by CUDA Sanitizer."""
82
class UnsynchronizedAccessError(SynchronizationError):
83
"""Stores information about two unsynchronized accesses to one data pointer."""
88
allocation_stack_trace: Optional[traceback.StackSummary],
89
current_access: Access,
90
previous_access: Access,
92
self.data_ptr = data_ptr
93
self.allocation_stack_trace = allocation_stack_trace
94
self.current_access = current_access
95
self.previous_access = previous_access
98
def format_access(access: Access):
99
message.write(f"{access.operator}\n{access.type}")
101
message.write(" argument(s) " + ", ".join(access.aliases))
103
message.write(", and to")
105
message.write(" the output")
107
f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
110
with io.StringIO() as message:
114
============================
115
CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
116
Access by stream {self.current_access.stream} during kernel:
120
format_access(self.current_access)
123
f"Previous access by stream {self.previous_access.stream} during kernel:\n"
125
format_access(self.previous_access)
127
if self.allocation_stack_trace:
129
"Tensor was allocated with stack trace:\n"
130
f"{''.join(self.allocation_stack_trace.format())}"
133
message.write("Trace for tensor allocation not found.")
134
return message.getvalue()
137
class CUDASanitizerErrors(Exception):
138
"""Wrapper class for errors reported by CUDA Sanitizer."""
140
def __init__(self, errors: List[SynchronizationError]):
144
return f"detected {len(self.errors)} errors"
149
r"""Stores information about a single tensor and recent accesses to it.
152
allocation_stack_trace: the stack summary object captured during tensor
153
allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
154
reads: list of read accesses to the tensor that were performed since
156
write: the last write access to the tensor.
159
allocation_stack_trace: Optional[traceback.StackSummary]
160
reads: List[Access] = field(default_factory=list)
161
write: Optional[Access] = None
164
class _TensorsAccessed:
166
self.accesses: Dict[DataPtr, TensorInfo] = {}
168
def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
169
if data_ptr not in self.accesses:
171
"Found tensor with pointer: %s, but no matching tensor "
172
"allocation in the trace. Backfilling the trace now. "
173
"Perhaps the sanitizer was enabled after some torch operations?",
176
self.create_tensor(data_ptr, None)
178
def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
179
if data_ptr in self.accesses:
181
"Found duplicate tensor allocation in the trace for tensor with "
182
"pointer: %s. Assuming the trace for tensor deallocation "
183
"wasn't caught and backfilling it now. "
184
"Perhaps the sanitizer was enabled after some torch operations?",
187
self.delete_tensor(data_ptr)
190
self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
192
self.accesses[data_ptr] = TensorInfo(stack_trace)
194
def delete_tensor(self, data_ptr: DataPtr) -> None:
195
del self.accesses[data_ptr]
197
def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
198
return True if self.accesses[data_ptr].reads else False
200
def get_allocation_stack_trace(
201
self, data_ptr: DataPtr
202
) -> Optional[traceback.StackSummary]:
203
return self.accesses[data_ptr].allocation_stack_trace
205
def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
206
return self.accesses[data_ptr].write
208
def get_reads(self, data_ptr: DataPtr) -> List[Access]:
209
return self.accesses[data_ptr].reads
211
def add_read(self, data_ptr: DataPtr, access: Access) -> None:
212
self.accesses[data_ptr].reads.append(access)
214
def set_write(self, data_ptr: DataPtr, access: Access) -> None:
215
self.accesses[data_ptr].write = access
216
self.accesses[data_ptr].reads = []
219
class StreamSynchronizations:
221
self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
222
self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
223
self.host_sync_state: Dict[StreamId, SeqNum] = {}
224
self.create_stream(DEFAULT_STREAM_ID)
226
def _ensure_stream_exists(self, stream: StreamId) -> None:
227
if stream not in self.current_sync_states:
229
"Found Stream with id: %s, but no matching stream "
230
"creation in the trace. Backfilling the trace now. "
231
"Perhaps the sanitizer was enabled after some torch operations?",
234
self.create_stream(stream)
236
def _ensure_event_exists(self, event: EventId) -> None:
237
if event not in self.recorded_sync_states:
239
"Found Event with id: %s, but no matching event "
240
"creation in the trace. Backfilling the trace now. "
241
"Perhaps the sanitizer was enabled after some torch operations?",
244
self.create_event(event)
246
def _ensure_event_does_not_exist(self, event: EventId) -> None:
247
if event in self.recorded_sync_states:
249
"Found duplicate event creation in the trace for event with "
250
"id: %s. Assuming the trace for event deletion wasn't caught "
251
"and backfilling it now. "
252
"Perhaps the sanitizer was enabled after some torch operations?",
255
self.delete_event(event)
257
def create_stream(self, stream: StreamId) -> None:
258
if stream in self.current_sync_states:
260
"Found duplicate Stream creation in the trace for Stream with "
261
"id: %s. PyTorch Streams are only created once, so this "
262
"trace entry is ignored.",
266
self.host_sync_state[stream] = 0
267
self.current_sync_states[stream] = self.host_sync_state.copy()
269
def create_event(self, event: EventId) -> None:
270
self._ensure_event_does_not_exist(event)
271
self.recorded_sync_states[event] = {}
273
def delete_event(self, event: EventId) -> None:
274
self._ensure_event_exists(event)
275
del self.recorded_sync_states[event]
277
def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
278
self._ensure_stream_exists(stream)
279
self.current_sync_states[stream][stream] = seq_num
281
def record_state(self, event: EventId, stream: StreamId) -> None:
282
self._ensure_event_exists(event)
283
self._ensure_stream_exists(stream)
284
self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
286
def _state_wait_for_other(
287
self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
289
for stream, seq_num in other.items():
290
state[stream] = max(state.get(stream, -1), seq_num)
292
def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
293
self._ensure_stream_exists(stream)
294
self._ensure_event_exists(event)
295
self._state_wait_for_other(
296
self.current_sync_states[stream], self.recorded_sync_states[event]
299
def all_streams_wait_for_event(self, event: EventId) -> None:
300
self._ensure_event_exists(event)
301
for stream in self.current_sync_states.keys():
302
self.stream_wait_for_event(stream, event)
304
self._state_wait_for_other(
305
self.host_sync_state, self.recorded_sync_states[event]
308
def all_streams_wait_for_stream(self, stream: StreamId) -> None:
309
self._ensure_stream_exists(stream)
310
for state in self.current_sync_states.values():
311
self._state_wait_for_other(state, self.current_sync_states[stream])
313
self._state_wait_for_other(
314
self.host_sync_state, self.current_sync_states[stream]
317
def sync_all_streams(self) -> None:
318
for stream, state in self.current_sync_states.items():
319
self.host_sync_state[stream] = state[stream]
321
for state in self.current_sync_states.values():
322
self._state_wait_for_other(state, self.host_sync_state)
324
def is_ordered_after(
325
self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
327
self._ensure_stream_exists(current_stream)
328
self._ensure_stream_exists(other_stream)
329
return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
333
"""Analyzes CSAN trace for synchronization errors.
335
Stores information on each stream's synchronizations with other streams as well
336
as tensor accesses to determine whether a given kernel launch might cause a
341
self.tensors_accessed = _TensorsAccessed()
342
self.syncs = StreamSynchronizations()
343
self.seq_num: SeqNum = 0
345
def _handle_kernel_launch(
348
read_only: Set[DataPtr],
349
read_write: Set[DataPtr],
350
outputs: Set[DataPtr],
352
tensor_aliases: Dict[int, List[str]],
353
) -> List[SynchronizationError]:
355
data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
357
if previous_access is None:
359
if not self.syncs.is_ordered_after(
360
current_access.stream, previous_access.seq_num, previous_access.stream
363
UnsynchronizedAccessError(
365
self.tensors_accessed.get_allocation_stack_trace(data_ptr),
371
error_list: List[SynchronizationError] = []
373
self.syncs.update_seq_num(stream, self.seq_num)
374
stack_trace = traceback.StackSummary.extract(
375
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
379
stack_trace.reverse()
381
for data_ptr in read_only:
382
self.tensors_accessed.ensure_tensor_exists(data_ptr)
383
current_access = Access(
388
tensor_aliases[data_ptr],
393
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
395
self.tensors_accessed.add_read(data_ptr, current_access)
397
for data_ptr in read_write:
398
self.tensors_accessed.ensure_tensor_exists(data_ptr)
399
current_access = Access(
404
tensor_aliases[data_ptr],
408
if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
409
for previous_access in self.tensors_accessed.get_reads(data_ptr):
410
check_conflict(data_ptr, current_access, previous_access)
413
data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
415
self.tensors_accessed.set_write(data_ptr, current_access)
419
def _handle_event_creation(self, event: EventId) -> None:
420
self.syncs.create_event(event)
422
def _handle_event_deletion(self, event: EventId) -> None:
423
self.syncs.delete_event(event)
425
def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
426
self.syncs.record_state(event, stream)
428
def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
429
self.syncs.stream_wait_for_event(stream, event)
431
def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
432
self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
433
stack_trace = traceback.StackSummary.extract(
434
traceback.walk_stack(inspect.currentframe()), lookup_lines=False
438
stack_trace.reverse()
439
self.tensors_accessed.create_tensor(
444
def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
445
self.tensors_accessed.ensure_tensor_exists(data_ptr)
446
self.tensors_accessed.delete_tensor(data_ptr)
448
def _handle_stream_creation(self, stream: StreamId) -> None:
449
self.syncs.create_stream(stream)
451
def _handle_device_synchronization(self) -> None:
452
self.syncs.sync_all_streams()
454
def _handle_stream_synchronization(self, stream: StreamId) -> None:
455
self.syncs.all_streams_wait_for_stream(stream)
457
def _handle_event_synchronization(self, event: EventId) -> None:
458
self.syncs.all_streams_wait_for_event(event)
461
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
462
for arg, value in a.items():
464
yield arg, value, b[arg]
468
schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
469
) -> Iterator[Tuple[torch.Argument, Any]]:
470
schema_args = schema.arguments[: len(args)]
471
schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
473
yield from zip(schema_args, args)
475
for _, argument, value in zip_by_key(schema_kwargs, kwargs):
476
yield (argument, value)
479
class ArgumentHandler:
481
self.dataptrs_read: Set[DataPtr] = set()
482
self.dataptrs_written: Set[DataPtr] = set()
483
self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
484
self.outputs: Set[DataPtr] = set()
486
def _handle_argument(
490
name: Optional[str] = None,
491
is_output: bool = False,
493
if isinstance(value, torch.Tensor) and value.is_cuda:
494
data_ptr = value.data_ptr()
496
self.dataptrs_written.add(data_ptr)
498
self.dataptrs_read.add(data_ptr)
500
self.tensor_aliases.setdefault(data_ptr, [])
502
self.tensor_aliases[data_ptr].append(name)
504
self.outputs.add(data_ptr)
508
schema: torch.FunctionSchema,
509
args: Tuple[Any, ...],
510
kwargs: Dict[str, Any],
512
for argument, value in zip_arguments(schema, args, kwargs):
513
is_write = argument.alias_info is not None and argument.alias_info.is_write
516
self._handle_argument, is_write=is_write, name=argument.name
521
def parse_outputs(self, outputs: Any) -> None:
523
functools.partial(self._handle_argument, is_write=True, is_output=True),
528
class CUDASanitizerDispatchMode(TorchDispatchMode):
530
self.event_handler = EventHandler()
531
torch._C._activate_cuda_trace()
532
cuda_trace.register_callback_for_cuda_event_creation(
533
self.event_handler._handle_event_creation
535
cuda_trace.register_callback_for_cuda_event_deletion(
536
self.event_handler._handle_event_deletion
538
cuda_trace.register_callback_for_cuda_event_record(
539
self.event_handler._handle_event_record
541
cuda_trace.register_callback_for_cuda_event_wait(
542
self.event_handler._handle_event_wait
544
cuda_trace.register_callback_for_cuda_memory_allocation(
545
self.event_handler._handle_memory_allocation
547
cuda_trace.register_callback_for_cuda_memory_deallocation(
548
self.event_handler._handle_memory_deallocation
550
cuda_trace.register_callback_for_cuda_stream_creation(
551
self.event_handler._handle_stream_creation
553
cuda_trace.register_callback_for_cuda_device_synchronization(
554
self.event_handler._handle_device_synchronization
556
cuda_trace.register_callback_for_cuda_stream_synchronization(
557
self.event_handler._handle_stream_synchronization
559
cuda_trace.register_callback_for_cuda_event_synchronization(
560
self.event_handler._handle_event_synchronization
563
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
567
argument_handler = ArgumentHandler()
568
argument_handler.parse_inputs(func._schema, args, kwargs)
570
outputs = func(*args, **kwargs)
572
argument_handler.parse_outputs(outputs)
573
errors = self.event_handler._handle_kernel_launch(
574
torch.cuda.current_stream().cuda_stream,
575
argument_handler.dataptrs_read - argument_handler.dataptrs_written,
576
argument_handler.dataptrs_written,
577
argument_handler.outputs,
579
argument_handler.tensor_aliases,
583
print(error, file=sys.stderr)
584
raise CUDASanitizerErrors(errors)
590
"""Manages the lifetime of a CUDASanitizer dispatch mode object.
592
The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
593
context manager in the enable function/destructor, respectively. This is to
594
explicitly set the lifetime of the dispatch mode object to that of the application.
595
This approach was deemed more elegant than using the atexit module.
599
self.dispatch = CUDASanitizerDispatchMode()
603
self.dispatch.__enter__()
608
self.dispatch.__exit__(None, None, None)
611
def enable_cuda_sanitizer():
612
"""Enable CUDA Sanitizer.
614
The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
615
for synchronization errors. All data races found will be printed to the standard
616
error output along with stack traces of suspected causes. For best results, the
617
sanitizer should be enabled at the very beginning of the program.
619
cuda_sanitizer.enable()
622
cuda_sanitizer = CUDASanitizer()