pytorch

Форк
0
/
_cuda_trace.py 
99 строк · 3.1 Кб
1
import logging
2
from typing import Callable, Generic, List
3

4
from typing_extensions import ParamSpec  # Python 3.10+
5

6
logger = logging.getLogger(__name__)
7
P = ParamSpec("P")
8

9

10
class CallbackRegistry(Generic[P]):
11
    def __init__(self, name: str):
12
        self.name = name
13
        self.callback_list: List[Callable[P, None]] = []
14

15
    def add_callback(self, cb: Callable[P, None]) -> None:
16
        self.callback_list.append(cb)
17

18
    def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
19
        for cb in self.callback_list:
20
            try:
21
                cb(*args, **kwargs)
22
            except Exception as e:
23
                logger.exception(
24
                    "Exception in callback for %s registered with CUDA trace", self.name
25
                )
26

27

28
CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
29
    "CUDA event creation"
30
)
31
CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
32
    "CUDA event deletion"
33
)
34
CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
35
    "CUDA event record"
36
)
37
CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
38
    "CUDA event wait"
39
)
40
CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
41
    "CUDA memory allocation"
42
)
43
CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
44
    "CUDA memory deallocation"
45
)
46
CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
47
    "CUDA stream creation"
48
)
49
CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
50
    "CUDA device synchronization"
51
)
52
CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
53
    "CUDA stream synchronization"
54
)
55
CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
56
    "CUDA event synchronization"
57
)
58

59

60
def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
61
    CUDAEventCreationCallbacks.add_callback(cb)
62

63

64
def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
65
    CUDAEventDeletionCallbacks.add_callback(cb)
66

67

68
def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
69
    CUDAEventRecordCallbacks.add_callback(cb)
70

71

72
def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
73
    CUDAEventWaitCallbacks.add_callback(cb)
74

75

76
def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
77
    CUDAMemoryAllocationCallbacks.add_callback(cb)
78

79

80
def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
81
    CUDAMemoryDeallocationCallbacks.add_callback(cb)
82

83

84
def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
85
    CUDAStreamCreationCallbacks.add_callback(cb)
86

87

88
def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
89
    CUDADeviceSynchronizationCallbacks.add_callback(cb)
90

91

92
def register_callback_for_cuda_stream_synchronization(
93
    cb: Callable[[int], None]
94
) -> None:
95
    CUDAStreamSynchronizationCallbacks.add_callback(cb)
96

97

98
def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
99
    CUDAEventSynchronizationCallbacks.add_callback(cb)
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.