pytorch

Форк
0
/
unpack_mixed_mm.py 
82 строки · 2.8 Кб
1
import logging
2
from typing import List
3

4
from ..select_algorithm import autotune_select_algorithm, ChoiceCaller, TritonTemplate
5
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
6

7
log = logging.getLogger(__name__)
8

9
uint4x2_mixed_mm_template = TritonTemplate(
10
    name="uint4x2_mixed_mm",
11
    grid=mm_grid,
12
    source=r"""
13
{{def_kernel("A", "B")}}
14
    M = {{size("A", 0)}}
15
    N = {{size("B", 1)}}
16
    K = {{size("A", 1)}}
17
    stride_am = {{stride("A", 0)}}
18
    stride_ak = {{stride("A", 1)}}
19
    stride_bk = {{stride("B", 0)}}
20
    stride_bn = {{stride("B", 1)}}
21

22
    # based on triton.ops.matmul
23
    pid = tl.program_id(0)
24
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
25
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
26

27
    # re-order program ID for better L2 performance
28
    width = GROUP_M * grid_n
29
    group_id = pid // width
30
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
31
    pid_m = group_id * GROUP_M + (pid % group_size)
32
    pid_n = (pid % width) // (group_size)
33

34
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
35
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
36
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
37
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
38
    rk = tl.arange(0, BLOCK_K)
39
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
40
    B = B + (rk[:, None]//2 * stride_bk + rbn[None, :] * stride_bn)
41
    b_shifts = 4*(rk%2)
42
    b_subs = 8*(1-(rk%2))
43

44
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
45
    for k in range(K, 0, -BLOCK_K):
46
        if EVEN_K:
47
            a = tl.load(A)
48
            b = tl.load(B)
49
        else:
50
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
51
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
52
        b = ((b >> b_shifts[:, None]) & 0xF) - 8
53
        b = b.to(B_PROLOGUE_CAST_TYPE)
54
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
55
        A += BLOCK_K * stride_ak
56
        B += BLOCK_K//2 * stride_bk
57

58
    # rematerialize rm and rn to save registers
59
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
60
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
61
    idx_m = rm[:, None]
62
    idx_n = rn[None, :]
63
    mask = (idx_m < M) & (idx_n < N)
64

65
    # inductor generates a suffix
66
    {{store_output(("idx_m", "idx_n"), "acc", "mask")}}
67
""",
68
)
69

70

71
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
72
    m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
73
    choices: List[ChoiceCaller] = []
74
    b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
75
    for config in mm_configs(m, n, k):
76
        uint4x2_mixed_mm_template.maybe_append_choice(
77
            choices,
78
            input_nodes=(mat1, mat2),
79
            layout=layout,
80
            **mm_options(config, m, n, k, layout, b_prologue_cast_type),
81
        )
82
    return autotune_select_algorithm("uint4x2_mixed_mm", choices, [mat1, mat2], layout)
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.