autogptq
93 строки · 2.5 Кб
1#include <cuda_runtime.h>
2#include <ATen/core/Tensor.h>
3#include <ATen/cuda/CUDAContext.h>
4#include <c10/cuda/CUDAGuard.h>
5
6#include "marlin_repack.cuh"
7
8__global__ void gptq_repack_kernel(
9uint32_t* in,
10uint32_t* out,
11int m,
12int n
13) {
14uint32_t row = blockIdx.x * 2;
15uint32_t col = blockIdx.y * 64;
16uint32_t t = threadIdx.x;
17
18// marlin packs 4 16x16 blocks one time;
19const int pad_len = 18;
20__shared__ uint8_t block[4][16][pad_len];
21
22// unpack
23int block_idx = t / 8;
24int block_offset = t % 8;
25for (int offset = block_offset; offset < 16; offset += 8) {
26uint32_t v1 = in[row * n + col + block_idx * 16 + offset];
27uint32_t v2 = in[(row + 1) * n + col + block_idx * 16 + offset];
28#pragma unroll
29for (int i = 0; i < 8; i += 1) {
30block[block_idx][i][offset] = v1 & 0xf;
31v1 >>= 4;
32block[block_idx][i + 8][offset] = v2 & 0xf;
33v2 >>= 4;
34}
35}
36
37// repack
38// ref: _get_perms @ https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py
39uint32_t srow = (t % 4) * 2;
40uint32_t scol = t / 4;
41
42uint32_t idx[8][2];
43idx[0][0] = srow; idx[0][1] = scol;
44idx[1][0] = srow + 8; idx[1][1] = scol;
45idx[2][0] = srow; idx[2][1] = scol + 8;
46idx[3][0] = srow + 8; idx[3][1] = scol + 8;
47
48idx[4][0] = srow + 1; idx[4][1] = scol;
49idx[5][0] = srow + 9; idx[5][1] = scol;
50idx[6][0] = srow + 1; idx[6][1] = scol + 8;
51idx[7][0] = srow + 9; idx[7][1] = scol + 8;
52
53#pragma unroll
54for (int i = 0; i < 4; i += 1) {
55uint32_t v[8];
56#pragma unroll
57for (int j = 0; j < 8; ++j) {
58v[j] = block[i][idx[j][0]][idx[j][1]];
59}
60
61uint32_t pack = (v[7] << 28) | (v[6] << 24) | (v[5] << 20) | (v[4] << 16) |
62(v[3] << 12) | (v[2] << 8) | (v[1] << 4) | v[0];
63
64out[blockIdx.x * n * 2 + blockIdx.y * 128 + t * 4 + i] = pack;
65}
66}
67
68torch::Tensor gptq_repack(
69torch::Tensor W
70) {
71int m = W.sizes()[0];
72int n = W.sizes()[1];
73
74assert(W.is_contiguous());
75assert(W.dtype() == at::kInt);
76assert(m % 2 == 0);
77assert(n % 64 == 0);
78auto result = at::empty(
79{m / 2, n * 2}, at::TensorOptions().dtype(at::kInt).device(W.device()));
80
81const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
82const dim3 threads(32);
83// marlin packs 16 x 64 block and gptq packs 8 x 1
84const dim3 blocks(m / 2, n / 64);
85cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
86gptq_repack_kernel<<<blocks, threads, 0, stream>>>(
87(uint32_t*)W.data_ptr(),
88(uint32_t*)result.data_ptr(),
89m,
90n
91);
92return result;
93}