pytorch

Форк
0
124 строки · 2.4 Кб
1
import torch
2
import torch.fx as fx
3
from functorch import make_fx
4
from torch._functorch.compile_utils import fx_graph_cse
5
from torch.profiler import profile, ProfilerActivity
6

7

8
def profile_it(f, inp):
9
    for _ in range(5):
10
        f(inp)
11

12
    itr = 5
13
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
14
        for _ in range(itr):
15
            f(inp)
16

17
    timing = prof.key_averages()
18
    cuda_time_total = 0
19
    for e in timing:
20
        cuda_time_total = cuda_time_total + e.cuda_time_total
21
    return cuda_time_total / itr
22

23

24
def profile_function(name, f, inp):
25
    fx_g = make_fx(f)(inp)
26

27
    new_g = fx_graph_cse(fx_g.graph)
28
    new_g = fx.GraphModule(fx_g, new_g)
29
    # do not benchmark against the scripted version because script already does some CSE
30
    # script_f = torch.jit.script(fx_g)
31
    # script_g = torch.jit.script(new_g)
32
    # avg_cuda_time_f = profile_it(script_f, inp)
33
    # avg_cuda_time_g = profile_it(script_g, inp)
34
    avg_cuda_time_f = profile_it(fx_g, inp)
35
    avg_cuda_time_g = profile_it(new_g, inp)
36
    num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
37

38
    print(
39
        f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
40
    )
41

42

43
g_gpu = torch.Generator(device="cuda")
44
g_gpu.manual_seed(2147483647)
45
inp = torch.randn(2**20, device="cuda", generator=g_gpu)
46

47

48
def f1(x):
49
    return x.cos().cos()
50

51

52
profile_function("f1", f1, inp)
53

54

55
def fsum(x):
56
    a = x.sum()
57
    b = x.sum()
58
    c = x.sum()
59
    d = x.sum()
60
    return a + b + c + d
61

62

63
profile_function("fsum", fsum, inp)
64

65

66
def fconcat(x):
67
    a = torch.cat((x, x))
68
    b = torch.cat((x, x))
69
    return a + b
70

71

72
profile_function("fconcat", fconcat, inp)
73

74

75
def fsum2(x):
76
    a = x.sum()
77
    for _ in range(30):
78
        a = a + x.sum()
79
    return a
80

81

82
profile_function("fsum2", fsum2, inp)
83

84

85
def fsummulti(x):
86
    a = 0
87
    for _ in range(3):
88
        a = a + x.sum()
89
        a = a * x.sum()
90
    return a
91

92

93
profile_function("fsummulti", fsummulti, inp)
94

95

96
def fsummulti2(x):
97
    a = 0
98
    for _ in range(30):
99
        a = a + x.sum()
100
        a = a * x.sum()
101
    return a
102

103

104
profile_function("fsummulti2", fsummulti2, inp)
105

106

107
def fcos(x):
108
    a = 0
109
    for _ in range(3):
110
        a = a + x.cos()
111
    return a
112

113

114
profile_function("fcos", fcos, inp)
115

116

117
def fcos2(x):
118
    a = 0
119
    for _ in range(30):
120
        a = a + x.cos()
121
    return a
122

123

124
profile_function("fcos2", fcos2, inp)
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.