pytorch

Форк
0
/
benchmark_utils.py 
195 строк · 6.0 Кб
1
# mypy: ignore-errors
2

3
import contextlib
4
import time
5
import os
6
import json
7

8
import torch
9
from torch.profiler import profile, ProfilerActivity
10

11

12
def synchronize():
13
    pass
14

15

16
def dump_chrome_trace(f, input, trace_filename, optimize_ctx, activities, num_runs=1,
17
                      devices=None, kwargs_for_f=None, kwargs_for_profiler=None):
18
    """
19
    Output the chrome trace of running f(input, **kwargs_for_f) with [optimize_ctx]
20
    [num_runs] times to [trace_filename].
21

22
    [activities] are the activities that the profiler will record, e.g. ProfilerActivity.CUDA.
23
    Return total runtime without the profiler
24

25
    Outputs to trace_filename
26
    """
27

28
    if devices is None:
29
        devices = ["cuda"]
30

31
    global synchronize
32
    if devices != ["cpu"] and torch.cuda.is_available():
33
        synchronize = torch.cuda.synchronize
34

35
    if kwargs_for_f is None:
36
        kwargs_for_f = {}
37
    if kwargs_for_profiler is None:
38
        kwargs_for_profiler = {}
39

40
    with optimize_ctx:
41
        torch.manual_seed(1337)
42
        for _ in range(5):  # warmup runs
43
            f(input, **kwargs_for_f)
44
            synchronize()
45
        torch.manual_seed(1337)
46
        t0 = time.perf_counter()
47
        for _ in range(num_runs):
48
            f(input, **kwargs_for_f)
49
            synchronize()
50
        t1 = time.perf_counter()
51
    timing = t1 - t0
52

53
    with profile(activities=activities, **kwargs_for_profiler) as prof:
54
        with optimize_ctx:
55
            synchronize()
56
            torch.manual_seed(1337)
57
            for _ in range(num_runs):
58
                f(input, **kwargs_for_f)
59
                synchronize()
60
    prof.export_chrome_trace(trace_filename)
61

62
    return timing
63

64

65
def get_chrome_trace_events(filename):
66
    f = open(filename)
67
    data = json.load(f)
68
    events = data["traceEvents"]
69
    return events
70

71

72
def is_gpu_compute_event(event):
73
    global gpu_pids
74
    return "pid" in event and event["pid"] in gpu_pids and "ph" in event and event["ph"] == "X"
75

76

77
def get_sorted_gpu_events(events):
78
    sorted_gpu_events = []
79
    for event in events:
80
        if not is_gpu_compute_event(event):
81
            continue
82
        sorted_gpu_events.append(event)
83
    return sorted(sorted_gpu_events, key=lambda x: x["ts"])
84

85

86
def get_duration(sorted_gpu_events):
87
    if len(sorted_gpu_events) == 0:
88
        return 0
89
    event = sorted_gpu_events[0]
90
    current_end_time = event["ts"] + event["dur"]
91
    total_duration = event["dur"]
92
    for event in sorted_gpu_events[1:]:
93
        start_time = max(event["ts"], current_end_time)
94
        end_time = event["ts"] + event["dur"]
95
        total_duration = total_duration + max(end_time - start_time, 0)
96
        current_end_time = max(current_end_time, end_time)
97
    return total_duration
98

99

100
def get_sorted_gpu_mm_conv_events(events):
101
    def is_mm_conv_event(event):
102
        return "name" in event and ("gemm" in event["name"] or "conv" in event["name"]
103
                                    or "cutlass" in event["name"] or "wgrad" in event["name"])
104
    gpu_events = get_sorted_gpu_events(events)
105
    sorted_events = []
106
    for event in gpu_events:
107
        if not is_mm_conv_event(event):
108
            continue
109
        sorted_events.append(event)
110
    return sorted_events
111

112

113
gpu_pids = []
114

115

116
def compute_utilization(filename: str, total_length: float):
117
    """
118
    Process the chrome traces outputs by the pytorch profiler to compute GPU Utilization
119
    and percent of times spent on matmul and convolution
120

121
    Args:
122
        filename(str): Name of chrome traces file produced by pytorch profiler
123

124
        total_length(float): total length of the process without profiler in second
125

126
    Return:
127
        tuple: (GPU Utilization, percent of time spent on matmul and convolution)
128
    """
129
    events = get_chrome_trace_events(filename)
130

131
    # get pids of GPU events
132
    global gpu_pids
133
    gpu_pids = []
134
    for event in events:
135
        if "name" not in event:
136
            continue
137
        if event["name"] == 'process_labels' and "GPU" in event["args"]["labels"]:
138
            gpu_pids.append(event["pid"])
139

140
    total_length = total_length * 1e6
141
    sorted_gpu_events = get_sorted_gpu_events(events)
142
    utilization = get_duration(sorted_gpu_events) / total_length
143

144
    sorted_gpu_mm_conv_events = get_sorted_gpu_mm_conv_events(events)
145
    mm_conv_utilization = get_duration(sorted_gpu_mm_conv_events) / total_length
146

147
    return utilization, mm_conv_utilization
148

149

150
def benchmark_utilization(f, input, trace_folder, optimize_ctx=None, trace_file_name="tmp_chrome_trace", num_runs=1):
151
    """
152
    Benchmark the GPU Utilization and percent of time spent on matmul and convolution operations of
153
    running f(input, **kwargs_for_f) with [optimize_ctx] [num_runs] times.
154
    It will produce a chrome trace file in trace_folder/trace_file_name.json
155

156
    Example:
157

158
    ```
159
    def f(a):
160
        return a.sum()
161
    a = torch.rand(2**20, device="cuda")
162
    utilization, mm_conv_utilization = benchmark_utilization(f, a, "tmp", trace_file_name = "tmp_chrome_trace")
163
    ```
164

165
    Args:
166
        f: function to benchmark
167

168
        input: input to :attr:`f`
169

170
        trace_folder: name of the folder to store the chrome trace
171

172
        optimize_ctx: the context in which f will run
173

174
        trace_file_name: name of the dumped chrome trace file, default to "tmp_chrome_trace"
175

176
        num_runs: number of times to run f, excluding the warm-up runs, default to 1.
177

178
    Return:
179
        tuple: (GPU Utilization, percent of time spent on matmul and convolution)
180

181
    """
182
    isExist = os.path.exists(trace_folder)
183
    if not isExist:
184
        os.makedirs(trace_folder)
185
        print("create folder " + trace_folder)
186

187
    if optimize_ctx is None:
188
        optimize_ctx = contextlib.nullcontext()
189

190
    chrome_trace_file_name = os.path.join(trace_folder, trace_file_name + ".json")
191
    total_length = dump_chrome_trace(f, input, chrome_trace_file_name, optimize_ctx,
192
                                     [ProfilerActivity.CUDA], num_runs=num_runs, devices="cuda")
193
    utilization, mm_conv_utilization = compute_utilization(chrome_trace_file_name, total_length)
194

195
    return utilization, mm_conv_utilization
196

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.