belle

Форк
0
/
test_kernel.py 
155 строк · 4.1 Кб
1
import torch
2
import torch.nn as nn
3

4
import quant_cuda
5
import os
6
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
7

8
torch.backends.cuda.matmul.allow_tf32 = False
9
torch.backends.cudnn.allow_tf32 = False
10

11
print('Benchmarking LLaMa-7B FC2 matvec ...')
12

13
DEV = torch.device('cuda:0')
14

15
B = 5
16
L = 128
17
M = 4096
18
N = 11008
19

20
DTYPE = torch.half
21
mat = torch.randn((M, N), device=DEV, dtype=DTYPE)
22
vec = torch.randn((B, M), device=DEV, dtype=DTYPE)
23
mul = torch.zeros((B, N), device=DEV, dtype=DTYPE)
24

25
COUNT = 1000
26
import time
27
tick = time.time()
28
for _ in range(COUNT):
29
    torch.matmul(vec, mat, out=mul) 
30
    torch.cuda.synchronize()
31
print('FP16:', (time.time() - tick) / COUNT)
32

33
DTYPE = torch.float
34
mat = mat.to(DTYPE)
35
vec = vec.to(DTYPE)
36
mul = mul.to(DTYPE)
37

38
mat = torch.randint(-1000000000, 1000000000, (M // 256 * 32, N), device=DEV, dtype=torch.int)
39
scales = torch.randn(N, device=DEV, dtype=DTYPE)
40
zeros = torch.randint(-1000000000, 1000000000, (1, N // 256 * 32), device=DEV, dtype=torch.int)
41

42
COUNT = 1000
43
import time
44
tick = time.time()
45
for _ in range(COUNT):
46
    quant_cuda.vecquant2matmul(vec, mat, mul, scales, zeros, M)
47
    torch.cuda.synchronize()
48
print('2bit:', (time.time() - tick) / COUNT)
49

50
tick = time.time()
51
for _ in range(COUNT):
52
    quant_cuda.vecquant3matmul(vec, mat, mul, scales, zeros, M)
53
    torch.cuda.synchronize()
54
print('3bit:', (time.time() - tick) / COUNT)
55

56
tick = time.time()
57
for _ in range(COUNT):
58
    quant_cuda.vecquant4matmul(vec, mat, mul, scales, zeros, M)
59
    torch.cuda.synchronize()
60
print('4bit:', (time.time() - tick) / COUNT)
61

62
tick = time.time()
63
for _ in range(COUNT):
64
    quant_cuda.vecquant8matmul(vec, mat, mul, scales, zeros, M)
65
    torch.cuda.synchronize()
66
print('8bit:', (time.time() - tick) / COUNT)
67
print('Verifiying kernel correctness ...')
68

69
M = 4096
70
N = 11008
71

72
from quant import *
73

74
layer = nn.Linear(M, N)
75
vec = torch.randn(B,L,M).to(DEV)
76

77
quantizer = Quantizer()
78
quantizer.configure(2, perchannel=True, sym=False, mse=False)
79
quantizer.find_params(layer.weight.data, weight=True)
80
layer.weight.data = quantize(
81
    layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
82
)
83

84
qlayer = QuantLinear(2, -1, layer.in_features, layer.out_features)
85
qlayer.pack(layer, quantizer.scale, quantizer.zero)
86

87
qlayer = qlayer.to(DEV)
88
layer = layer.to(DEV)
89

90
with torch.no_grad():
91
    print('2bit Simu:', qlayer(vec))
92
    print('2bit Kern:', layer.to(DEV)(vec))
93
    print('\n')
94

95
layer = nn.Linear(M, N)
96
vec = torch.randn(B,L,M).to(DEV)
97

98
quantizer = Quantizer()
99
quantizer.configure(3, perchannel=True, sym=False, mse=False)
100
quantizer.find_params(layer.weight.data, weight=True)
101
layer.weight.data = quantize(
102
    layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
103
)
104

105
qlayer = QuantLinear(3, -1, layer.in_features, layer.out_features)
106
qlayer.pack(layer, quantizer.scale, quantizer.zero)
107

108
qlayer = qlayer.to(DEV)
109
layer = layer.to(DEV)
110

111
with torch.no_grad():
112
    print('3bit Simu:', qlayer(vec))
113
    print('3bit Kern:', layer.to(DEV)(vec))
114
    print('\n')
115

116
layer = nn.Linear(M, N)
117
vec = torch.randn(B,L,M).to(DEV)
118

119
quantizer = Quantizer()
120
quantizer.configure(4, perchannel=True, sym=False, mse=False)
121
quantizer.find_params(layer.weight.data, weight=True)
122
layer.weight.data = quantize(
123
    layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
124
)
125

126
qlayer = QuantLinear(4, -1, layer.in_features, layer.out_features)
127
qlayer.pack(layer, quantizer.scale, quantizer.zero)
128

129
qlayer = qlayer.to(DEV)
130
layer = layer.to(DEV) 
131

132
with torch.no_grad():
133
    print('4bit Simu:', qlayer(vec))
134
    print('4bit Kern:', layer.to(DEV)(vec))
135
    print('\n')
136

137
layer = nn.Linear(M, N)
138
vec = torch.randn(B,L,M).to(DEV)
139

140
quantizer = Quantizer()
141
quantizer.configure(8, perchannel=True, sym=False, mse=False)
142
quantizer.find_params(layer.weight.data, weight=True)
143
layer.weight.data = quantize(
144
    layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
145
)
146

147
qlayer = QuantLinear(8, -1, layer.in_features, layer.out_features)
148
qlayer.pack(layer, quantizer.scale, quantizer.zero)
149

150
qlayer = qlayer.to(DEV)
151
layer = layer.to(DEV)
152

153
with torch.no_grad():
154
    print('8bit Simu:', qlayer(vec))
155
    print('8bit Kern:', layer.to(DEV)(vec))
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.