6
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
8
torch.backends.cuda.matmul.allow_tf32 = False
9
torch.backends.cudnn.allow_tf32 = False
11
print('Benchmarking LLaMa-7B FC2 matvec ...')
13
DEV = torch.device('cuda:0')
21
mat = torch.randn((M, N), device=DEV, dtype=DTYPE)
22
vec = torch.randn((B, M), device=DEV, dtype=DTYPE)
23
mul = torch.zeros((B, N), device=DEV, dtype=DTYPE)
29
torch.matmul(vec, mat, out=mul)
30
torch.cuda.synchronize()
31
print('FP16:', (time.time() - tick) / COUNT)
38
mat = torch.randint(-1000000000, 1000000000, (M // 256 * 32, N), device=DEV, dtype=torch.int)
39
scales = torch.randn(N, device=DEV, dtype=DTYPE)
40
zeros = torch.randint(-1000000000, 1000000000, (1, N // 256 * 32), device=DEV, dtype=torch.int)
46
quant_cuda.vecquant2matmul(vec, mat, mul, scales, zeros, M)
47
torch.cuda.synchronize()
48
print('2bit:', (time.time() - tick) / COUNT)
52
quant_cuda.vecquant3matmul(vec, mat, mul, scales, zeros, M)
53
torch.cuda.synchronize()
54
print('3bit:', (time.time() - tick) / COUNT)
58
quant_cuda.vecquant4matmul(vec, mat, mul, scales, zeros, M)
59
torch.cuda.synchronize()
60
print('4bit:', (time.time() - tick) / COUNT)
64
quant_cuda.vecquant8matmul(vec, mat, mul, scales, zeros, M)
65
torch.cuda.synchronize()
66
print('8bit:', (time.time() - tick) / COUNT)
67
print('Verifiying kernel correctness ...')
74
layer = nn.Linear(M, N)
75
vec = torch.randn(B,L,M).to(DEV)
77
quantizer = Quantizer()
78
quantizer.configure(2, perchannel=True, sym=False, mse=False)
79
quantizer.find_params(layer.weight.data, weight=True)
80
layer.weight.data = quantize(
81
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
84
qlayer = QuantLinear(2, -1, layer.in_features, layer.out_features)
85
qlayer.pack(layer, quantizer.scale, quantizer.zero)
87
qlayer = qlayer.to(DEV)
91
print('2bit Simu:', qlayer(vec))
92
print('2bit Kern:', layer.to(DEV)(vec))
95
layer = nn.Linear(M, N)
96
vec = torch.randn(B,L,M).to(DEV)
98
quantizer = Quantizer()
99
quantizer.configure(3, perchannel=True, sym=False, mse=False)
100
quantizer.find_params(layer.weight.data, weight=True)
101
layer.weight.data = quantize(
102
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
105
qlayer = QuantLinear(3, -1, layer.in_features, layer.out_features)
106
qlayer.pack(layer, quantizer.scale, quantizer.zero)
108
qlayer = qlayer.to(DEV)
112
print('3bit Simu:', qlayer(vec))
113
print('3bit Kern:', layer.to(DEV)(vec))
116
layer = nn.Linear(M, N)
117
vec = torch.randn(B,L,M).to(DEV)
119
quantizer = Quantizer()
120
quantizer.configure(4, perchannel=True, sym=False, mse=False)
121
quantizer.find_params(layer.weight.data, weight=True)
122
layer.weight.data = quantize(
123
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
126
qlayer = QuantLinear(4, -1, layer.in_features, layer.out_features)
127
qlayer.pack(layer, quantizer.scale, quantizer.zero)
129
qlayer = qlayer.to(DEV)
133
print('4bit Simu:', qlayer(vec))
134
print('4bit Kern:', layer.to(DEV)(vec))
137
layer = nn.Linear(M, N)
138
vec = torch.randn(B,L,M).to(DEV)
140
quantizer = Quantizer()
141
quantizer.configure(8, perchannel=True, sym=False, mse=False)
142
quantizer.find_params(layer.weight.data, weight=True)
143
layer.weight.data = quantize(
144
layer.weight.data, quantizer.scale, quantizer.zero, quantizer.maxq
147
qlayer = QuantLinear(8, -1, layer.in_features, layer.out_features)
148
qlayer.pack(layer, quantizer.scale, quantizer.zero)
150
qlayer = qlayer.to(DEV)
154
print('8bit Simu:', qlayer(vec))
155
print('8bit Kern:', layer.to(DEV)(vec))