pytorch

Форк
0
/
benchmark.py 
274 строки · 8.1 Кб
1
import argparse
2
import csv
3
import dataclasses
4
import os
5

6
from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
7

8
import torch
9
import torch.nn as nn
10
from torch._inductor.runtime.benchmarking import benchmarker
11
from torch.utils.flop_counter import FlopCounterMode
12

13

14
WARMUP_ITER = 5
15

16
A100_40G_BF16_TFLOPS = 312
17

18

19
@dataclasses.dataclass
20
class Experiment:
21
    name: str
22
    metric: str
23
    target: float
24
    actual: float
25
    dtype: str
26
    device: str
27
    is_model: bool = False
28

29

30
class SimpleMLP(nn.Module):
31
    def __init__(self, input_dim, hidden_dim, output_dim, dtype):
32
        super().__init__()
33
        self.layers = nn.ModuleList(
34
            [
35
                nn.Linear(input_dim, hidden_dim, dtype=dtype),
36
                nn.LayerNorm(hidden_dim, dtype=dtype),
37
                nn.Linear(hidden_dim, output_dim, dtype=dtype),
38
                nn.LayerNorm(output_dim, dtype=dtype),
39
            ]
40
        )
41

42
    def forward(self, x):
43
        for layer in self.layers:
44
            x = layer(x)
45
        return x
46

47

48
def run_mlp_layer_norm_gelu(device: str = "cuda"):
49
    dtype_flops_utilization_map = {
50
        torch.bfloat16: "0.8",
51
    }
52
    input_shapes = [1024, 4096, 8192, 16384]
53
    intermediate_size = 14336
54
    results = []
55
    for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
56
        flops_utilization = 0
57
        for D in input_shapes:
58
            mod = SimpleMLP(
59
                input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
60
            ).to(device)
61

62
            x = torch.randn(D, device=device, dtype=torch.bfloat16)
63

64
            with FlopCounterMode(display=False) as mode:
65
                mod(x)
66

67
            flops = mode.get_total_flops()
68

69
            compiled_mod = torch.compile(mod, dynamic=False)
70

71
            for _ in range(WARMUP_ITER):
72
                compiled_mod(x)
73

74
            us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000
75
            flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS
76

77
        flops_utilization = flops_utilization / len(input_shapes)
78
        dtype_str = str(dtype).replace("torch.", "")
79
        results.append(
80
            Experiment(
81
                "mlp_layer_norm_gelu",
82
                "flops_utilization",
83
                expected_flops_utilization,
84
                f"{flops_utilization:.02f}",
85
                dtype_str,
86
                device,
87
            )
88
        )
89
    return results
90

91

92
def run_layer_norm(device: str = "cuda"):
93
    dtype_memory_bandwidth_map = {
94
        torch.bfloat16: "950",
95
    }
96
    input_shapes = [1024, 4096, 8192, 16384]
97
    BS = 4096
98
    results = []
99
    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
100
        memory_bandwidth = 0
101
        for D in input_shapes:
102
            mod = nn.LayerNorm(D).to(device)
103

104
            x = torch.randn(BS, D, device=device, dtype=dtype)
105

106
            compiled_mod = torch.compile(mod, dynamic=False)
107

108
            for _ in range(WARMUP_ITER):
109
                compiled_mod(x)
110

111
            us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000
112
            memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
113

114
        memory_bandwidth = memory_bandwidth / len(input_shapes)
115
        dtype_str = str(dtype).replace("torch.", "")
116
        results.append(
117
            Experiment(
118
                "layer_norm",
119
                "memory_bandwidth(GB/s)",
120
                expected_memory_bandwidth,
121
                f"{memory_bandwidth:.02f}",
122
                dtype_str,
123
                device,
124
            )
125
        )
126
    return results
127

128

129
@torch._inductor.config.patch(coordinate_descent_tuning=True)
130
def run_gather_gemv(device: str = "cuda"):
131
    E = 8
132
    dtype_memory_bandwidth_map = {
133
        torch.int8: "990",
134
        torch.bfloat16: "1060",
135
    }
136
    input_shapes = [1024, 4096, 8192, 16384]
137
    results = []
138
    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
139
        memory_bandwidth = 0
140
        for D in input_shapes:
141

142
            def gather_gemv(W, score_idxs, x):
143
                return W[score_idxs].to(x.dtype) @ x
144

145
            W = torch.randn(E, D, D, device=device).to(dtype=dtype)
146
            x = torch.randn(D, device=device, dtype=torch.bfloat16)
147
            score_idxs = torch.tensor([3, 5], device=device)
148

149
            compiled_fn = torch.compile(gather_gemv, dynamic=False)
150

151
            for _ in range(WARMUP_ITER):
152
                compiled_fn(W, score_idxs, x)
153

154
            us_per_iter = (
155
                benchmarker.benchmark_gpu(lambda: compiled_fn(W, score_idxs, x)) * 1000
156
            )
157
            memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
158

159
        memory_bandwidth = memory_bandwidth / len(input_shapes)
160
        dtype_str = str(dtype).replace("torch.", "")
161
        results.append(
162
            Experiment(
163
                "gather_gemv",
164
                "memory_bandwidth(GB/s)",
165
                expected_memory_bandwidth,
166
                f"{memory_bandwidth:.02f}",
167
                dtype_str,
168
                device,
169
            )
170
        )
171
    return results
172

173

174
@torch._inductor.config.patch(coordinate_descent_tuning=True)
175
def run_gemv(device: str = "cuda"):
176
    dtype_memory_bandwidth_map = {
177
        torch.int8: "870",
178
        torch.bfloat16: "990",
179
    }
180
    input_shapes = [1024, 4096, 8192, 16384]
181
    results = []
182
    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
183
        memory_bandwidth = 0
184
        for D in input_shapes:
185

186
            def gemv(W, x):
187
                return W.to(x.dtype) @ x
188

189
            W = torch.randn(D, D, device="cuda").to(dtype=dtype)
190
            x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
191

192
            compiled_fn = torch.compile(gemv, dynamic=False)
193

194
            for _ in range(WARMUP_ITER):
195
                compiled_fn(W, x)
196

197
            us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_fn(W, x)) * 1000
198
            memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
199

200
        memory_bandwidth = memory_bandwidth / len(input_shapes)
201
        dtype_str = str(dtype).replace("torch.", "")
202
        results.append(
203
            Experiment(
204
                "gemv",
205
                "memory_bandwidth(GB/s)",
206
                expected_memory_bandwidth,
207
                f"{memory_bandwidth:.02f}",
208
                dtype_str,
209
                device,
210
            )
211
        )
212
    return results
213

214

215
def output_csv(output_file, headers, row):
216
    if os.path.exists(output_file):
217
        with open(output_file) as fd:
218
            lines = list(csv.reader(fd)) or [[]]
219
            if headers and len(headers) > len(lines[0]):
220
                # if prior results failed the header might not be filled in yet
221
                lines[0] = headers
222
            else:
223
                headers = lines[0]
224
    else:
225
        lines = [headers]
226

227
    if output_file != DEFAULT_OUTPUT_FILE:
228
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
229
    lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
230
    with open(output_file, "w") as fd:
231
        writer = csv.writer(fd, lineterminator="\n")
232
        for line in lines:
233
            writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
234

235

236
DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
237

238
all_experiments = {
239
    # A list of GPT models: LlaMa, Mixtral, etc.
240
    run_llama2_7b_bf16,
241
    run_llama2_7b_int8,
242
    run_mixtral_8x7b_int8,
243
    # A list of micro-benchmarks.
244
    run_mlp_layer_norm_gelu,
245
    run_layer_norm,
246
    run_gather_gemv,
247
    run_gemv,
248
}
249

250

251
def main(output_file=DEFAULT_OUTPUT_FILE):
252
    results = []
253

254
    for func in all_experiments:
255
        lst = func()
256
        for x in lst:
257
            results.append(dataclasses.astuple(x))
258

259
    headers = [field.name for field in dataclasses.fields(Experiment)]
260

261
    for row in results:
262
        output_csv(output_file, headers, row)
263

264

265
if __name__ == "__main__":
266
    parser = argparse.ArgumentParser(description="Run experiments.")
267
    parser.add_argument(
268
        "--output",
269
        default=DEFAULT_OUTPUT_FILE,
270
        help="Set the output CSV file to save the benchmark results",
271
    )
272
    args = parser.parse_args()
273

274
    main(output_file=args.output)
275

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.