pytorch

Форк
0
/
_sanitizer.py 
622 строки · 21.9 Кб
1
r"""
2
This module introduces CUDA Sanitizer, a tool for detecting synchronization errors between kernels ran on different streams.
3

4
It stores information on accesses to tensors to determine if they are synchronized
5
or not. When enabled in a python program and a possible data race is detected, a
6
detailed warning will be printed and the program will exit.
7

8
It can be enabled either by importing this module and calling
9
:func:`enable_cuda_sanitizer()` or by exporting the ``TORCH_CUDA_SANITIZER``
10
environment variable.
11
"""
12

13
import enum
14
import functools
15
import inspect
16
import io
17
import logging
18
import sys
19
import textwrap
20
import traceback
21
from dataclasses import dataclass, field
22
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar
23

24
import torch
25
import torch.utils._cuda_trace as cuda_trace
26
from torch.utils import _pytree as pytree
27
from torch.utils._python_dispatch import TorchDispatchMode
28

29

30
DEFAULT_STREAM_ID = 0
31

32
TK = TypeVar("TK")
33
TVa = TypeVar("TVa")
34
TVb = TypeVar("TVb")
35

36
DataPtr = int
37
StreamId = int
38
EventId = int
39
SeqNum = int
40

41
logger = logging.getLogger(__name__)
42

43

44
class AccessType(enum.Enum):
45
    READ = enum.auto()
46
    WRITE = enum.auto()
47

48
    def __str__(self):
49
        return "reading from" if self is AccessType.READ else "writing to"
50

51

52
@dataclass
53
class Access:
54
    r"""Stores information about a single access to a tensor by a kernel.
55

56
    Args:
57
        type: either AccessType.READ or AccessType.Write.
58
        seq_num: the sequential number of the kernel performing the access.
59
        stream: the stream id of the stream executing the kernel.
60
        operator: the schema of the launched kernel, which lists the
61
            arguments and return type.
62
        aliases: the arguments in the schema this access corresponds to.
63
        is_output: Whether the tensor was an output of the kernel.
64
        stack_trace: the stack summary object captured during access.
65
    """
66

67
    type: AccessType
68
    seq_num: SeqNum
69
    stream: StreamId
70
    operator: str
71
    aliases: List[str]
72
    is_output: bool
73
    stack_trace: traceback.StackSummary
74

75

76
class SynchronizationError(Exception):
77
    """Base class for errors detected by CUDA Sanitizer."""
78

79
    pass
80

81

82
class UnsynchronizedAccessError(SynchronizationError):
83
    """Stores information about two unsynchronized accesses to one data pointer."""
84

85
    def __init__(
86
        self,
87
        data_ptr: DataPtr,
88
        allocation_stack_trace: Optional[traceback.StackSummary],
89
        current_access: Access,
90
        previous_access: Access,
91
    ):
92
        self.data_ptr = data_ptr
93
        self.allocation_stack_trace = allocation_stack_trace
94
        self.current_access = current_access
95
        self.previous_access = previous_access
96

97
    def __str__(self):
98
        def format_access(access: Access):
99
            message.write(f"{access.operator}\n{access.type}")
100
            if access.aliases:
101
                message.write(" argument(s) " + ", ".join(access.aliases))
102
                if access.is_output:
103
                    message.write(", and to")
104
            if access.is_output:
105
                message.write(" the output")
106
            message.write(
107
                f"\nWith stack trace:\n{''.join(access.stack_trace.format())}\n"
108
            )
109

110
        with io.StringIO() as message:
111
            message.write(
112
                textwrap.dedent(
113
                    f"""\
114
                    ============================
115
                    CSAN detected a possible data race on tensor with data pointer {self.data_ptr}
116
                    Access by stream {self.current_access.stream} during kernel:
117
                    """
118
                )
119
            )
120
            format_access(self.current_access)
121

122
            message.write(
123
                f"Previous access by stream {self.previous_access.stream} during kernel:\n"
124
            )
125
            format_access(self.previous_access)
126

127
            if self.allocation_stack_trace:
128
                message.write(
129
                    "Tensor was allocated with stack trace:\n"
130
                    f"{''.join(self.allocation_stack_trace.format())}"
131
                )
132
            else:
133
                message.write("Trace for tensor allocation not found.")
134
            return message.getvalue()
135

136

137
class CUDASanitizerErrors(Exception):
138
    """Wrapper class for errors reported by CUDA Sanitizer."""
139

140
    def __init__(self, errors: List[SynchronizationError]):
141
        self.errors = errors
142

143
    def __str__(self):
144
        return f"detected {len(self.errors)} errors"
145

146

147
@dataclass
148
class TensorInfo:
149
    r"""Stores information about a single tensor and recent accesses to it.
150

151
    Args:
152
        allocation_stack_trace: the stack summary object captured during tensor
153
            allocation. Can be ``None`` if the allocation wasn't caught by CSAN.
154
        reads: list of read accesses to the tensor that were performed since
155
            the last write.
156
        write: the last write access to the tensor.
157
    """
158

159
    allocation_stack_trace: Optional[traceback.StackSummary]
160
    reads: List[Access] = field(default_factory=list)
161
    write: Optional[Access] = None
162

163

164
class _TensorsAccessed:
165
    def __init__(self):
166
        self.accesses: Dict[DataPtr, TensorInfo] = {}
167

168
    def ensure_tensor_exists(self, data_ptr: DataPtr) -> None:
169
        if data_ptr not in self.accesses:
170
            logger.info(
171
                "Found tensor with pointer: %s, but no matching tensor "
172
                "allocation in the trace. Backfilling the trace now. "
173
                "Perhaps the sanitizer was enabled after some torch operations?",
174
                data_ptr,
175
            )
176
            self.create_tensor(data_ptr, None)
177

178
    def ensure_tensor_does_not_exist(self, data_ptr: DataPtr) -> None:
179
        if data_ptr in self.accesses:
180
            logger.info(
181
                "Found duplicate tensor allocation in the trace for tensor with "
182
                "pointer: %s. Assuming the trace for tensor deallocation "
183
                "wasn't caught and backfilling it now. "
184
                "Perhaps the sanitizer was enabled after some torch operations?",
185
                data_ptr,
186
            )
187
            self.delete_tensor(data_ptr)
188

189
    def create_tensor(
190
        self, data_ptr: DataPtr, stack_trace: Optional[traceback.StackSummary]
191
    ) -> None:
192
        self.accesses[data_ptr] = TensorInfo(stack_trace)
193

194
    def delete_tensor(self, data_ptr: DataPtr) -> None:
195
        del self.accesses[data_ptr]
196

197
    def were_there_reads_since_last_write(self, data_ptr: DataPtr) -> bool:
198
        return True if self.accesses[data_ptr].reads else False
199

200
    def get_allocation_stack_trace(
201
        self, data_ptr: DataPtr
202
    ) -> Optional[traceback.StackSummary]:
203
        return self.accesses[data_ptr].allocation_stack_trace
204

205
    def get_write(self, data_ptr: DataPtr) -> Optional[Access]:
206
        return self.accesses[data_ptr].write
207

208
    def get_reads(self, data_ptr: DataPtr) -> List[Access]:
209
        return self.accesses[data_ptr].reads
210

211
    def add_read(self, data_ptr: DataPtr, access: Access) -> None:
212
        self.accesses[data_ptr].reads.append(access)
213

214
    def set_write(self, data_ptr: DataPtr, access: Access) -> None:
215
        self.accesses[data_ptr].write = access
216
        self.accesses[data_ptr].reads = []
217

218

219
class StreamSynchronizations:
220
    def __init__(self):
221
        self.current_sync_states: Dict[StreamId, Dict[StreamId, SeqNum]] = {}
222
        self.recorded_sync_states: Dict[EventId, Dict[StreamId, SeqNum]] = {}
223
        self.host_sync_state: Dict[StreamId, SeqNum] = {}
224
        self.create_stream(DEFAULT_STREAM_ID)
225

226
    def _ensure_stream_exists(self, stream: StreamId) -> None:
227
        if stream not in self.current_sync_states:
228
            logger.info(
229
                "Found Stream with id: %s, but no matching stream "
230
                "creation in the trace. Backfilling the trace now. "
231
                "Perhaps the sanitizer was enabled after some torch operations?",
232
                stream,
233
            )
234
            self.create_stream(stream)
235

236
    def _ensure_event_exists(self, event: EventId) -> None:
237
        if event not in self.recorded_sync_states:
238
            logger.info(
239
                "Found Event with id: %s, but no matching event "
240
                "creation in the trace. Backfilling the trace now. "
241
                "Perhaps the sanitizer was enabled after some torch operations?",
242
                event,
243
            )
244
            self.create_event(event)
245

246
    def _ensure_event_does_not_exist(self, event: EventId) -> None:
247
        if event in self.recorded_sync_states:
248
            logger.info(
249
                "Found duplicate event creation in the trace for event with "
250
                "id: %s. Assuming the trace for event deletion wasn't caught "
251
                "and backfilling it now. "
252
                "Perhaps the sanitizer was enabled after some torch operations?",
253
                event,
254
            )
255
            self.delete_event(event)
256

257
    def create_stream(self, stream: StreamId) -> None:
258
        if stream in self.current_sync_states:
259
            logger.info(
260
                "Found duplicate Stream creation in the trace for Stream with "
261
                "id: %s. PyTorch Streams are only created once, so this "
262
                "trace entry is ignored.",
263
                stream,
264
            )
265
        else:
266
            self.host_sync_state[stream] = 0
267
            self.current_sync_states[stream] = self.host_sync_state.copy()
268

269
    def create_event(self, event: EventId) -> None:
270
        self._ensure_event_does_not_exist(event)
271
        self.recorded_sync_states[event] = {}
272

273
    def delete_event(self, event: EventId) -> None:
274
        self._ensure_event_exists(event)
275
        del self.recorded_sync_states[event]
276

277
    def update_seq_num(self, stream: StreamId, seq_num: SeqNum) -> None:
278
        self._ensure_stream_exists(stream)
279
        self.current_sync_states[stream][stream] = seq_num
280

281
    def record_state(self, event: EventId, stream: StreamId) -> None:
282
        self._ensure_event_exists(event)
283
        self._ensure_stream_exists(stream)
284
        self.recorded_sync_states[event] = self.current_sync_states[stream].copy()
285

286
    def _state_wait_for_other(
287
        self, state: Dict[StreamId, SeqNum], other: Dict[StreamId, SeqNum]
288
    ) -> None:
289
        for stream, seq_num in other.items():
290
            state[stream] = max(state.get(stream, -1), seq_num)
291

292
    def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None:
293
        self._ensure_stream_exists(stream)
294
        self._ensure_event_exists(event)
295
        self._state_wait_for_other(
296
            self.current_sync_states[stream], self.recorded_sync_states[event]
297
        )
298

299
    def all_streams_wait_for_event(self, event: EventId) -> None:
300
        self._ensure_event_exists(event)
301
        for stream in self.current_sync_states.keys():
302
            self.stream_wait_for_event(stream, event)
303

304
        self._state_wait_for_other(
305
            self.host_sync_state, self.recorded_sync_states[event]
306
        )
307

308
    def all_streams_wait_for_stream(self, stream: StreamId) -> None:
309
        self._ensure_stream_exists(stream)
310
        for state in self.current_sync_states.values():
311
            self._state_wait_for_other(state, self.current_sync_states[stream])
312

313
        self._state_wait_for_other(
314
            self.host_sync_state, self.current_sync_states[stream]
315
        )
316

317
    def sync_all_streams(self) -> None:
318
        for stream, state in self.current_sync_states.items():
319
            self.host_sync_state[stream] = state[stream]
320

321
        for state in self.current_sync_states.values():
322
            self._state_wait_for_other(state, self.host_sync_state)
323

324
    def is_ordered_after(
325
        self, current_stream: StreamId, seq_num: SeqNum, other_stream: StreamId
326
    ) -> bool:
327
        self._ensure_stream_exists(current_stream)
328
        self._ensure_stream_exists(other_stream)
329
        return seq_num <= self.current_sync_states[current_stream].get(other_stream, -1)
330

331

332
class EventHandler:
333
    """Analyzes CSAN trace for synchronization errors.
334

335
    Stores information on each stream's synchronizations with other streams as well
336
    as tensor accesses to determine whether a given kernel launch might cause a
337
    data race.
338
    """
339

340
    def __init__(self):
341
        self.tensors_accessed = _TensorsAccessed()
342
        self.syncs = StreamSynchronizations()
343
        self.seq_num: SeqNum = 0
344

345
    def _handle_kernel_launch(
346
        self,
347
        stream: StreamId,
348
        read_only: Set[DataPtr],
349
        read_write: Set[DataPtr],
350
        outputs: Set[DataPtr],
351
        operator: str,
352
        tensor_aliases: Dict[int, List[str]],
353
    ) -> List[SynchronizationError]:
354
        def check_conflict(
355
            data_ptr: DataPtr, current_access: Access, previous_access: Optional[Access]
356
        ) -> None:
357
            if previous_access is None:
358
                return
359
            if not self.syncs.is_ordered_after(
360
                current_access.stream, previous_access.seq_num, previous_access.stream
361
            ):
362
                error_list.append(
363
                    UnsynchronizedAccessError(
364
                        data_ptr,
365
                        self.tensors_accessed.get_allocation_stack_trace(data_ptr),
366
                        current_access,
367
                        previous_access,
368
                    )
369
                )
370

371
        error_list: List[SynchronizationError] = []
372
        self.seq_num += 1
373
        self.syncs.update_seq_num(stream, self.seq_num)
374
        stack_trace = traceback.StackSummary.extract(
375
            traceback.walk_stack(inspect.currentframe()), lookup_lines=False
376
        )
377
        # The stack trace generated in this way is in the inverse order, so it must be
378
        # reversed.
379
        stack_trace.reverse()
380

381
        for data_ptr in read_only:
382
            self.tensors_accessed.ensure_tensor_exists(data_ptr)
383
            current_access = Access(
384
                AccessType.READ,
385
                self.seq_num,
386
                stream,
387
                operator,
388
                tensor_aliases[data_ptr],
389
                data_ptr in outputs,
390
                stack_trace,
391
            )
392
            check_conflict(
393
                data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
394
            )
395
            self.tensors_accessed.add_read(data_ptr, current_access)
396

397
        for data_ptr in read_write:
398
            self.tensors_accessed.ensure_tensor_exists(data_ptr)
399
            current_access = Access(
400
                AccessType.WRITE,
401
                self.seq_num,
402
                stream,
403
                operator,
404
                tensor_aliases[data_ptr],
405
                data_ptr in outputs,
406
                stack_trace,
407
            )
408
            if self.tensors_accessed.were_there_reads_since_last_write(data_ptr):
409
                for previous_access in self.tensors_accessed.get_reads(data_ptr):
410
                    check_conflict(data_ptr, current_access, previous_access)
411
            else:
412
                check_conflict(
413
                    data_ptr, current_access, self.tensors_accessed.get_write(data_ptr)
414
                )
415
            self.tensors_accessed.set_write(data_ptr, current_access)
416

417
        return error_list
418

419
    def _handle_event_creation(self, event: EventId) -> None:
420
        self.syncs.create_event(event)
421

422
    def _handle_event_deletion(self, event: EventId) -> None:
423
        self.syncs.delete_event(event)
424

425
    def _handle_event_record(self, event: EventId, stream: StreamId) -> None:
426
        self.syncs.record_state(event, stream)
427

428
    def _handle_event_wait(self, event: EventId, stream: StreamId) -> None:
429
        self.syncs.stream_wait_for_event(stream, event)
430

431
    def _handle_memory_allocation(self, data_ptr: DataPtr) -> None:
432
        self.tensors_accessed.ensure_tensor_does_not_exist(data_ptr)
433
        stack_trace = traceback.StackSummary.extract(
434
            traceback.walk_stack(inspect.currentframe()), lookup_lines=False
435
        )
436
        # The stack trace generated in this way is in the inverse order, so it must be
437
        # reversed.
438
        stack_trace.reverse()
439
        self.tensors_accessed.create_tensor(
440
            data_ptr,
441
            stack_trace,
442
        )
443

444
    def _handle_memory_deallocation(self, data_ptr: DataPtr) -> None:
445
        self.tensors_accessed.ensure_tensor_exists(data_ptr)
446
        self.tensors_accessed.delete_tensor(data_ptr)
447

448
    def _handle_stream_creation(self, stream: StreamId) -> None:
449
        self.syncs.create_stream(stream)
450

451
    def _handle_device_synchronization(self) -> None:
452
        self.syncs.sync_all_streams()
453

454
    def _handle_stream_synchronization(self, stream: StreamId) -> None:
455
        self.syncs.all_streams_wait_for_stream(stream)
456

457
    def _handle_event_synchronization(self, event: EventId) -> None:
458
        self.syncs.all_streams_wait_for_event(event)
459

460

461
def zip_by_key(a: Dict[TK, TVa], b: Dict[TK, TVb]) -> Iterator[Tuple[TK, TVa, TVb]]:
462
    for arg, value in a.items():
463
        if arg in b:
464
            yield arg, value, b[arg]
465

466

467
def zip_arguments(
468
    schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
469
) -> Iterator[Tuple[torch.Argument, Any]]:
470
    schema_args = schema.arguments[: len(args)]
471
    schema_kwargs = {arg.name: arg for arg in schema.arguments[len(args) :]}
472

473
    yield from zip(schema_args, args)
474

475
    for _, argument, value in zip_by_key(schema_kwargs, kwargs):
476
        yield (argument, value)
477

478

479
class ArgumentHandler:
480
    def __init__(self):
481
        self.dataptrs_read: Set[DataPtr] = set()
482
        self.dataptrs_written: Set[DataPtr] = set()
483
        self.tensor_aliases: Dict[DataPtr, List[str]] = dict()
484
        self.outputs: Set[DataPtr] = set()
485

486
    def _handle_argument(
487
        self,
488
        value: Any,
489
        is_write: bool,
490
        name: Optional[str] = None,
491
        is_output: bool = False,
492
    ) -> None:
493
        if isinstance(value, torch.Tensor) and value.is_cuda:
494
            data_ptr = value.data_ptr()
495
            if is_write:
496
                self.dataptrs_written.add(data_ptr)
497
            else:
498
                self.dataptrs_read.add(data_ptr)
499

500
            self.tensor_aliases.setdefault(data_ptr, [])
501
            if name is not None:
502
                self.tensor_aliases[data_ptr].append(name)
503
            if is_output:
504
                self.outputs.add(data_ptr)
505

506
    def parse_inputs(
507
        self,
508
        schema: torch.FunctionSchema,
509
        args: Tuple[Any, ...],
510
        kwargs: Dict[str, Any],
511
    ) -> None:
512
        for argument, value in zip_arguments(schema, args, kwargs):
513
            is_write = argument.alias_info is not None and argument.alias_info.is_write
514
            pytree.tree_map_(
515
                functools.partial(
516
                    self._handle_argument, is_write=is_write, name=argument.name
517
                ),
518
                value,
519
            )
520

521
    def parse_outputs(self, outputs: Any) -> None:
522
        pytree.tree_map_(
523
            functools.partial(self._handle_argument, is_write=True, is_output=True),
524
            outputs,
525
        )
526

527

528
class CUDASanitizerDispatchMode(TorchDispatchMode):
529
    def __init__(self):
530
        self.event_handler = EventHandler()
531
        torch._C._activate_cuda_trace()
532
        cuda_trace.register_callback_for_cuda_event_creation(
533
            self.event_handler._handle_event_creation
534
        )
535
        cuda_trace.register_callback_for_cuda_event_deletion(
536
            self.event_handler._handle_event_deletion
537
        )
538
        cuda_trace.register_callback_for_cuda_event_record(
539
            self.event_handler._handle_event_record
540
        )
541
        cuda_trace.register_callback_for_cuda_event_wait(
542
            self.event_handler._handle_event_wait
543
        )
544
        cuda_trace.register_callback_for_cuda_memory_allocation(
545
            self.event_handler._handle_memory_allocation
546
        )
547
        cuda_trace.register_callback_for_cuda_memory_deallocation(
548
            self.event_handler._handle_memory_deallocation
549
        )
550
        cuda_trace.register_callback_for_cuda_stream_creation(
551
            self.event_handler._handle_stream_creation
552
        )
553
        cuda_trace.register_callback_for_cuda_device_synchronization(
554
            self.event_handler._handle_device_synchronization
555
        )
556
        cuda_trace.register_callback_for_cuda_stream_synchronization(
557
            self.event_handler._handle_stream_synchronization
558
        )
559
        cuda_trace.register_callback_for_cuda_event_synchronization(
560
            self.event_handler._handle_event_synchronization
561
        )
562

563
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
564
        if kwargs is None:
565
            kwargs = {}
566

567
        argument_handler = ArgumentHandler()
568
        argument_handler.parse_inputs(func._schema, args, kwargs)
569

570
        outputs = func(*args, **kwargs)
571

572
        argument_handler.parse_outputs(outputs)
573
        errors = self.event_handler._handle_kernel_launch(
574
            torch.cuda.current_stream().cuda_stream,
575
            argument_handler.dataptrs_read - argument_handler.dataptrs_written,
576
            argument_handler.dataptrs_written,
577
            argument_handler.outputs,
578
            func._schema,
579
            argument_handler.tensor_aliases,
580
        )
581
        if errors:
582
            for error in errors:
583
                print(error, file=sys.stderr)
584
            raise CUDASanitizerErrors(errors)
585

586
        return outputs
587

588

589
class CUDASanitizer:
590
    """Manages the lifetime of a CUDASanitizer dispatch mode object.
591

592
    The CUDASanitizer class wraps the entering/exiting functions of the dispatch mode
593
    context manager in the enable function/destructor, respectively. This is to
594
    explicitly set the lifetime of the dispatch mode object to that of the application.
595
    This approach was deemed more elegant than using the atexit module.
596
    """
597

598
    def __init__(self):
599
        self.dispatch = CUDASanitizerDispatchMode()
600
        self.enabled = False
601

602
    def enable(self):
603
        self.dispatch.__enter__()
604
        self.enabled = True
605

606
    def __del__(self):
607
        if self.enabled:
608
            self.dispatch.__exit__(None, None, None)
609

610

611
def enable_cuda_sanitizer():
612
    """Enable CUDA Sanitizer.
613

614
    The sanitizer will begin to analyze low-level CUDA calls invoked by torch functions
615
    for synchronization errors. All data races found will be printed to the standard
616
    error output along with stack traces of suspected causes. For best results, the
617
    sanitizer should be enabled at the very beginning of the program.
618
    """
619
    cuda_sanitizer.enable()
620

621

622
cuda_sanitizer = CUDASanitizer()
623

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.