pytorch
124 строки · 2.4 Кб
1import torch
2import torch.fx as fx
3from functorch import make_fx
4from torch._functorch.compile_utils import fx_graph_cse
5from torch.profiler import profile, ProfilerActivity
6
7
8def profile_it(f, inp):
9for _ in range(5):
10f(inp)
11
12itr = 5
13with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
14for _ in range(itr):
15f(inp)
16
17timing = prof.key_averages()
18cuda_time_total = 0
19for e in timing:
20cuda_time_total = cuda_time_total + e.cuda_time_total
21return cuda_time_total / itr
22
23
24def profile_function(name, f, inp):
25fx_g = make_fx(f)(inp)
26
27new_g = fx_graph_cse(fx_g.graph)
28new_g = fx.GraphModule(fx_g, new_g)
29# do not benchmark against the scripted version because script already does some CSE
30# script_f = torch.jit.script(fx_g)
31# script_g = torch.jit.script(new_g)
32# avg_cuda_time_f = profile_it(script_f, inp)
33# avg_cuda_time_g = profile_it(script_g, inp)
34avg_cuda_time_f = profile_it(fx_g, inp)
35avg_cuda_time_g = profile_it(new_g, inp)
36num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
37
38print(
39f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
40)
41
42
43g_gpu = torch.Generator(device="cuda")
44g_gpu.manual_seed(2147483647)
45inp = torch.randn(2**20, device="cuda", generator=g_gpu)
46
47
48def f1(x):
49return x.cos().cos()
50
51
52profile_function("f1", f1, inp)
53
54
55def fsum(x):
56a = x.sum()
57b = x.sum()
58c = x.sum()
59d = x.sum()
60return a + b + c + d
61
62
63profile_function("fsum", fsum, inp)
64
65
66def fconcat(x):
67a = torch.cat((x, x))
68b = torch.cat((x, x))
69return a + b
70
71
72profile_function("fconcat", fconcat, inp)
73
74
75def fsum2(x):
76a = x.sum()
77for _ in range(30):
78a = a + x.sum()
79return a
80
81
82profile_function("fsum2", fsum2, inp)
83
84
85def fsummulti(x):
86a = 0
87for _ in range(3):
88a = a + x.sum()
89a = a * x.sum()
90return a
91
92
93profile_function("fsummulti", fsummulti, inp)
94
95
96def fsummulti2(x):
97a = 0
98for _ in range(30):
99a = a + x.sum()
100a = a * x.sum()
101return a
102
103
104profile_function("fsummulti2", fsummulti2, inp)
105
106
107def fcos(x):
108a = 0
109for _ in range(3):
110a = a + x.cos()
111return a
112
113
114profile_function("fcos", fcos, inp)
115
116
117def fcos2(x):
118a = 0
119for _ in range(30):
120a = a + x.cos()
121return a
122
123
124profile_function("fcos2", fcos2, inp)
125