DeepSpeed

Зеркало из https://github.com/microsoft/DeepSpeed
Форк
0
171 строка · 7.2 Кб
1
# Copyright (c) Microsoft Corporation.
2
# SPDX-License-Identifier: Apache-2.0
3

4
# DeepSpeed Team
5

6
######## Fused MoE kernel #########
7
# These kernels are implemented for
8
# fusing GeMM with dequantization of
9
# fp8 weight data when using bit-16
10
# activation.
11
###################################
12

13
import torch
14
import triton
15
import triton.language as tl
16

17

18
@triton.jit
19
def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
20
                           stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
21
                           BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
22
                           quantization_group_size: tl.constexpr):
23
    pid = tl.program_id(axis=0)
24
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
25
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
26
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
27
    group_id = pid // num_pid_in_group
28
    first_pid_m = group_id * GROUP_SIZE_M
29
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
30
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
31
    pid_n = (pid % num_pid_in_group) // group_size_m
32

33
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
34
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
35
    offs_k = tl.arange(0, BLOCK_SIZE_K)
36

37
    inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
38
    weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
39
    weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
40
        (pid_n * BLOCK_SIZE_N) // quantization_group_size)
41

42
    weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
43
    scale = tl.load(scale_ptr + weight_ptrs_offset)
44

45
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
46
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
47
        inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
48
        # Dequantize weight (fp8 -> bf16)
49
        w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
50
        w = (w + 0x3C00).to(tl.uint16)
51
        w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)
52

53
        inp_data += BLOCK_SIZE_K * stride_ak
54
        weight_data += BLOCK_SIZE_K * stride_bk
55
        weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
56
        weight = tl.load(weight_data, mask=weight_mask, other=0.0)
57
        scale = tl.load(scale_ptr + (weight_ptrs_offset +
58
                                     (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
59
                        mask=weight_mask,
60
                        other=0.0)
61

62
        accumulator += tl.dot(inp, w)
63

64
    out = accumulator.to(tl.bfloat16)
65

66
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
67
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
    out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
69
    tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
70

71

72
@triton.jit
73
def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
74
                           stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
75
                           BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
76
                           quantization_group_size: tl.constexpr):
77
    pid = tl.program_id(axis=0)
78
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
79
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
80
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
81
    group_id = pid // num_pid_in_group
82
    first_pid_m = group_id * GROUP_SIZE_M
83
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
84
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
85
    pid_n = (pid % num_pid_in_group) // group_size_m
86

87
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
88
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
89
    offs_k = tl.arange(0, BLOCK_SIZE_K)
90

91
    inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
92
    weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
93
    weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
94
        (pid_n * BLOCK_SIZE_N) // quantization_group_size)
95

96
    weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
97
    scale = tl.load(scale_ptr + weight_ptrs_offset)
98

99
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
101
        inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
102
        # Dequantize weight (fp8 -> fp16)
103
        w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
104
        w = (w + 0x2000).to(tl.uint16)
105
        w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)
106

107
        inp_data += BLOCK_SIZE_K * stride_ak
108
        weight_data += BLOCK_SIZE_K * stride_bk
109

110
        weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
111
        scale = tl.load(scale_ptr + (weight_ptrs_offset +
112
                                     (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))
113

114
        accumulator += tl.dot(inp, w)
115

116
    out = accumulator.to(tl.float16)
117

118
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
119
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
120
    out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
121
    tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
122

123

124
def matmul_fp8(inp, weight, scale, quantization_group_size):
125

126
    assert inp.shape[1] == weight.shape[0], \
127
        f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"
128

129
    M, K = inp.shape
130
    K, N = weight.shape
131

132
    out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
133

134
    # GEMM tuning parameters!
135
    # TODO: Add a more configurable tuning for selecting the best GeMM
136
    BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
137
    BLOCK_SIZE_N = 64
138
    BLOCK_SIZE_K = max(64, quantization_group_size)
139
    GROUP_SIZE_M = 8
140
    num_stages = 4
141
    num_warps = 4
142
    if M >= 256:
143
        BLOCK_SIZE_M = 256
144
        BLOCK_SIZE_N = 128
145
        BLOCK_SIZE_K = max(128, quantization_group_size)
146
        num_stages = 3
147
        num_warps = 8
148

149
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
150
    kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
151
    kernel[grid](inp,
152
                 weight,
153
                 out,
154
                 scale,
155
                 M,
156
                 N,
157
                 K,
158
                 inp.stride(0),
159
                 inp.stride(1),
160
                 weight.stride(0),
161
                 weight.stride(1),
162
                 out.stride(0),
163
                 out.stride(1),
164
                 quantization_group_size=quantization_group_size,
165
                 BLOCK_SIZE_M=BLOCK_SIZE_M,
166
                 BLOCK_SIZE_N=BLOCK_SIZE_N,
167
                 BLOCK_SIZE_K=BLOCK_SIZE_K,
168
                 GROUP_SIZE_M=GROUP_SIZE_M,
169
                 num_stages=num_stages,
170
                 num_warps=num_warps)
171
    return out
172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.