5
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
7
This is PyTorch implementation of main part of reorder_meta()
8
function, from tools/util/include/cutlass/util/host_reorder.h file
9
of CUTLASS source tree. Furthermore, CUTLASS template for sparse
10
GEMM decides upon layout of this matrix, and at the moment for the
11
sparse GEMM executed on tensor cores, this is layout described by
12
ColumnMajorInterleaved<2> data structure, in
13
include/cutlass/layout/matrix.h of CUTLASS source tree. The
14
reordering of meta matrix into meta_reordered matrix calculated
15
according to these segments of CUTLASS code is re-implemented here.
16
Note that this calculation produces offsets for scattering metadata
17
matrix elements into reordered metadata matrix elements (or,
18
equivalently, for gathering reordered metadata matrix element back
19
into metadata matrix elements).
21
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
22
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
25
group = 32 if meta_dtype.itemsize == 2 else 16
26
interweave = 4 if meta_dtype.itemsize == 2 else 2
28
dst_rows // group * group
29
+ (dst_rows % 8) * interweave
30
+ (dst_rows % group) // 8
33
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
34
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
35
dst_rows += topright - bottomleft
36
dst_cols -= topright - bottomleft
42
cols_maj = dst_cols // interleave
43
cols_min = dst_cols % interleave
44
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
47
def sparse_semi_structured_from_dense_cutlass(dense):
49
This function converts dense matrix into sparse semi-structured
50
representation, producing "compressed" matrix, in the layout used by
51
CUTLASS backend, and corresponding metadata matrix.
55
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor"
61
meta_dtype = torch.int8
62
if dense.dtype == torch.int8:
63
meta_dtype = torch.int32
64
elif dense.dtype in [torch.half, torch.bfloat16, torch.float]:
65
meta_dtype = torch.int16
67
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
68
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
69
if quadbits_per_meta_elem not in (4, 8):
70
raise RuntimeError("Invalid number of elements per meta element calculated")
72
if meta_dtype == torch.int32:
75
f"Number of rows of dense matrix {m} must be divisible by 16"
80
f"Number of rows of dense matrix {m} must be divisible by 32"
82
if k % (4 * quadbits_per_meta_elem) != 0:
84
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}"
87
if dense.dtype != torch.float:
89
dense_4 = dense.view(-1, k // ksparse, ksparse)
90
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
93
dense_2 = dense.view(-1, k // ksparse, ksparse)
94
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
95
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
136
bit2 = expr0 | expr2 | m3
138
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
139
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
141
if dense.dtype != torch.float:
142
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1))
143
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
144
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
146
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2)
148
meta_4 = idxs0 | (idxs1 << 2)
149
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
151
if quadbits_per_meta_elem == 4:
154
| (meta_n[:, :, 1] << 4)
155
| (meta_n[:, :, 2] << 8)
156
| (meta_n[:, :, 3] << 12)
158
elif quadbits_per_meta_elem == 8:
161
| (meta_n[:, :, 1] << 4)
162
| (meta_n[:, :, 2] << 8)
163
| (meta_n[:, :, 3] << 12)
164
| (meta_n[:, :, 4] << 16)
165
| (meta_n[:, :, 5] << 20)
166
| (meta_n[:, :, 6] << 24)
167
| (meta_n[:, :, 7] << 28)
171
meta_reordered = meta.new_empty((m * meta_ncols,))
172
meta_offsets = _calculate_meta_reordering_scatter_offsets(
173
m, meta_ncols, meta_dtype, device
175
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
177
return (sparse, meta_reordered.view(m, meta_ncols))
180
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
182
This function performs reverse of the function above - it
183
reconstructs dense matrix from a pair of "compressed" matrix, given
184
in the layout used by CUTLASS backend, and accompanying metadata
187
if sparse.dim() != 2:
189
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor"
193
device = sparse.device
195
if meta_reordered.dim() != 2:
197
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor"
199
if meta_reordered.device != device:
201
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device"
204
meta_dtype = meta_reordered.dtype
205
if meta_dtype not in (torch.int16, torch.int32):
206
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
207
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
209
if sparse.dtype != torch.float:
214
meta_nrows, meta_ncols = meta_reordered.shape
217
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}"
219
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
221
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, "
222
"expected according to the number of columns of meta matrix"
226
meta_offsets = _calculate_meta_reordering_scatter_offsets(
227
m, meta_ncols, meta_dtype, device
229
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
238
meta_2 = torch.empty(
239
(m, meta_ncols, 2 * quadbits_per_meta_elem),
243
if quadbits_per_meta_elem == 4:
244
meta_2[:, :, 0] = meta & 0b11
245
meta_2[:, :, 1] = (meta >> 2) & 0b11
246
meta_2[:, :, 2] = (meta >> 4) & 0b11
247
meta_2[:, :, 3] = (meta >> 6) & 0b11
248
meta_2[:, :, 4] = (meta >> 8) & 0b11
249
meta_2[:, :, 5] = (meta >> 10) & 0b11
250
meta_2[:, :, 6] = (meta >> 12) & 0b11
251
meta_2[:, :, 7] = (meta >> 14) & 0b11
252
elif quadbits_per_meta_elem == 8:
253
meta_2[:, :, 0] = meta & 0b11
254
meta_2[:, :, 1] = (meta >> 2) & 0b11
255
meta_2[:, :, 2] = (meta >> 4) & 0b11
256
meta_2[:, :, 3] = (meta >> 6) & 0b11
257
meta_2[:, :, 4] = (meta >> 8) & 0b11
258
meta_2[:, :, 5] = (meta >> 10) & 0b11
259
meta_2[:, :, 6] = (meta >> 12) & 0b11
260
meta_2[:, :, 7] = (meta >> 14) & 0b11
261
meta_2[:, :, 8] = (meta >> 16) & 0b11
262
meta_2[:, :, 9] = (meta >> 18) & 0b11
263
meta_2[:, :, 10] = (meta >> 20) & 0b11
264
meta_2[:, :, 11] = (meta >> 22) & 0b11
265
meta_2[:, :, 12] = (meta >> 24) & 0b11
266
meta_2[:, :, 13] = (meta >> 26) & 0b11
267
meta_2[:, :, 14] = (meta >> 28) & 0b11
268
meta_2[:, :, 15] = (meta >> 30) & 0b11
270
dense_offsets = meta_2.view(-1) + (
271
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
272
).view(-1, 1).repeat(1, 2).view(-1)
274
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
275
if sparse.dtype != torch.float:
276
dense.scatter_(0, dense_offsets, sparse.view(-1))
278
dense.view(torch.half).scatter_(
279
0, dense_offsets, sparse.view(torch.half).view(-1)
282
return dense.view(m, 2 * k)
285
def _sparse_semi_structured_tile(dense):
287
This function computes a 2:4 sparse tile by greedily taking the largest values.
289
Since we take the largest values greedily, how the sorting algorithm handles duplicates affects
290
the ultimate sparsity pattern.
292
Note that this function does not have the same sorting semantics as our CUDA backend,
293
which is exposed via `torch._sparse_semi_structured_tile` and thus returns a different pattern.
296
def greedy_prune_tile(tile):
297
num_kept_row = [0, 0, 0, 0]
298
num_kept_col = [0, 0, 0, 0]
300
for x in tile.flatten().sort(descending=True, stable=True).indices:
302
if num_kept_row[r] < 2 and num_kept_col[c] < 2:
308
for batch in dense.unfold(0, 4, 4).unfold(1, 4, 4):
310
greedy_prune_tile(tile)
315
def _compute_compressed_swizzled_bitmask(dense):
317
Calculates the compressed swizzled bitmask from a dense tensor
321
int_bitmask = dense.bool().to(torch.uint8)
332
bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
335
bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
345
bitmask_binary_representation = bitmask_4x4_chunks.reshape(
346
*bitmask_4x4_chunks.shape[:2], 4, 2, 8
350
powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
352
compressed_swizzled_bitmask = (
353
bitmask_binary_representation.to(torch.float) @ powers_of_two
356
return compressed_swizzled_bitmask