pytorch

Форк
0
/
microbenchmarks.py 
334 строки · 9.0 Кб
1
import argparse
2
import operator
3
import time
4

5
import matplotlib.pyplot as plt
6
import numpy as np
7
import pandas as pd
8
import seaborn as sns
9

10
import torch
11
import torch._C._te as te
12

13

14
class kernel_arena_scope:
15
    def __enter__(self):
16
        self.scope = te.KernelScope()
17

18
    def __exit__(self, typ, val, traceback):
19
        self.scope = None
20

21

22
unary_ops = [
23
    ("sin", torch.sin),
24
    ("cos", torch.cos),
25
    ("tan", torch.tan),
26
    ("asin", torch.asin),
27
    ("acos", torch.acos),
28
    ("atan", torch.atan),
29
    ("sinh", torch.sinh),
30
    ("cosh", torch.cosh),
31
    ("tanh", torch.tanh),
32
    ("sigmoid", torch.sigmoid),
33
    ("exp", torch.exp),
34
    ("expm1", torch.expm1),
35
    ("expm1", torch.expm1),
36
    ("abs", torch.abs),
37
    ("log", torch.log),
38
    ("fast_log", torch.log),
39
    ("log2", torch.log2),
40
    ("log10", torch.log10),
41
    ("log1p", torch.log1p),
42
    ("erf", torch.erf),
43
    ("erfc", torch.erfc),
44
    ("sqrt", torch.sqrt),
45
    ("rsqrt", torch.rsqrt),
46
    ("ceil", torch.ceil),
47
    ("floor", torch.floor),
48
    ("round", torch.round),
49
    ("trunc", torch.trunc),
50
    ("lgamma", torch.lgamma),
51
    # ("frac", torch.frac), # seems unimplemented
52
    # ("isnan", torch.isnan), # no out variant
53
]
54

55

56
def gen_unary_nnc_fun(nnc_name):
57
    def nnc_fun(A, B):
58
        def compute(i, j):
59
            return getattr(A.load([i, j]), nnc_name)()
60

61
        return compute
62

63
    return nnc_fun
64

65

66
def gen_unary_torch_fun(torch_op):
67
    def torch_fun(a, b, out):
68
        def fun():
69
            return torch_op(a, out=out)
70

71
        return fun
72

73
    return torch_fun
74

75

76
def gen_binary_nnc_fun(fn):
77
    def nnc_fun(A, B):
78
        def compute(i, j):
79
            return fn(A.load([i, j]), B.load([i, j]))
80

81
        return compute
82

83
    return nnc_fun
84

85

86
def gen_binary_torch_fun(fn):
87
    def pt_fun(a, b, out):
88
        def fun():
89
            return fn(a, b, out=out)
90

91
        return fun
92

93
    return pt_fun
94

95

96
def gen_int_comparison_tensors(N, M):
97
    return (
98
        torch.randint(0, 3, (N, M)),
99
        torch.randint(0, 3, (N, M)),
100
        torch.empty((N, M), dtype=torch.bool),
101
    )
102

103

104
def gen_float_comparison_tensors(N, M):
105
    return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
106

107

108
te_bool = te.Dtype.Bool
109
binary_ops = [
110
    ("add", operator.add, torch.add),
111
    ("mul", operator.mul, torch.mul),
112
    ("sub", operator.sub, torch.sub),
113
    ("div", operator.truediv, torch.div),
114
    (
115
        "eq",
116
        (lambda a, b: te.Cast.make(te_bool, a == b)),
117
        torch.eq,
118
        gen_int_comparison_tensors,
119
    ),
120
    (
121
        "gt",
122
        (lambda a, b: te.Cast.make(te_bool, a > b)),
123
        torch.gt,
124
        gen_float_comparison_tensors,
125
    ),
126
    (
127
        "lt",
128
        (lambda a, b: te.Cast.make(te_bool, a < b)),
129
        torch.lt,
130
        gen_float_comparison_tensors,
131
    ),
132
    (
133
        "gte",
134
        (lambda a, b: te.Cast.make(te_bool, a >= b)),
135
        torch.greater_equal,
136
        gen_float_comparison_tensors,
137
    ),
138
    (
139
        "lte",
140
        (lambda a, b: te.Cast.make(te_bool, a <= b)),
141
        torch.less_equal,
142
        gen_float_comparison_tensors,
143
    ),
144
    # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
145
    # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
146
]
147

148

149
def nnc_relu(A, B):
150
    def f(i, j):
151
        return torch._C._te.ifThenElse(
152
            A.load([i, j]) < torch._C._te.ExprHandle.float(0),
153
            torch._C._te.ExprHandle.float(0),
154
            A.load([i, j]),
155
        )
156

157
    return f
158

159

160
def pt_relu(a, b, c):
161
    return torch.relu(a)
162

163

164
custom_ops = [
165
    ("relu", nnc_relu, pt_relu),
166
    # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
167
    # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
168
]
169

170

171
def gen_custom_torch_fun(fn):
172
    def pt_fun(a, b, out):
173
        def fun():
174
            return fn(a, b, out)
175

176
        return fun
177

178
    return pt_fun
179

180

181
def normalize_benchmarks(ops):
182
    return [i + (None,) if len(i) == 3 else i for i in ops]
183

184

185
names = []
186
nnc_fns = []
187
pt_fns = []
188
shape_fns = []
189

190
for nnc_name, pt_op in unary_ops:
191
    names.append(nnc_name)
192
    nnc_fns.append(gen_unary_nnc_fun(nnc_name))
193
    pt_fns.append(gen_unary_torch_fun(pt_op))
194
    shape_fns.append(None)
195

196
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
197
    names.append(name)
198
    nnc_fns.append(gen_binary_nnc_fun(lmbda))
199
    pt_fns.append(gen_binary_torch_fun(pt_fn))
200
    shape_fns.append(shape_fn)
201

202
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
203
    names.append(name)
204
    nnc_fns.append(lmbda)
205
    pt_fns.append(gen_custom_torch_fun(pt_fn))
206
    shape_fns.append(shape_fn)
207

208
benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
209

210

211
def run_benchmarks(benchmarks, sizes):
212
    df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"])
213
    with torch.no_grad():
214
        for name, nnc_fun, torch_fun, shape_fn in benchmarks:
215
            for N, M in sizes:
216
                iters = int(1e6 / (N + M))
217
                with kernel_arena_scope():
218
                    if shape_fn is None:
219
                        tA = torch.rand(M, N).clamp(0.01, 0.99)
220
                        tB = torch.rand(M, N).clamp(0.01, 0.99)
221
                        tX = torch.empty(M, N)
222
                        tR = torch.empty(M, N)
223
                    else:
224
                        tA, tB, tX = shape_fn(M, N)
225
                        tR = tX.clone()
226

227
                    def get_nnc_type(dtype):
228
                        if dtype == torch.float:
229
                            return torch._C._te.Dtype.Float
230
                        elif dtype == torch.long:
231
                            return torch._C._te.Dtype.Long
232

233
                    dtype = get_nnc_type(tA.dtype)
234

235
                    dM = torch._C._te.ExprHandle.int(M)
236
                    dN = torch._C._te.ExprHandle.int(N)
237

238
                    A = torch._C._te.Placeholder("A", dtype, [dM, dN])
239
                    B = torch._C._te.Placeholder("B", dtype, [dM, dN])
240

241
                    dim_args = [
242
                        torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")]
243
                    ]
244

245
                    compute = nnc_fun(A, B)
246
                    X = torch._C._te.Compute("X", dim_args, compute)
247
                    loopnest = torch._C._te.LoopNest([X])
248
                    loopnest.prepare_for_codegen()
249
                    stmt = torch._C._te.simplify(loopnest.root_stmt())
250
                    cg = torch._C._te.construct_codegen(
251
                        "llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]
252
                    )
253

254
                    # warmup
255
                    for _ in range(10):
256
                        cg.call([tA, tB, tX])
257
                    start = time.time()
258
                    for it in range(iters):
259
                        cg.call([tA, tB, tX])
260
                    time1 = time.time() - start
261

262
                    fn = torch_fun(tA, tB, tR)
263
                    # warmup
264
                    for _ in range(10):
265
                        tR = fn()
266
                    start = time.time()
267
                    for it in range(iters):
268
                        tR = fn()
269
                    time2 = time.time() - start
270

271
                    df = df.append(
272
                        {
273
                            "name": name,
274
                            "N": N,
275
                            "M": M,
276
                            "nnc_time": time1,
277
                            "torch_time": time2,
278
                            "ratio": time2 / time1,
279
                        },
280
                        ignore_index=True,
281
                    )
282
                    print(name, N, M)
283

284
                    print(time2 / time1, time1, time2)
285
                    print()
286

287
                    def check_correctness(a, b):
288
                        if not np.allclose(a, b):
289
                            print(name)
290
                            assert np.allclose(a, b)
291

292
                    check_correctness(tX, tR)
293
    return df
294

295

296
def dump_plot(df, sizes):
297
    keys = []
298
    vals = []
299
    indexed = df[df["N"] == df["M"]]
300
    for index, row in indexed.iterrows():
301
        keys.append(row["name"])
302
        vals.append(row["ratio"])
303

304
    keys = keys[:: len(sizes)]
305
    sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)})
306

307
    cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
308
    np_vals = np.array([vals]).reshape(-1, len(sizes))
309
    g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
310
    plt.yticks(rotation=0)
311
    plt.title("PyTorch performance divided by NNC performance (single core)")
312
    plt.xlabel("Size of NxN matrix")
313
    plt.ylabel("Operation")
314
    g.set_yticklabels(keys)
315
    g.set_xticklabels(sizes)
316

317
    plt.savefig("nnc.png")
318

319

320
if __name__ == "__main__":
321
    parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks")
322
    parser.add_argument(
323
        "--multi-threaded",
324
        "--multi_threaded",
325
        action="store_true",
326
        help="Run with more than one thread",
327
    )
328
    args = parser.parse_args()
329
    if not args.multi_threaded:
330
        torch.set_num_threads(1)
331

332
    sizes = [1, 4, 16, 64, 256, 1024]
333
    df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
334
    dump_plot(df, sizes)
335

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.