pytorch
334 строки · 9.0 Кб
1import argparse
2import operator
3import time
4
5import matplotlib.pyplot as plt
6import numpy as np
7import pandas as pd
8import seaborn as sns
9
10import torch
11import torch._C._te as te
12
13
14class kernel_arena_scope:
15def __enter__(self):
16self.scope = te.KernelScope()
17
18def __exit__(self, typ, val, traceback):
19self.scope = None
20
21
22unary_ops = [
23("sin", torch.sin),
24("cos", torch.cos),
25("tan", torch.tan),
26("asin", torch.asin),
27("acos", torch.acos),
28("atan", torch.atan),
29("sinh", torch.sinh),
30("cosh", torch.cosh),
31("tanh", torch.tanh),
32("sigmoid", torch.sigmoid),
33("exp", torch.exp),
34("expm1", torch.expm1),
35("expm1", torch.expm1),
36("abs", torch.abs),
37("log", torch.log),
38("fast_log", torch.log),
39("log2", torch.log2),
40("log10", torch.log10),
41("log1p", torch.log1p),
42("erf", torch.erf),
43("erfc", torch.erfc),
44("sqrt", torch.sqrt),
45("rsqrt", torch.rsqrt),
46("ceil", torch.ceil),
47("floor", torch.floor),
48("round", torch.round),
49("trunc", torch.trunc),
50("lgamma", torch.lgamma),
51# ("frac", torch.frac), # seems unimplemented
52# ("isnan", torch.isnan), # no out variant
53]
54
55
56def gen_unary_nnc_fun(nnc_name):
57def nnc_fun(A, B):
58def compute(i, j):
59return getattr(A.load([i, j]), nnc_name)()
60
61return compute
62
63return nnc_fun
64
65
66def gen_unary_torch_fun(torch_op):
67def torch_fun(a, b, out):
68def fun():
69return torch_op(a, out=out)
70
71return fun
72
73return torch_fun
74
75
76def gen_binary_nnc_fun(fn):
77def nnc_fun(A, B):
78def compute(i, j):
79return fn(A.load([i, j]), B.load([i, j]))
80
81return compute
82
83return nnc_fun
84
85
86def gen_binary_torch_fun(fn):
87def pt_fun(a, b, out):
88def fun():
89return fn(a, b, out=out)
90
91return fun
92
93return pt_fun
94
95
96def gen_int_comparison_tensors(N, M):
97return (
98torch.randint(0, 3, (N, M)),
99torch.randint(0, 3, (N, M)),
100torch.empty((N, M), dtype=torch.bool),
101)
102
103
104def gen_float_comparison_tensors(N, M):
105return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
106
107
108te_bool = te.Dtype.Bool
109binary_ops = [
110("add", operator.add, torch.add),
111("mul", operator.mul, torch.mul),
112("sub", operator.sub, torch.sub),
113("div", operator.truediv, torch.div),
114(
115"eq",
116(lambda a, b: te.Cast.make(te_bool, a == b)),
117torch.eq,
118gen_int_comparison_tensors,
119),
120(
121"gt",
122(lambda a, b: te.Cast.make(te_bool, a > b)),
123torch.gt,
124gen_float_comparison_tensors,
125),
126(
127"lt",
128(lambda a, b: te.Cast.make(te_bool, a < b)),
129torch.lt,
130gen_float_comparison_tensors,
131),
132(
133"gte",
134(lambda a, b: te.Cast.make(te_bool, a >= b)),
135torch.greater_equal,
136gen_float_comparison_tensors,
137),
138(
139"lte",
140(lambda a, b: te.Cast.make(te_bool, a <= b)),
141torch.less_equal,
142gen_float_comparison_tensors,
143),
144# ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
145# ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
146]
147
148
149def nnc_relu(A, B):
150def f(i, j):
151return torch._C._te.ifThenElse(
152A.load([i, j]) < torch._C._te.ExprHandle.float(0),
153torch._C._te.ExprHandle.float(0),
154A.load([i, j]),
155)
156
157return f
158
159
160def pt_relu(a, b, c):
161return torch.relu(a)
162
163
164custom_ops = [
165("relu", nnc_relu, pt_relu),
166# ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
167# ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
168]
169
170
171def gen_custom_torch_fun(fn):
172def pt_fun(a, b, out):
173def fun():
174return fn(a, b, out)
175
176return fun
177
178return pt_fun
179
180
181def normalize_benchmarks(ops):
182return [i + (None,) if len(i) == 3 else i for i in ops]
183
184
185names = []
186nnc_fns = []
187pt_fns = []
188shape_fns = []
189
190for nnc_name, pt_op in unary_ops:
191names.append(nnc_name)
192nnc_fns.append(gen_unary_nnc_fun(nnc_name))
193pt_fns.append(gen_unary_torch_fun(pt_op))
194shape_fns.append(None)
195
196for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
197names.append(name)
198nnc_fns.append(gen_binary_nnc_fun(lmbda))
199pt_fns.append(gen_binary_torch_fun(pt_fn))
200shape_fns.append(shape_fn)
201
202for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
203names.append(name)
204nnc_fns.append(lmbda)
205pt_fns.append(gen_custom_torch_fun(pt_fn))
206shape_fns.append(shape_fn)
207
208benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
209
210
211def run_benchmarks(benchmarks, sizes):
212df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"])
213with torch.no_grad():
214for name, nnc_fun, torch_fun, shape_fn in benchmarks:
215for N, M in sizes:
216iters = int(1e6 / (N + M))
217with kernel_arena_scope():
218if shape_fn is None:
219tA = torch.rand(M, N).clamp(0.01, 0.99)
220tB = torch.rand(M, N).clamp(0.01, 0.99)
221tX = torch.empty(M, N)
222tR = torch.empty(M, N)
223else:
224tA, tB, tX = shape_fn(M, N)
225tR = tX.clone()
226
227def get_nnc_type(dtype):
228if dtype == torch.float:
229return torch._C._te.Dtype.Float
230elif dtype == torch.long:
231return torch._C._te.Dtype.Long
232
233dtype = get_nnc_type(tA.dtype)
234
235dM = torch._C._te.ExprHandle.int(M)
236dN = torch._C._te.ExprHandle.int(N)
237
238A = torch._C._te.Placeholder("A", dtype, [dM, dN])
239B = torch._C._te.Placeholder("B", dtype, [dM, dN])
240
241dim_args = [
242torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")]
243]
244
245compute = nnc_fun(A, B)
246X = torch._C._te.Compute("X", dim_args, compute)
247loopnest = torch._C._te.LoopNest([X])
248loopnest.prepare_for_codegen()
249stmt = torch._C._te.simplify(loopnest.root_stmt())
250cg = torch._C._te.construct_codegen(
251"llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]
252)
253
254# warmup
255for _ in range(10):
256cg.call([tA, tB, tX])
257start = time.time()
258for it in range(iters):
259cg.call([tA, tB, tX])
260time1 = time.time() - start
261
262fn = torch_fun(tA, tB, tR)
263# warmup
264for _ in range(10):
265tR = fn()
266start = time.time()
267for it in range(iters):
268tR = fn()
269time2 = time.time() - start
270
271df = df.append(
272{
273"name": name,
274"N": N,
275"M": M,
276"nnc_time": time1,
277"torch_time": time2,
278"ratio": time2 / time1,
279},
280ignore_index=True,
281)
282print(name, N, M)
283
284print(time2 / time1, time1, time2)
285print()
286
287def check_correctness(a, b):
288if not np.allclose(a, b):
289print(name)
290assert np.allclose(a, b)
291
292check_correctness(tX, tR)
293return df
294
295
296def dump_plot(df, sizes):
297keys = []
298vals = []
299indexed = df[df["N"] == df["M"]]
300for index, row in indexed.iterrows():
301keys.append(row["name"])
302vals.append(row["ratio"])
303
304keys = keys[:: len(sizes)]
305sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)})
306
307cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
308np_vals = np.array([vals]).reshape(-1, len(sizes))
309g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
310plt.yticks(rotation=0)
311plt.title("PyTorch performance divided by NNC performance (single core)")
312plt.xlabel("Size of NxN matrix")
313plt.ylabel("Operation")
314g.set_yticklabels(keys)
315g.set_xticklabels(sizes)
316
317plt.savefig("nnc.png")
318
319
320if __name__ == "__main__":
321parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks")
322parser.add_argument(
323"--multi-threaded",
324"--multi_threaded",
325action="store_true",
326help="Run with more than one thread",
327)
328args = parser.parse_args()
329if not args.multi_threaded:
330torch.set_num_threads(1)
331
332sizes = [1, 4, 16, 64, 256, 1024]
333df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
334dump_plot(df, sizes)
335