DeepSpeed
Зеркало из https://github.com/microsoft/DeepSpeed
171 строка · 7.2 Кб
1# Copyright (c) Microsoft Corporation.
2# SPDX-License-Identifier: Apache-2.0
3
4# DeepSpeed Team
5
6######## Fused MoE kernel #########
7# These kernels are implemented for
8# fusing GeMM with dequantization of
9# fp8 weight data when using bit-16
10# activation.
11###################################
12
13import torch
14import triton
15import triton.language as tl
16
17
18@triton.jit
19def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
20stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
21BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
22quantization_group_size: tl.constexpr):
23pid = tl.program_id(axis=0)
24num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
25num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
26num_pid_in_group = GROUP_SIZE_M * num_pid_n
27group_id = pid // num_pid_in_group
28first_pid_m = group_id * GROUP_SIZE_M
29group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
30pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
31pid_n = (pid % num_pid_in_group) // group_size_m
32
33offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
34offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
35offs_k = tl.arange(0, BLOCK_SIZE_K)
36
37inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
38weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
39weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
40(pid_n * BLOCK_SIZE_N) // quantization_group_size)
41
42weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
43scale = tl.load(scale_ptr + weight_ptrs_offset)
44
45accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
46for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
47inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
48# Dequantize weight (fp8 -> bf16)
49w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
50w = (w + 0x3C00).to(tl.uint16)
51w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)
52
53inp_data += BLOCK_SIZE_K * stride_ak
54weight_data += BLOCK_SIZE_K * stride_bk
55weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
56weight = tl.load(weight_data, mask=weight_mask, other=0.0)
57scale = tl.load(scale_ptr + (weight_ptrs_offset +
58(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
59mask=weight_mask,
60other=0.0)
61
62accumulator += tl.dot(inp, w)
63
64out = accumulator.to(tl.bfloat16)
65
66offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
67offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
69tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
70
71
72@triton.jit
73def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
74stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
75BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
76quantization_group_size: tl.constexpr):
77pid = tl.program_id(axis=0)
78num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
79num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
80num_pid_in_group = GROUP_SIZE_M * num_pid_n
81group_id = pid // num_pid_in_group
82first_pid_m = group_id * GROUP_SIZE_M
83group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
84pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
85pid_n = (pid % num_pid_in_group) // group_size_m
86
87offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
88offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
89offs_k = tl.arange(0, BLOCK_SIZE_K)
90
91inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
92weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
93weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
94(pid_n * BLOCK_SIZE_N) // quantization_group_size)
95
96weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
97scale = tl.load(scale_ptr + weight_ptrs_offset)
98
99accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
101inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
102# Dequantize weight (fp8 -> fp16)
103w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
104w = (w + 0x2000).to(tl.uint16)
105w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)
106
107inp_data += BLOCK_SIZE_K * stride_ak
108weight_data += BLOCK_SIZE_K * stride_bk
109
110weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
111scale = tl.load(scale_ptr + (weight_ptrs_offset +
112(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))
113
114accumulator += tl.dot(inp, w)
115
116out = accumulator.to(tl.float16)
117
118offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
119offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
120out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
121tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))
122
123
124def matmul_fp8(inp, weight, scale, quantization_group_size):
125
126assert inp.shape[1] == weight.shape[0], \
127f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"
128
129M, K = inp.shape
130K, N = weight.shape
131
132out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)
133
134# GEMM tuning parameters!
135# TODO: Add a more configurable tuning for selecting the best GeMM
136BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
137BLOCK_SIZE_N = 64
138BLOCK_SIZE_K = max(64, quantization_group_size)
139GROUP_SIZE_M = 8
140num_stages = 4
141num_warps = 4
142if M >= 256:
143BLOCK_SIZE_M = 256
144BLOCK_SIZE_N = 128
145BLOCK_SIZE_K = max(128, quantization_group_size)
146num_stages = 3
147num_warps = 8
148
149grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
150kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
151kernel[grid](inp,
152weight,
153out,
154scale,
155M,
156N,
157K,
158inp.stride(0),
159inp.stride(1),
160weight.stride(0),
161weight.stride(1),
162out.stride(0),
163out.stride(1),
164quantization_group_size=quantization_group_size,
165BLOCK_SIZE_M=BLOCK_SIZE_M,
166BLOCK_SIZE_N=BLOCK_SIZE_N,
167BLOCK_SIZE_K=BLOCK_SIZE_K,
168GROUP_SIZE_M=GROUP_SIZE_M,
169num_stages=num_stages,
170num_warps=num_warps)
171return out
172