belle

Форк
0
/
quant.py 
280 строк · 10.5 Кб
1
import numpy as np
2
import torch
3
import torch.nn as nn
4
import math
5

6
def quantize(x, scale, zero, maxq):
7
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
8
    return scale * (q - zero)
9

10
class Quantizer(nn.Module):
11

12
    def __init__(self, shape=1):
13
        super(Quantizer, self).__init__()
14
        self.register_buffer('maxq', torch.tensor(0))
15
        self.register_buffer('scale', torch.zeros(shape))
16
        self.register_buffer('zero', torch.zeros(shape))
17

18
    def configure(
19
            self,
20
            bits, perchannel=False, sym=True, 
21
            mse=False, norm=2.4, grid=100, maxshrink=.8
22
        ):
23
        self.maxq = torch.tensor(2 ** bits - 1)
24
        self.perchannel = perchannel
25
        self.sym = sym
26
        self.mse = mse
27
        self.norm = norm
28
        self.grid = grid
29
        self.maxshrink = maxshrink 
30

31
    def find_params(self, x, weight=False):
32
        dev = x.device
33
        self.maxq = self.maxq.to(dev)
34

35
        shape = x.shape
36
        if self.perchannel:
37
            if weight:
38
                x = x.flatten(1)
39
            else:
40
                if len(shape) == 4:
41
                    x = x.permute([1, 0, 2, 3])
42
                    x = x.flatten(1)
43
                if len(shape) == 3:
44
                    x = x.reshape((-1, shape[-1])).t()
45
                if len(shape) == 2:
46
                    x = x.t()
47
        else:
48
            x = x.flatten().unsqueeze(0)
49

50
        tmp = torch.zeros(x.shape[0], device=dev)
51
        xmin = torch.minimum(x.min(1)[0], tmp)
52
        xmax = torch.maximum(x.max(1)[0], tmp)
53

54
        if self.sym:
55
            xmax = torch.maximum(torch.abs(xmin), xmax)
56
            tmp = xmin < 0
57
            if torch.any(tmp):
58
                xmin[tmp] = -xmax[tmp]
59
        tmp = (xmin == 0) & (xmax == 0)
60
        xmin[tmp] = -1
61
        xmax[tmp] = +1
62

63
        self.scale = (xmax - xmin) / self.maxq
64
        if self.sym:
65
            self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
66
        else:
67
            self.zero = torch.round(-xmin / self.scale)
68

69
        if self.mse:
70
            best = torch.full([x.shape[0]], float('inf'), device=dev)
71
            for i in range(int(self.maxshrink * self.grid)):
72
                p = 1 - i / self.grid 
73
                xmin1 = p * xmin
74
                xmax1 = p * xmax
75
                scale1 = (xmax1 - xmin1) / self.maxq
76
                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
77
                q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
78
                q -= x
79
                q.abs_()
80
                q.pow_(self.norm)
81
                err = torch.sum(q, 1)
82
                tmp = err < best
83
                if torch.any(tmp):
84
                    best[tmp] = err[tmp]
85
                    self.scale[tmp] = scale1[tmp]
86
                    self.zero[tmp] = zero1[tmp]
87
        if not self.perchannel:
88
            if weight:
89
                tmp = shape[0]
90
            else:
91
                tmp = shape[1] if len(shape) != 3 else shape[2]
92
            self.scale = self.scale.repeat(tmp)
93
            self.zero = self.zero.repeat(tmp)
94

95
        if weight:
96
            shape = [-1] + [1] * (len(shape) - 1)
97
            self.scale = self.scale.reshape(shape)
98
            self.zero = self.zero.reshape(shape)
99
            return
100
        if len(shape) == 4:
101
            self.scale = self.scale.reshape((1, -1, 1, 1))
102
            self.zero = self.zero.reshape((1, -1, 1, 1))
103
        if len(shape) == 3:
104
            self.scale = self.scale.reshape((1, 1, -1))
105
            self.zero = self.zero.reshape((1, 1, -1)) 
106
        if len(shape) == 2:
107
            self.scale = self.scale.unsqueeze(0)
108
            self.zero = self.zero.unsqueeze(0)
109

110
    def quantize(self, x):
111
        if self.ready():
112
            return quantize(x, self.scale, self.zero, self.maxq)
113
        return x
114

115
    def enabled(self):
116
        return self.maxq > 0
117

118
    def ready(self):
119
        return torch.all(self.scale != 0)
120

121

122
try:
123
    import quant_cuda
124
except:
125
    print('CUDA extension not installed.')
126

127
# Assumes layer is perfectly divisible into 256 * 256 blocks
128
class QuantLinear(nn.Module): 
129
    def __init__(self, bits, groupsize, infeatures, outfeatures):
130
        super().__init__()
131
        if bits not in [2,3,4,8]:
132
            raise NotImplementedError("Only 2,3,4,8 bits are supported.")
133
        self.infeatures = infeatures
134
        self.outfeatures = outfeatures
135
        self.bits = bits
136
        if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2,int(math.log2(groupsize)))):
137
            raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
138
        groupsize = groupsize if groupsize != -1 else infeatures
139
        self.groupsize = groupsize
140
        self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures/groupsize),outfeatures // 256 * (bits * 8)), dtype=torch.int))
141
        self.register_buffer('scales', torch.zeros((math.ceil(infeatures/groupsize),outfeatures)))
142
        self.register_buffer('bias', torch.zeros(outfeatures))
143
        self.register_buffer(
144
            'qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int)
145
        )
146
        self._initialized_quant_state = False
147

148
    def pack(self, linear, scales, zeros):
149
        scales = scales.t().contiguous()
150
        zeros = zeros.t().contiguous()
151
        scale_zeros = zeros * scales
152
        self.scales = scales.clone()
153
        if linear.bias is not None:
154
            self.bias = linear.bias.clone() 
155
            
156
        intweight = []
157
        for idx in range(self.infeatures):
158
            g_idx = idx // self.groupsize
159
            intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,None])
160
        intweight = torch.cat(intweight,dim=1)
161
        intweight = intweight.t().contiguous()
162
        intweight = intweight.numpy().astype(np.uint32)
163
        qweight = np.zeros(
164
            (intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32
165
        )
166
        i = 0
167
        row = 0
168
        while row < qweight.shape[0]:
169
            if self.bits in [2,4,8]:
170
                for j in range(i, i + (32//self.bits)):
171
                    qweight[row] |= intweight[j] << (self.bits * (j - i))
172
                i += 32//self.bits
173
                row += 1
174
            elif self.bits == 3:
175
                for j in range(i, i + 10):
176
                    qweight[row] |= intweight[j] << (3 * (j - i))
177
                i += 10
178
                qweight[row] |= intweight[i] << 30
179
                row += 1
180
                qweight[row] |= (intweight[i] >> 2) & 1
181
                i += 1
182
                for j in range(i, i + 10):
183
                    qweight[row] |= intweight[j] << (3 * (j - i) + 1)
184
                i += 10
185
                qweight[row] |= intweight[i] << 31
186
                row += 1
187
                qweight[row] |= (intweight[i] >> 1) & 0x3
188
                i += 1
189
                for j in range(i, i + 10):
190
                    qweight[row] |= intweight[j] << (3 * (j - i) + 2)
191
                i += 10
192
                row += 1
193
            else:
194
                raise NotImplementedError("Only 2,3,4,8 bits are supported.")
195
                
196
        qweight = qweight.astype(np.int32)
197
        self.qweight = torch.from_numpy(qweight) 
198
        
199
        zeros -= 1;
200
        zeros = zeros.numpy().astype(np.uint32)
201
        qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
202
        i = 0
203
        col = 0
204
        while col < qzeros.shape[1]:
205
            if self.bits in [2,4,8]:
206
                for j in range(i, i + (32//self.bits)):
207
                    qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
208
                i += 32//self.bits
209
                col += 1
210
            elif self.bits == 3:
211
                for j in range(i, i + 10):
212
                    qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
213
                i += 10
214
                qzeros[:, col] |= zeros[:, i] << 30
215
                col += 1
216
                qzeros[:, col] |= (zeros[:, i] >> 2) & 1
217
                i += 1
218
                for j in range(i, i + 10):
219
                    qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
220
                i += 10
221
                qzeros[:, col] |= zeros[:, i] << 31
222
                col += 1
223
                qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
224
                i += 1
225
                for j in range(i, i + 10):
226
                    qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
227
                i += 10
228
                col += 1
229
            else:
230
                raise NotImplementedError("Only 2,3,4,8 bits are supported.")
231
                
232
        qzeros = qzeros.astype(np.int32)
233
        self.qzeros = torch.from_numpy(qzeros) 
234

235
    def forward(self, x):
236
        intermediate_dtype = torch.float32
237

238
        if not self._initialized_quant_state:
239
            # Do we even have a bias? Check for at least one non-zero element.
240
            if self.bias is not None and bool(torch.any(self.bias != 0)):
241
                # Then make sure it's the right type.
242
                self.bias.data = self.bias.data.to(intermediate_dtype)
243
            else:
244
                self.bias = None
245

246
        outshape = list(x.shape)
247
        outshape[-1] = self.outfeatures
248
        x = x.reshape(-1, x.shape[-1])
249
        if self.bias is None:
250
            y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
251
        else:
252
            y = self.bias.clone().repeat(x.shape[0], 1)
253

254
        output_dtype = x.dtype
255
        x = x.to(intermediate_dtype)
256
        if self.bits == 2:
257
            quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
258
        elif self.bits == 3:
259
            quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
260
        elif self.bits == 4:
261
            quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
262
        elif self.bits == 8:
263
            quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
264
        else:
265
            raise NotImplementedError("Only 2,3,4,8 bits are supported.")
266
        y = y.to(output_dtype)
267
        return y.reshape(outshape)
268

269
def make_quant(module, names, bits, groupsize, name=''):
270
    if isinstance(module, QuantLinear):
271
        return
272
    for attr in dir(module):
273
        tmp = getattr(module, attr)
274
        name1 = name + '.' + attr if name != '' else attr
275
        if name1 in names:
276
            setattr(
277
                module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features)
278
            )
279
    for name1, child in module.named_children():
280
        make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
281

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.