autogptq

Форк
0
93 строки · 2.5 Кб
1
#include <cuda_runtime.h>
2
#include <ATen/core/Tensor.h>
3
#include <ATen/cuda/CUDAContext.h>
4
#include <c10/cuda/CUDAGuard.h>
5

6
#include "marlin_repack.cuh"
7

8
__global__ void gptq_repack_kernel(
9
  uint32_t* in,
10
  uint32_t* out,
11
  int m,
12
  int n
13
) {
14
  uint32_t row = blockIdx.x * 2;
15
  uint32_t col = blockIdx.y * 64;
16
  uint32_t t = threadIdx.x;
17

18
  // marlin packs 4 16x16 blocks one time;
19
  const int pad_len = 18;
20
  __shared__ uint8_t block[4][16][pad_len];
21

22
  // unpack
23
  int block_idx = t / 8;
24
  int block_offset = t % 8;
25
  for (int offset = block_offset; offset < 16; offset += 8) {
26
    uint32_t v1 = in[row * n + col + block_idx * 16 + offset];
27
    uint32_t v2 = in[(row + 1) * n + col + block_idx * 16 + offset];
28
#pragma unroll
29
    for (int i = 0; i < 8; i += 1) {
30
      block[block_idx][i][offset] = v1 & 0xf;
31
      v1 >>= 4;
32
      block[block_idx][i + 8][offset] = v2 & 0xf;
33
      v2 >>= 4;
34
    }
35
  }
36

37
  // repack
38
  // ref: _get_perms @ https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py
39
  uint32_t srow = (t % 4) * 2;
40
  uint32_t scol = t / 4;
41

42
  uint32_t idx[8][2];
43
  idx[0][0] = srow;     idx[0][1] = scol;
44
  idx[1][0] = srow + 8; idx[1][1] = scol;
45
  idx[2][0] = srow;     idx[2][1] = scol + 8;
46
  idx[3][0] = srow + 8; idx[3][1] = scol + 8;
47

48
  idx[4][0] = srow + 1; idx[4][1] = scol;
49
  idx[5][0] = srow + 9; idx[5][1] = scol;
50
  idx[6][0] = srow + 1; idx[6][1] = scol + 8;
51
  idx[7][0] = srow + 9; idx[7][1] = scol + 8;
52

53
#pragma unroll
54
  for (int i = 0; i < 4; i += 1) {
55
    uint32_t v[8];
56
#pragma unroll
57
    for (int j = 0; j < 8; ++j) {
58
      v[j] = block[i][idx[j][0]][idx[j][1]];
59
    }
60

61
    uint32_t pack = (v[7] << 28) | (v[6] << 24) | (v[5] << 20) | (v[4] << 16) |
62
        (v[3] << 12) | (v[2] << 8) | (v[1] << 4) | v[0];
63

64
    out[blockIdx.x * n * 2 + blockIdx.y * 128 + t * 4 + i] = pack;
65
  }
66
}
67

68
torch::Tensor gptq_repack(
69
    torch::Tensor W
70
) {
71
  int m = W.sizes()[0];
72
  int n = W.sizes()[1];
73

74
  assert(W.is_contiguous());
75
  assert(W.dtype() == at::kInt);
76
  assert(m % 2 == 0);
77
  assert(n % 64 == 0);
78
  auto result = at::empty(
79
      {m / 2, n * 2}, at::TensorOptions().dtype(at::kInt).device(W.device()));
80

81
  const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
82
  const dim3 threads(32);
83
  // marlin packs 16 x 64 block and gptq packs 8 x 1
84
  const dim3 blocks(m / 2, n / 64);
85
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
86
  gptq_repack_kernel<<<blocks, threads, 0, stream>>>(
87
    (uint32_t*)W.data_ptr(),
88
    (uint32_t*)result.data_ptr(),
89
    m,
90
    n
91
  );
92
  return result;
93
}

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.