pytorch
274 строки · 8.1 Кб
1import argparse
2import csv
3import dataclasses
4import os
5
6from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
7
8import torch
9import torch.nn as nn
10from torch._inductor.runtime.benchmarking import benchmarker
11from torch.utils.flop_counter import FlopCounterMode
12
13
14WARMUP_ITER = 5
15
16A100_40G_BF16_TFLOPS = 312
17
18
19@dataclasses.dataclass
20class Experiment:
21name: str
22metric: str
23target: float
24actual: float
25dtype: str
26device: str
27is_model: bool = False
28
29
30class SimpleMLP(nn.Module):
31def __init__(self, input_dim, hidden_dim, output_dim, dtype):
32super().__init__()
33self.layers = nn.ModuleList(
34[
35nn.Linear(input_dim, hidden_dim, dtype=dtype),
36nn.LayerNorm(hidden_dim, dtype=dtype),
37nn.Linear(hidden_dim, output_dim, dtype=dtype),
38nn.LayerNorm(output_dim, dtype=dtype),
39]
40)
41
42def forward(self, x):
43for layer in self.layers:
44x = layer(x)
45return x
46
47
48def run_mlp_layer_norm_gelu(device: str = "cuda"):
49dtype_flops_utilization_map = {
50torch.bfloat16: "0.8",
51}
52input_shapes = [1024, 4096, 8192, 16384]
53intermediate_size = 14336
54results = []
55for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
56flops_utilization = 0
57for D in input_shapes:
58mod = SimpleMLP(
59input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
60).to(device)
61
62x = torch.randn(D, device=device, dtype=torch.bfloat16)
63
64with FlopCounterMode(display=False) as mode:
65mod(x)
66
67flops = mode.get_total_flops()
68
69compiled_mod = torch.compile(mod, dynamic=False)
70
71for _ in range(WARMUP_ITER):
72compiled_mod(x)
73
74us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000
75flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS
76
77flops_utilization = flops_utilization / len(input_shapes)
78dtype_str = str(dtype).replace("torch.", "")
79results.append(
80Experiment(
81"mlp_layer_norm_gelu",
82"flops_utilization",
83expected_flops_utilization,
84f"{flops_utilization:.02f}",
85dtype_str,
86device,
87)
88)
89return results
90
91
92def run_layer_norm(device: str = "cuda"):
93dtype_memory_bandwidth_map = {
94torch.bfloat16: "950",
95}
96input_shapes = [1024, 4096, 8192, 16384]
97BS = 4096
98results = []
99for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
100memory_bandwidth = 0
101for D in input_shapes:
102mod = nn.LayerNorm(D).to(device)
103
104x = torch.randn(BS, D, device=device, dtype=dtype)
105
106compiled_mod = torch.compile(mod, dynamic=False)
107
108for _ in range(WARMUP_ITER):
109compiled_mod(x)
110
111us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_mod(x)) * 1000
112memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
113
114memory_bandwidth = memory_bandwidth / len(input_shapes)
115dtype_str = str(dtype).replace("torch.", "")
116results.append(
117Experiment(
118"layer_norm",
119"memory_bandwidth(GB/s)",
120expected_memory_bandwidth,
121f"{memory_bandwidth:.02f}",
122dtype_str,
123device,
124)
125)
126return results
127
128
129@torch._inductor.config.patch(coordinate_descent_tuning=True)
130def run_gather_gemv(device: str = "cuda"):
131E = 8
132dtype_memory_bandwidth_map = {
133torch.int8: "990",
134torch.bfloat16: "1060",
135}
136input_shapes = [1024, 4096, 8192, 16384]
137results = []
138for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
139memory_bandwidth = 0
140for D in input_shapes:
141
142def gather_gemv(W, score_idxs, x):
143return W[score_idxs].to(x.dtype) @ x
144
145W = torch.randn(E, D, D, device=device).to(dtype=dtype)
146x = torch.randn(D, device=device, dtype=torch.bfloat16)
147score_idxs = torch.tensor([3, 5], device=device)
148
149compiled_fn = torch.compile(gather_gemv, dynamic=False)
150
151for _ in range(WARMUP_ITER):
152compiled_fn(W, score_idxs, x)
153
154us_per_iter = (
155benchmarker.benchmark_gpu(lambda: compiled_fn(W, score_idxs, x)) * 1000
156)
157memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
158
159memory_bandwidth = memory_bandwidth / len(input_shapes)
160dtype_str = str(dtype).replace("torch.", "")
161results.append(
162Experiment(
163"gather_gemv",
164"memory_bandwidth(GB/s)",
165expected_memory_bandwidth,
166f"{memory_bandwidth:.02f}",
167dtype_str,
168device,
169)
170)
171return results
172
173
174@torch._inductor.config.patch(coordinate_descent_tuning=True)
175def run_gemv(device: str = "cuda"):
176dtype_memory_bandwidth_map = {
177torch.int8: "870",
178torch.bfloat16: "990",
179}
180input_shapes = [1024, 4096, 8192, 16384]
181results = []
182for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
183memory_bandwidth = 0
184for D in input_shapes:
185
186def gemv(W, x):
187return W.to(x.dtype) @ x
188
189W = torch.randn(D, D, device="cuda").to(dtype=dtype)
190x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
191
192compiled_fn = torch.compile(gemv, dynamic=False)
193
194for _ in range(WARMUP_ITER):
195compiled_fn(W, x)
196
197us_per_iter = benchmarker.benchmark_gpu(lambda: compiled_fn(W, x)) * 1000
198memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
199
200memory_bandwidth = memory_bandwidth / len(input_shapes)
201dtype_str = str(dtype).replace("torch.", "")
202results.append(
203Experiment(
204"gemv",
205"memory_bandwidth(GB/s)",
206expected_memory_bandwidth,
207f"{memory_bandwidth:.02f}",
208dtype_str,
209device,
210)
211)
212return results
213
214
215def output_csv(output_file, headers, row):
216if os.path.exists(output_file):
217with open(output_file) as fd:
218lines = list(csv.reader(fd)) or [[]]
219if headers and len(headers) > len(lines[0]):
220# if prior results failed the header might not be filled in yet
221lines[0] = headers
222else:
223headers = lines[0]
224else:
225lines = [headers]
226
227if output_file != DEFAULT_OUTPUT_FILE:
228os.makedirs(os.path.dirname(output_file), exist_ok=True)
229lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
230with open(output_file, "w") as fd:
231writer = csv.writer(fd, lineterminator="\n")
232for line in lines:
233writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
234
235
236DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
237
238all_experiments = {
239# A list of GPT models: LlaMa, Mixtral, etc.
240run_llama2_7b_bf16,
241run_llama2_7b_int8,
242run_mixtral_8x7b_int8,
243# A list of micro-benchmarks.
244run_mlp_layer_norm_gelu,
245run_layer_norm,
246run_gather_gemv,
247run_gemv,
248}
249
250
251def main(output_file=DEFAULT_OUTPUT_FILE):
252results = []
253
254for func in all_experiments:
255lst = func()
256for x in lst:
257results.append(dataclasses.astuple(x))
258
259headers = [field.name for field in dataclasses.fields(Experiment)]
260
261for row in results:
262output_csv(output_file, headers, row)
263
264
265if __name__ == "__main__":
266parser = argparse.ArgumentParser(description="Run experiments.")
267parser.add_argument(
268"--output",
269default=DEFAULT_OUTPUT_FILE,
270help="Set the output CSV file to save the benchmark results",
271)
272args = parser.parse_args()
273
274main(output_file=args.output)
275