pytorch

Форк
0
/
_semi_structured_conversions.py 
356 строк · 13.7 Кб
1
# mypy: allow-untyped-defs
2
import torch
3

4

5
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
6
    """
7
    This is PyTorch implementation of main part of reorder_meta()
8
    function, from tools/util/include/cutlass/util/host_reorder.h file
9
    of CUTLASS source tree.  Furthermore, CUTLASS template for sparse
10
    GEMM decides upon layout of this matrix, and at the moment for the
11
    sparse GEMM executed on tensor cores, this is layout described by
12
    ColumnMajorInterleaved<2> data structure, in
13
    include/cutlass/layout/matrix.h of CUTLASS source tree.  The
14
    reordering of meta matrix into meta_reordered matrix calculated
15
    according to these segments of CUTLASS code is re-implemented here.
16
    Note that this calculation produces offsets for scattering metadata
17
    matrix elements into reordered metadata matrix elements (or,
18
    equivalently, for gathering reordered metadata matrix element back
19
    into metadata matrix elements).
20
    """
21
    dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
22
    dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
23

24
    # Reorder the rows, then swizzle the 2x2 blocks.
25
    group = 32 if meta_dtype.itemsize == 2 else 16
26
    interweave = 4 if meta_dtype.itemsize == 2 else 2
27
    dst_rows = (
28
        dst_rows // group * group
29
        + (dst_rows % 8) * interweave
30
        + (dst_rows % group) // 8
31
    )
32

33
    topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
34
    bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
35
    dst_rows += topright - bottomleft
36
    dst_cols -= topright - bottomleft
37

38
    # Assumed that meta tensor is to be stored in CUTLASS
39
    # InterleavedColumnMajor layout, and reverse engineered
40
    # corresponding code to store values into this tensor.
41
    interleave = 2
42
    cols_maj = dst_cols // interleave
43
    cols_min = dst_cols % interleave
44
    return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
45

46

47
def sparse_semi_structured_from_dense_cutlass(dense):
48
    """
49
    This function converts dense matrix into sparse semi-structured
50
    representation, producing "compressed" matrix, in the layout used by
51
    CUTLASS backend, and corresponding metadata matrix.
52
    """
53
    if dense.dim() != 2:
54
        raise RuntimeError(
55
            f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
56
        )
57

58
    m, k = dense.shape
59
    device = dense.device
60

61
    meta_dtype = torch.int8
62
    if dense.dtype == torch.int8:
63
        meta_dtype = torch.int32
64
    elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
65
        meta_dtype = torch.int16
66
    else:
67
        raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
68
    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
69
    if quadbits_per_meta_elem not in (4, 8):
70
        raise RuntimeError("Invalid number of elements per meta element calculated")
71

72
    if meta_dtype == torch.int32:
73
        if m % 16 != 0:
74
            raise RuntimeError(
75
                f"Number of rows of dense matrix {m} must be divisible by 16"
76
            )
77
    else:
78
        if m % 32 != 0:
79
            raise RuntimeError(
80
                f"Number of rows of dense matrix {m} must be divisible by 32"
81
            )
82
    if k % (4 * quadbits_per_meta_elem) != 0:
83
        raise RuntimeError(
84
            f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
85
        )
86

87
    if dense.dtype != torch.float:
88
        ksparse = 4
89
        dense_4 = dense.view(-1, k // ksparse, ksparse)
90
        m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
91
    else:
92
        ksparse = 2
93
        dense_2 = dense.view(-1, k // ksparse, ksparse)
94
        m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
95
    meta_ncols = k // (ksparse * quadbits_per_meta_elem)
96

97
    # Encoding quadruples of True/False values as follows:
98
    #     [True,  True,  False, False] -> 0b0100
99
    #     [True,  False, True,  False] -> 0b1000
100
    #     [False, True,  True,  False] -> 0b1001
101
    #     [True,  False, False, True ] -> 0b1100
102
    #     [False, True,  False, True ] -> 0b1101
103
    #     [False, False, True,  True ] -> 0b1110
104
    # Thus, lower two bits in the encoding are index of the True value
105
    # at the lowest index in the quadruple, and the higher two bits in
106
    # the encoding are index of the other True value in the quadruple.
107
    # In case there are less than two True values, than False value or
108
    # values at some index or indices are considered True for the
109
    # encoding.  In case there are more than two True values, then the
110
    # excess True value(s) at some indices are considered False for
111
    # the encoding.  The exact encodings used for these cases are as
112
    # follows:
113
    #     [False, False, False, False] -> 0b1110
114
    #     [False, False, False, True ] -> 0b1110
115
    #     [False, False, True,  False] -> 0b1110
116
    #     [False, True,  False, False] -> 0b1001
117
    #     [False, True,  True,  True ] -> 0b1101
118
    #     [True,  False, False, False] -> 0b1000
119
    #     [True,  False, True,  True ] -> 0b1100
120
    #     [True,  True,  False, True ] -> 0b0100
121
    #     [True,  True,  True,  False] -> 0b0100
122
    #     [True,  True,  True,  True ] -> 0b0100
123
    # These particular encodings are chosen, with the help of Espresso
124
    # logic minimizer software, for the purpose of minimization of
125
    # corresponding Boolean functions, that translate non-zero flags
126
    # into encoding bits.  Note also possible choices for the first
127
    # and last of these encodings were limited only to (0b0100,
128
    # 0b1110), in order to produce valid encodings for 1:2 sparsity
129
    # case.
130

131
    expr0 = m0 & m1
132
    expr1 = ~m0 & m1
133
    expr2 = ~m0 & ~m1
134
    bit0 = expr1
135
    bit1 = expr2
136
    bit2 = expr0 | expr2 | m3
137
    bit3 = expr1 | ~m1
138
    idxs0 = bit0 | (bit1.to(torch.int64) << 1)
139
    idxs1 = bit2 | (bit3.to(torch.int64) << 1)
140

141
    if dense.dtype != torch.float:
142
        sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))  # type: ignore[possibly-undefined]
143
        sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
144
        sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
145
    else:
146
        sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)  # type: ignore[possibly-undefined]
147

148
    meta_4 = idxs0 | (idxs1 << 2)
149
    meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
150

151
    if quadbits_per_meta_elem == 4:
152
        meta = (
153
            meta_n[:, :, 0]
154
            | (meta_n[:, :, 1] << 4)
155
            | (meta_n[:, :, 2] << 8)
156
            | (meta_n[:, :, 3] << 12)
157
        )
158
    elif quadbits_per_meta_elem == 8:
159
        meta = (
160
            meta_n[:, :, 0]
161
            | (meta_n[:, :, 1] << 4)
162
            | (meta_n[:, :, 2] << 8)
163
            | (meta_n[:, :, 3] << 12)
164
            | (meta_n[:, :, 4] << 16)
165
            | (meta_n[:, :, 5] << 20)
166
            | (meta_n[:, :, 6] << 24)
167
            | (meta_n[:, :, 7] << 28)
168
        )
169

170
    # Reorder meta tensor elements.
171
    meta_reordered = meta.new_empty((m * meta_ncols,))  # type: ignore[possibly-undefined]
172
    meta_offsets = _calculate_meta_reordering_scatter_offsets(
173
        m, meta_ncols, meta_dtype, device
174
    )
175
    meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
176

177
    return (sparse, meta_reordered.view(m, meta_ncols))
178

179

180
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
181
    """
182
    This function performs reverse of the function above - it
183
    reconstructs dense matrix from a pair of "compressed" matrix, given
184
    in the layout used by CUTLASS backend, and accompanying metadata
185
    matrix.
186
    """
187
    if sparse.dim() != 2:
188
        raise RuntimeError(
189
            f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
190
        )
191

192
    m, k = sparse.shape
193
    device = sparse.device
194

195
    if meta_reordered.dim() != 2:
196
        raise RuntimeError(
197
            f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
198
        )
199
    if meta_reordered.device != device:
200
        raise RuntimeError(
201
            f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
202
        )
203

204
    meta_dtype = meta_reordered.dtype
205
    if meta_dtype not in (torch.int16, torch.int32):
206
        raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
207
    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
208

209
    if sparse.dtype != torch.float:
210
        ksparse = 4
211
    else:
212
        ksparse = 2
213

214
    meta_nrows, meta_ncols = meta_reordered.shape
215
    if meta_nrows != m:
216
        raise RuntimeError(
217
            f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
218
        )
219
    if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
220
        raise RuntimeError(
221
            f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "
222
            "expected according to the number of columns of meta matrix"
223
        )
224

225
    # Undo meta tensor elements reordering.
226
    meta_offsets = _calculate_meta_reordering_scatter_offsets(
227
        m, meta_ncols, meta_dtype, device
228
    )
229
    meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
230

231
    # Unpack sparse tensor back to original dense tensor, using
232
    # information provided by meta tensor.  Note that torch.float
233
    # datatype is handled pretty much the same as
234
    # torch.half/torch.bfloat16, as metadata for a pair of torch.float
235
    # value is encoded as if underlying 8 bytes contain four
236
    # torch.half/torch.bfloat16 values, where either first two or last
237
    # two are zeros.
238
    meta_2 = torch.empty(
239
        (m, meta_ncols, 2 * quadbits_per_meta_elem),
240
        dtype=meta_dtype,
241
        device=device,
242
    )
243
    if quadbits_per_meta_elem == 4:
244
        meta_2[:, :, 0] = meta & 0b11
245
        meta_2[:, :, 1] = (meta >> 2) & 0b11
246
        meta_2[:, :, 2] = (meta >> 4) & 0b11
247
        meta_2[:, :, 3] = (meta >> 6) & 0b11
248
        meta_2[:, :, 4] = (meta >> 8) & 0b11
249
        meta_2[:, :, 5] = (meta >> 10) & 0b11
250
        meta_2[:, :, 6] = (meta >> 12) & 0b11
251
        meta_2[:, :, 7] = (meta >> 14) & 0b11
252
    elif quadbits_per_meta_elem == 8:
253
        meta_2[:, :, 0] = meta & 0b11
254
        meta_2[:, :, 1] = (meta >> 2) & 0b11
255
        meta_2[:, :, 2] = (meta >> 4) & 0b11
256
        meta_2[:, :, 3] = (meta >> 6) & 0b11
257
        meta_2[:, :, 4] = (meta >> 8) & 0b11
258
        meta_2[:, :, 5] = (meta >> 10) & 0b11
259
        meta_2[:, :, 6] = (meta >> 12) & 0b11
260
        meta_2[:, :, 7] = (meta >> 14) & 0b11
261
        meta_2[:, :, 8] = (meta >> 16) & 0b11
262
        meta_2[:, :, 9] = (meta >> 18) & 0b11
263
        meta_2[:, :, 10] = (meta >> 20) & 0b11
264
        meta_2[:, :, 11] = (meta >> 22) & 0b11
265
        meta_2[:, :, 12] = (meta >> 24) & 0b11
266
        meta_2[:, :, 13] = (meta >> 26) & 0b11
267
        meta_2[:, :, 14] = (meta >> 28) & 0b11
268
        meta_2[:, :, 15] = (meta >> 30) & 0b11
269

270
    dense_offsets = meta_2.view(-1) + (
271
        torch.arange(0, 2 * m * k // ksparse, device=device) * 4
272
    ).view(-1, 1).repeat(1, 2).view(-1)
273

274
    dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
275
    if sparse.dtype != torch.float:
276
        dense.scatter_(0, dense_offsets, sparse.view(-1))
277
    else:
278
        dense.view(torch.half).scatter_(
279
            0, dense_offsets, sparse.view(torch.half).view(-1)
280
        )
281

282
    return dense.view(m, 2 * k)
283

284

285
def _sparse_semi_structured_tile(dense):
286
    """
287
    This function computes a 2:4 sparse tile by greedily taking the largest values.
288

289
    Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
290
    the ultimate sparsity pattern.
291

292
    Note that this function does not have the same sorting semantics as our CUDA backend,
293
    which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
294
    """
295

296
    def greedy_prune_tile(tile):
297
        num_kept_row = [0, 0, 0, 0]
298
        num_kept_col = [0, 0, 0, 0]
299

300
        for x in tile.flatten().sort(descending=True, stable=True).indices:
301
            r, c = x // 4, x % 4
302
            if num_kept_row[r] < 2 and num_kept_col[c] < 2:
303
                num_kept_row[r] += 1
304
                num_kept_col[c] += 1
305
            else:
306
                tile[r, c] = 0
307

308
    for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
309
        for tile in batch:
310
            greedy_prune_tile(tile)
311

312
    return dense
313

314

315
def _compute_compressed_swizzled_bitmask(dense):
316
    """
317
    Calculates the compressed swizzled bitmask from a dense tensor
318
    """
319

320
    # first we need to convert the dense tensor to a bitmask
321
    int_bitmask = dense.bool().to(torch.uint8)
322

323
    # Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
324
    # A, B, C and D, as displayed in the following schema:
325
    # +---+---+
326
    # | A | B |
327
    # +---+---+
328
    # | C | D |
329
    # +---+---+
330

331
    # we first need to split into the 8x8 tiles
332
    bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
333

334
    # then we unfold again to get our indivdual 4x4 tiles
335
    bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
336

337
    # Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
338
    # of that tile. Note that the least siginificant bit is stored first.
339
    # [1 1 0 0]
340
    # [1 1 0 0]  ->  0011 0011 ->   51
341
    # [0 0 1 1]      1100 1100      204
342
    # [0 0 1 1]
343

344
    # reshape tensor to expand tiles into 8-bit vectors
345
    bitmask_binary_representation = bitmask_4x4_chunks.reshape(
346
        *bitmask_4x4_chunks.shape[:2], 4, 2, 8
347
    )
348

349
    # to convert from binary representaiton, we can do a matmul with powers of two
350
    powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
351
    # To run on GPU: cast to float to do matmul and then cast back
352
    compressed_swizzled_bitmask = (
353
        bitmask_binary_representation.to(torch.float) @ powers_of_two
354
    ).to(torch.uint8)
355

356
    return compressed_swizzled_bitmask
357

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.