1
# Owner(s): ["module: sparse"]
9
from contextlib import redirect_stderr
10
from torch.testing import make_tensor, FileCheck
11
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
12
from torch.testing._internal.common_utils import \
13
(TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
14
run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
16
from torch.testing._internal.common_device_type import \
17
(ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
18
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
20
from torch.testing._internal.common_methods_invocations import \
21
(op_db, sparse_csr_unary_ufuncs, ReductionOpInfo)
22
from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUDA
23
from torch.testing._internal.common_dtype import (
24
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
25
all_types_and_complex, floating_and_complex_types_and)
26
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
27
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
28
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
32
import scipy.sparse as sp
36
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
37
# sharding on sandcastle. This line silences flake warnings
38
load_tests = load_tests
40
no_mkl_sparse = IS_WINDOWS or not TEST_MKL
42
def _check_cusparse_triangular_solve_available():
43
version = _get_torch_cuda_version()
44
# cusparseSpSM was added in 11.3.1 but we don't have access to patch version
45
min_supported_version = (11, 4)
46
return version >= min_supported_version
48
def _check_cusparse_spgemm_available():
49
# cusparseSpGEMM was added in 11.0
50
return not TEST_WITH_ROCM
52
def _check_cusparse_sddmm_available():
55
version = _get_torch_cuda_version()
56
# cusparseSDDMM was added in 11.2.1 but we don't have access to patch version
57
min_supported_version = (11, 3)
58
return version >= min_supported_version
60
_sparse_csr_ops = list(filter(lambda op: op.supports_sparse_csr, op_db))
61
_sparse_compressed_ops = list(filter(lambda op: (op.supports_sparse_csr or op.supports_sparse_csc
62
or op.supports_sparse_bsr or op.supports_sparse_bsc), op_db))
63
binary_functions_with_dense_output = ['mm', 'mv', ]
64
binary_ops_with_dense_output = list(filter(lambda op: op.name in binary_functions_with_dense_output, op_db))
66
UNARY_EWISE_CSR_ALLOW_AUTOGRAD = [
78
# This should be just an import from test_linalg instead of code duplication
79
# but https://github.com/pytorch/pytorch/pull/63511#discussion_r733989701
94
Unified test for checking `f(t, m, v, alpha=alpha, beta=beta)` computation,
95
where f is `torch.addmv` or `torch.addmm`.
96
`transpose_out` controls whether the out argument is in column-major order.
97
`layout` controls whether `m` is converted to specified layout or not.
98
Custom behaviour is implemented only for torch.sparse_csr layout.
102
if dtype in {torch.bfloat16}:
103
numpy_dtype = torch.float
105
alpha = 0.9 + 0.3j if alpha is None else alpha
106
beta = 0.5 + 0.6j if beta is None else beta
108
alpha = 1.2 if alpha is None else alpha
109
beta = 0.8 if beta is None else beta
111
def convert_layout(mat):
112
if layout == torch.sparse_csr:
113
return mat.to_sparse_csr()
114
elif layout == torch.sparse_csc:
115
return mat.to_sparse_csc()
117
assert mat.layout == layout
120
if mode == "all_sparse":
121
res1 = f(*map(convert_layout, (t, m, v)), alpha=alpha, beta=beta)
122
test_case.assertEqual(res1.layout, layout)
123
res1 = res1.to_dense()
124
elif mode == "dense_result":
125
res1 = f(t, convert_layout(m), convert_layout(v), alpha=alpha, beta=beta)
127
res1 = f(t, convert_layout(m), v, alpha=alpha, beta=beta)
128
res2 = torch.full_like(res1, float('nan'))
130
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
131
f(t, convert_layout(m), v, alpha=alpha, beta=beta, out=res2)
132
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
134
res3 += (beta * t).to(numpy_dtype).cpu().numpy()
135
res3 = torch.from_numpy(res3).to(dtype)
136
test_case.assertEqual(res1, res2)
137
test_case.assertEqual(res1, res3)
140
class TestSparseCSRSampler(TestCase):
142
def test_make_crow_indices(self):
143
# Here we test the correctness of the crow_indices algorithm
144
# and testing it on CPU and with int32 dtype will be
146
device = torch.device('cpu')
147
index_dtype = torch.int32
148
for n_rows in range(1, 10):
149
for n_cols in range(1, 10):
150
for nnz in range(0, n_rows * n_cols + 1):
151
crow_indices = self._make_crow_indices(
153
device=device, dtype=index_dtype)
154
self.assertEqual(len(crow_indices), n_rows + 1)
155
counts = crow_indices[1:] - crow_indices[:-1]
156
self.assertEqual(counts.sum(), nnz)
157
self.assertGreaterEqual(counts.min(), 0)
158
self.assertLessEqual(counts.max(), n_cols)
161
def all_sparse_compressed_layouts(test_name='layout'):
162
return parametrize(test_name, [
163
subtest(torch.sparse_csr, name='SparseCSR'),
164
subtest(torch.sparse_csc, name='SparseCSC'),
165
subtest(torch.sparse_bsr, name='SparseBSR'),
166
subtest(torch.sparse_bsc, name='SparseBSC')])
169
def sparse_compressed_nonblock_layouts(test_name='layout'):
170
return parametrize(test_name, [
171
subtest(torch.sparse_csr, name='SparseCSR'),
172
subtest(torch.sparse_csc, name='SparseCSC')])
175
sparse_compressed_indices_methods = {
176
torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
177
torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
178
torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
179
torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
183
def batched_nonbatched(test_name='batched'):
184
return parametrize(test_name, [
185
subtest(True, name="Batched"),
186
subtest(False, name="NonBatched")
190
def hybrid_nonhybrid(test_name='hybrid'):
191
return parametrize(test_name, [
192
subtest(True, name="Hybrid"),
193
subtest(False, name="NonHybrid")
197
class TestSparseCompressed(TestCase):
198
"""Testing sparse compressed (CSR, CSC, BSR, BSC) tensor generic features.
201
def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_dtype=torch.int64):
203
device = self.device_type
204
return self.genSparseCompressedTensor(size, nnz, device=device, dtype=dtype, index_dtype=index_dtype, layout=layout)
206
@all_sparse_compressed_layouts()
208
def test_layout(self, layout):
209
self.assertIn(str(layout), {'torch.sparse_csr', 'torch.sparse_csc', 'torch.sparse_bsr', 'torch.sparse_bsc'})
210
self.assertEqual(type(layout), torch.layout)
212
@parametrize('shape_and_device_inference', [subtest(False, name='_'), subtest(True, name='shape_and_device_inference')])
213
@parametrize('use_factory_function', [subtest(False, name='_'), subtest(True, name='factory')])
214
@parametrize('input_kind', [subtest('tensor', name='from_tensor'), subtest('list', name='from_list')])
215
@all_sparse_compressed_layouts()
216
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
217
def test_sparse_compressed_constructor(self, layout, device, dtype,
218
use_factory_function, shape_and_device_inference, input_kind):
219
if input_kind == 'list' and shape_and_device_inference:
220
if torch.device(device).type == 'cuda':
221
# list inputs to factory/constructor function without
222
# specifying device will result a sparse compressed tensor
223
# on CPU. So, skip testing against cuda device as unused.
224
self.skipTest("nothing to test")
225
if dtype not in {torch.float32, torch.complex64, torch.int64, torch.bool}:
226
self.skipTest("dtype not supported with list values")
228
expected_devices = [torch.device(device)]
229
if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2 and not shape_and_device_inference:
230
expected_devices.append(torch.device('cuda:1'))
233
torch.sparse_csr: torch.sparse_csr_tensor,
234
torch.sparse_csc: torch.sparse_csc_tensor,
235
torch.sparse_bsr: torch.sparse_bsr_tensor,
236
torch.sparse_bsc: torch.sparse_bsc_tensor,
238
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
239
if input_kind == 'list':
240
index_dtypes = [torch.int64]
242
index_dtypes = [torch.int32, torch.int64]
243
if dtype.is_floating_point or dtype.is_complex:
244
requires_grad_lst = [False, True]
246
requires_grad_lst = [False]
247
for index_dtype in index_dtypes:
248
for expected_device in expected_devices:
249
for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
250
layout, device=expected_device, dtype=dtype, index_dtype=index_dtype,
251
# skip zero-sized tensors for list inputs:
252
enable_zero_sized=input_kind != 'list',
253
output_tensor=False):
254
size = kwargs['size']
255
if shape_and_device_inference and 0 in size:
256
# skip shape inference for zero-sized tensor
257
# inputs because (i) the shape determined from
258
# an empty list is ambiguous, and (ii) the
259
# size of the plain dimension defined as
260
# max(plain_indices) is undefined if
261
# plain_indices has no values
263
compressed_indices_expect = compressed_indices
264
plain_indices_expect = plain_indices
265
values_expect = values
267
if input_kind == 'list':
268
compressed_indices = compressed_indices.tolist()
269
plain_indices = plain_indices.tolist()
270
values = values.tolist()
272
for requires_grad in requires_grad_lst:
273
if use_factory_function:
274
if shape_and_device_inference:
275
sparse = factory_function(
276
compressed_indices, plain_indices, values, requires_grad=requires_grad)
278
sparse = factory_function(
279
compressed_indices, plain_indices, values, size,
280
dtype=dtype, device=expected_device, requires_grad=requires_grad)
282
if shape_and_device_inference:
283
sparse = torch.sparse_compressed_tensor(
284
compressed_indices, plain_indices, values,
285
layout=layout, requires_grad=requires_grad)
287
sparse = torch.sparse_compressed_tensor(
288
compressed_indices, plain_indices, values, size,
289
dtype=dtype, layout=layout, device=expected_device, requires_grad=requires_grad)
291
self.assertEqual(layout, sparse.layout)
292
self.assertEqual(size, sparse.shape)
293
self.assertEqual(compressed_indices_expect, compressed_indices_mth(sparse))
294
self.assertEqual(plain_indices_expect, plain_indices_mth(sparse))
295
self.assertEqual(values_expect, sparse.values())
296
self.assertEqual(sparse.device, sparse.values().device)
297
self.assertEqual(sparse.device, expected_device)
298
self.assertEqual(sparse.values().requires_grad, requires_grad)
299
self.assertEqual(sparse.requires_grad, requires_grad)
300
self.assertFalse(compressed_indices_mth(sparse).requires_grad)
301
self.assertFalse(plain_indices_mth(sparse).requires_grad)
304
@sparse_compressed_nonblock_layouts()
305
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
306
def test_empty(self, layout, device, dtype):
308
batch_shapes = [(), (2,), (2, 3)]
310
torch.sparse_csr: -2,
311
torch.sparse_csc: -1,
313
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
314
for m, n, b in itertools.product(ns, ns, batch_shapes):
316
with torch.sparse.check_sparse_tensor_invariants(enable=False):
317
# torch.empty may return invalid sparse compressed tensors
318
result = torch.empty(shape, dtype=dtype, device=device, layout=layout)
319
self.assertEqual(result.shape, shape)
320
self.assertEqual(result.dtype, dtype)
321
self.assertEqual(result.device, torch.device(device))
322
self.assertEqual(result.layout, layout)
323
self.assertEqual(compressed_indices_mth(result).shape, (*b, shape[compressed_dim] + 1,))
324
self.assertEqual(plain_indices_mth(result).shape, (*b, 0,))
325
self.assertEqual(result.values().shape, (*b, 0,))
326
self.assertEqual(result._nnz(), 0)
327
self.assertEqual(compressed_indices_mth(result).device, torch.device(device))
328
self.assertEqual(plain_indices_mth(result).device, torch.device(device))
329
self.assertEqual(result.values().device, torch.device(device))
330
self.assertEqual(compressed_indices_mth(result).dtype, torch.int64)
331
self.assertEqual(plain_indices_mth(result).dtype, torch.int64)
332
self.assertEqual(result.values().dtype, dtype)
335
@sparse_compressed_nonblock_layouts()
336
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
337
def test_empty_errors(self, layout, device, dtype):
338
with self.assertRaisesRegex(RuntimeError,
339
"torch.empty: Only batched sparse compressed \\(non-block\\) tensors are supported"
341
torch.empty((5,), dtype=dtype, device=device, layout=layout)
344
@all_sparse_compressed_layouts()
345
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
346
def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
348
def get_sparse_compressed_tensor_properties(s):
349
if layout in {torch.sparse_csr, torch.sparse_bsr}:
350
compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
352
compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
354
return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
355
compressed_indices_shape=compressed_indices.shape,
356
compressed_indices_dtype=compressed_indices.dtype,
357
compressed_indices_device=compressed_indices.device,
358
plain_indices_shape=plain_indices.shape,
359
plain_indices_dtype=plain_indices.dtype,
360
plain_indices_device=plain_indices.device,
361
values_shape=values.shape,
362
values_dtype=values.dtype,
363
values_device=values.device)
365
for index_dtype in [torch.int32, torch.int64]:
366
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
367
dense_dim = t.dense_dim()
368
sparse_dim = t.sparse_dim()
369
batch_dim = t.ndim - sparse_dim - dense_dim
370
nnz = t.values().shape[batch_dim]
371
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
372
blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
376
e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
377
dtype=dtype, layout=layout, device=device)
379
e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
380
for k, v in e_prop.items():
381
self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
384
@all_sparse_compressed_layouts()
385
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
386
def test_clone(self, layout, device, dtype):
387
for sparse in self.generate_simple_inputs(
388
layout, device=device, dtype=dtype, index_dtype=torch.int32):
389
cloned_sparse = sparse.clone()
390
self.assertEqual(sparse, cloned_sparse)
392
@all_sparse_compressed_layouts()
393
def test_print(self, layout, device):
394
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
396
for enable_hybrid in [False, True]:
397
# using local patterns for test_print stability
399
# 2 x 3 batch of 3 x 2 tensors, trivial blocksize, non-hybrid/hybrid:
411
[0, 2, 0]]]], [(2, 1)], [(), (4,)] if enable_hybrid else [()]),
412
# tensor with non-trivial blocksize, non-hybrid/hybrid:
413
([[0, 1, 0, 2, 0, 2],
420
[7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 2)] if enable_hybrid else [()]),
422
for index_dtype in [torch.int32, torch.int64]:
423
for dtype in [torch.float32, torch.float64]:
424
for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
425
layout, device=device, dtype=dtype, index_dtype=index_dtype, enable_hybrid=enable_hybrid,
426
enable_non_contiguous_indices=False, enable_non_contiguous_values=False,
427
enable_zero_sized=False, output_tensor=False, patterns=patterns):
428
size = tuple(kwargs['size'])
429
block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
431
batch_ndim = compressed_indices.dim() - 1
432
dense_ndim = values.dim() - batch_ndim - block_ndim - 1
433
if enable_hybrid and dense_ndim == 0:
434
# non-hybrid cases are covered by the enable_hybrid==False loop
436
batchsize = size[:batch_ndim]
437
basesize = size[batch_ndim:batch_ndim + base_ndim]
438
densesize = size[batch_ndim + base_ndim:]
439
assert len(densesize) == dense_ndim
440
printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{basesize}+{densesize} ##########")
441
x = torch.sparse_compressed_tensor(compressed_indices,
443
values, size, dtype=dtype, layout=layout, device=device)
444
printed.append("# sparse tensor")
445
printed.append(str(x))
446
printed.append(f"# _{compressed_indices_mth.__name__}")
447
printed.append(str(compressed_indices_mth(x)))
448
printed.append(f"# _{plain_indices_mth.__name__}")
449
printed.append(str(plain_indices_mth(x)))
450
printed.append("# _values")
451
printed.append(str(x.values()))
454
orig_maxDiff = self.maxDiff
457
self.assertExpected('\n'.join(printed))
458
self.maxDiff = orig_maxDiff
460
self.maxDiff = orig_maxDiff
464
@all_sparse_compressed_layouts()
465
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
466
def test_copy(self, layout, device, dtype):
468
def run_test(shape, blocksize, nnz, index_type):
469
a = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
470
index_dtype=index_dtype, blocksize=blocksize)
471
b = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
472
index_dtype=index_dtype, blocksize=blocksize)
476
self.assertEqual(a, b)
478
ns = [(9, 3), (2, 1), (0, 0)] # (number of dimensions, the corresponding block size)
479
batch_shapes = [(), (2,), (2, 3)]
480
for ((m, bm), (n, bn), b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
481
blocksize = (bm, bn) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
482
run_test((*b, m, n), blocksize, 0, index_dtype)
483
run_test((*b, m, n), blocksize, m * n, index_dtype)
486
@all_sparse_compressed_layouts()
487
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
488
def test_copy_errors(self, layout, device, dtype):
489
blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
490
nnz = 6 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 1
491
shape1 = (2 * 6, 3 * 6) if layout in {torch.sparse_bsr, torch.sparse_bsc} else (2, 3)
492
for index_dtype in [torch.int32, torch.int64]:
493
a = self.genSparseCompressedTensor(shape1, 0, dtype=dtype, layout=layout, device=device,
494
index_dtype=index_dtype, blocksize=blocksize)
496
with self.assertRaisesRegex(RuntimeError,
497
"copy of sparse compressed tensors having different layouts is not supported."):
498
a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
500
b = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device,
501
index_dtype=index_dtype, blocksize=blocksize)
502
assert a._nnz() != b._nnz(), (a._nnz(), b._nnz())
503
with self.assertRaisesRegex(RuntimeError,
504
"only sparse compressed tensors with the same number of specified elements are supported."):
507
shape2 = tuple(reversed(shape1))
508
c = self.genSparseCompressedTensor(shape2, nnz, dtype=dtype, layout=layout, device=device,
509
index_dtype=index_dtype, blocksize=blocksize)
510
with self.assertRaisesRegex(
512
"expected shapes of self and src to match along dimension"):
516
blocksize1 = tuple(reversed(blocksize))
517
d = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device,
518
index_dtype=index_dtype, blocksize=blocksize1)
519
with self.assertRaisesRegex(RuntimeError,
520
"copy of sparse compressed tensors having different block sizes is not supported"):
523
def _smallest_divisor(self, n):
524
for i in range(2, int(n ** 0.5) + 1):
529
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
530
@all_sparse_compressed_layouts()
531
@ops(_sparse_compressed_ops)
532
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
533
def test_consistency(self, layout, device, dtype, op):
534
"""Checks that the op on a strided and on a sparse tensors will
535
produce the same results.
537
if not op.supports_sparse_layout(layout):
538
self.skipTest(f"{op.name} does not support input with {layout} layout")
540
# FIXME: remove in followup once integer support is landed for segment_reduce
541
if (layout == torch.sparse_csr and not dtype.is_floating_point
542
and op.name in ('masked.mean', 'masked.amax', 'masked.amin')):
543
self.skipTest(f"{op.name} does not support input with {layout} layout and {dtype} dtype")
545
require_mask = isinstance(op, ReductionOpInfo) and 'masked.' in op.name
548
for sample in op.sample_inputs(device, dtype):
549
if sample.input.ndim < 2:
551
dense_dim = sample.input.ndim - 2
552
blocksize = (tuple(map(self._smallest_divisor, sample.input.shape[:2]))
553
if layout in {torch.sparse_bsr, torch.sparse_bsc} else None)
556
if isinstance(x, torch.Tensor):
557
if blocksize is None:
558
if x.ndim != sample.input.ndim:
560
elif x.ndim != sample.input.ndim + 2 or x.shape[-3] % blocksize[0] or x.shape[-2] % blocksize[1]:
562
return x.clone().to_sparse(layout=layout, blocksize=blocksize, dense_dim=dense_dim)
565
sparse_sample = sample.transform(_to_sparse)
566
# Some strided samples (with inf, nan elements) appear to share
567
# storage, so we must clone:
568
sample = sample.transform(lambda x: (x.clone() if isinstance(x, torch.Tensor) else x))
570
if validate_sample_input_sparse(op, sparse_sample, check_validate=False) is not sparse_sample:
571
# that is, the validation returns the sparse sample
572
# wrapped within ErrorInput instance
574
samples.append((sample, sparse_sample))
576
# Fail early to prevent silent success with this test
577
if len(samples) == 0:
578
raise ValueError("Expected at least one 2 or higher D tensor in samples.")
580
# Re-define atol and rtol for operations that result values
581
# are random (and hence, non-comparable) be we still want to
582
# check the shape, dtype, etc attributes of the results:
584
if op.name == 'randn_like':
588
for sample, sparse_sample in samples:
589
expected = op(sample.input, *sample.args, **sample.kwargs)
590
assert torch.is_tensor(expected)
591
output = op(sparse_sample.input, *sparse_sample.args, **sparse_sample.kwargs)
592
assert torch.is_tensor(output)
593
strided_output = output.to_dense()
594
if require_mask and sample.kwargs.get('mask') is not None:
595
output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
596
expected.masked_fill_(~output_mask, 0)
597
self.assertEqual(strided_output, expected, atol=atol, rtol=rtol)
600
@all_sparse_compressed_layouts()
601
@all_sparse_compressed_layouts('layout2')
602
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
603
def test_empty_like(self, layout, layout2, device, dtype):
604
for sparse in self.generate_simple_inputs(layout):
605
if layout == layout2:
606
result = torch.empty_like(sparse, layout=layout2)
607
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout]
608
torch._validate_sparse_compressed_tensor_args(compressed_indices_mth(result),
609
plain_indices_mth(result),
613
self.assertEqual(sparse.shape, result.shape)
615
self.assertRaisesRegex(
617
"empty_like with different sparse layout is not supported",
618
lambda: torch.empty_like(sparse, layout=layout2)
622
@all_sparse_compressed_layouts()
623
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
624
def test_validate(self, layout, device, dtype):
625
def make_zero_batched(t):
626
return torch.empty(*((0,) + t.shape), dtype=t.dtype, device=t.device)
628
for index_dtype in [torch.int32, torch.int64]:
629
for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
630
layout, device=device, dtype=dtype, index_dtype=index_dtype, output_tensor=False):
631
size = kwargs['size']
632
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout)
635
torch._validate_sparse_compressed_tensor_args(
636
*(make_zero_batched(t) for t in (compressed_indices, plain_indices, values)),
641
compressed_indices = torch.tensor([0, 0], dtype=index_dtype)
642
plain_indices = torch.tensor([], dtype=index_dtype)
643
torch._validate_compressed_sparse_indices(layout in {torch.sparse_csr, torch.sparse_bsr},
644
compressed_indices, plain_indices, 1, 1, 0)
646
def _generate_invalid_input(self, layout, device):
647
from functools import partial
649
def shape(shape, basedim=0):
651
if layout is torch.sparse_csc:
652
shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:]
653
elif layout is torch.sparse_bsc:
654
shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:]
655
elif layout is torch.sparse_bsr:
656
shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:]
659
def values(lst, device=device):
660
if layout in {torch.sparse_bsr, torch.sparse_bsc}:
661
lst = [[[item]] for item in lst]
662
return torch.tensor(lst, device=device)
664
tensor = partial(torch.tensor, device=device)
665
values = partial(values, device=device)
667
yield ('incontiguous compressed_indices',
668
tensor([0, -1, 2, -1, 4, -1])[::2],
669
tensor([0, 1, 0, 2]),
670
values([1, 2, 3, 4]),
672
'expected compressed_indices to be a contiguous tensor per batch')
674
yield ('incontiguous plain_indices',
676
tensor([0, -1, 1, -1, 0, -1, 2, -1])[::2],
677
values([1, 2, 3, 4]),
679
'expected plain_indices to be a contiguous tensor per batch')
681
yield ('0-D compressed_indices',
683
tensor([0, 1, 0, 2]),
684
values([1, 2, 3, 4]),
686
'compressed_indices must have dimensionality >= 1 but got 0')
688
yield ('compressed/plain_indices mismatch of dimensionalities',
690
tensor([0, 1, 0, 2]),
691
values([1, 2, 3, 4]),
693
'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively')
695
if layout in {torch.sparse_csr, torch.sparse_csc}:
696
yield ('indices and values mismatch of dimensionalities',
698
tensor([[0, 1, 0, 2]]),
699
values([1, 2, 3, 4]),
701
r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1')
703
yield ('indices and values mismatch of dimensionalities',
705
tensor([[0, 1, 0, 2]]),
706
values([1, 2, 3, 4]),
708
r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 2\) but got 3')
710
yield ('invalid size',
712
tensor([0, 1, 0, 2]),
713
values([1, 2, 3, 4]),
715
r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1')
717
yield ('invalid batchsize',
719
tensor([[0, 1, 0, 2]]),
720
values([[1, 2, 3, 4]]),
722
r'all batch dimensions of compressed_indices \(=\[1\]\), plain_indices \(=\[1\]\), '
723
r'and values \(=\[1\]\) must be equal to tensor batch dimensions \(=\[2\]\)')
725
if layout is torch.sparse_bsr:
726
yield ('invalid blocksize',
728
tensor([0, 1, 0, 2]),
729
tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]),
731
r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape')
733
if layout is torch.sparse_bsc:
734
yield ('invalid blocksize',
736
tensor([0, 1, 0, 2]),
737
tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]),
739
r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape')
741
yield ('invalid compressed_indices shape',
742
tensor([0, 2, 3, 4]),
743
tensor([0, 1, 0, 2]),
744
values([1, 2, 3, 4]),
746
r'compressed_indices.shape\[-1\] must be equal to the number of compressed_indices_names \+ 1 \(=3\), but got 4')
748
yield ('invalid compressed_indices shape',
750
tensor([0, 1, 0, 1, 2]),
751
values([1, 2, 3, 4]),
753
r'plain_indices.shape\[-1\] must be equal to nnz \(=4\) as defined by values.shape\[0\], but got 5')
755
yield ('compressed/plain_indices mismatch of dtype',
756
tensor([0, 2, 4], dtype=torch.int32),
757
tensor([0, 1, 0, 2], dtype=torch.int64),
758
values([1, 2, 3, 4]),
760
r'compressed_indices and plain_indices must have the same dtype, bot got Int and Long, respectively')
762
yield ('invalid compressed/plain_indices dtype',
763
tensor([0, 2, 4], dtype=torch.int16),
764
tensor([0, 1, 0, 2], dtype=torch.int16),
765
values([1, 2, 3, 4]),
767
r'compressed_indices and plain_indices dtype must be Int or Long, but got Short')
769
# CUDA kernel asserts are not recoverable, so we skip these for now
770
if torch.device(device).type == 'cpu':
771
yield ('invalid compressed_indices[0]',
773
tensor([0, 1, 0, 2]),
774
values([1, 2, 3, 4]),
776
r'`compressed_indices\[..., 0\] == 0` is not satisfied.')
778
yield ('invalid compressed_indices[0] when nnz == 0',
779
tensor([1, 0], dtype=torch.int64),
780
tensor([], dtype=torch.int64),
783
r'`compressed_indices\[..., 0\] == 0` is not satisfied.')
785
yield ('invalid compressed_indices[-1]',
787
tensor([0, 1, 0, 2]),
788
values([1, 2, 3, 4]),
790
r'`compressed_indices\[..., -1\] == nnz` is not satisfied.')
792
yield ('invalid compressed_indices[-1] when nnz == 0',
793
tensor([0, 1], dtype=torch.int64),
794
tensor([], dtype=torch.int64),
797
r'`compressed_indices\[..., -1\] == nnz` is not satisfied.')
799
yield ('invalid compressed_indices.diff(dim=-1)',
801
tensor([0, 1, 0, 2]),
802
values([1, 2, 3, 4]),
804
r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.')
806
yield ('invalid compressed_indices.diff(dim=-1)',
808
tensor([0, 1, 0, 2]),
809
values([1, 2, 3, 4]),
811
r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.')
813
yield ('invalid min(plain_indices)',
815
tensor([0, -1, 0, 3]),
816
values([1, 2, 3, 4]),
818
r'`0 <= plain_indices < plain_dim` is not satisfied.')
820
yield ('invalid max(plain_indices)',
822
tensor([0, 1, 0, 3]),
823
values([1, 2, 3, 4]),
825
r'`0 <= plain_indices < plain_dim` is not satisfied.')
827
yield ('non-coalesced',
829
tensor([1, 0, 0, 2]),
830
values([1, 2, 3, 4]),
832
r'`plain_indices\[..., compressed_indices\[..., i - 1\]:compressed_indices\[..., i\]\] '
833
'for all i = 1, ..., compressed_dim '
834
'are sorted and distinct along the last dimension values` is not satisfied.')
836
if TEST_CUDA and torch.device(device).type == 'cpu':
837
yield ('indices and values mismatch of device',
838
torch.tensor([0, 2, 4]),
839
torch.tensor([0, 1, 0, 1]),
840
values([1, 2, 3, 4], device='cuda'),
842
r'device of compressed_indices \(=cpu\) must match device of values \(=cuda:0\)')
843
yield ('compressed_indices and values mismatch of device',
844
torch.tensor([0, 2, 4], device='cuda'),
845
torch.tensor([0, 1, 0, 1]),
846
values([1, 2, 3, 4]),
848
r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!')
849
yield ('compressed/plain_indices mismatch of device',
850
torch.tensor([0, 2, 4], device='cuda'),
851
torch.tensor([0, 1, 0, 1]),
852
values([1, 2, 3, 4], device='cuda'),
854
r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!')
856
if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2:
857
yield ('indices and values mismatch of device index',
858
torch.tensor([0, 2, 4], device='cuda:0'),
859
torch.tensor([0, 1, 0, 1], device='cuda:0'),
860
values([1, 2, 3, 4], device='cuda:1'),
862
r'device of compressed_indices \(=cuda:0\) must match device of values \(=cuda:1\)')
863
yield ('compressed_indices and values mismatch of device index',
864
torch.tensor([0, 2, 4], device='cuda:0'),
865
torch.tensor([0, 1, 0, 1], device='cuda:1'),
866
values([1, 2, 3, 4], device='cuda:0'),
868
r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!')
871
@all_sparse_compressed_layouts()
872
@parametrize('target', [subtest('validate_sparse_compressed_tensor_args'),
873
subtest('sparse_compressed_tensor'),
874
subtest('sparse_compressed_tensor_no_size')])
875
def test_invalid_input(self, layout, device, target):
876
for label, compressed_indices, plain_indices, values, size, errmsg in self._generate_invalid_input(layout, device):
877
if layout is torch.sparse_bsr:
878
errmsg = errmsg.replace('compressed_indices_name', 'row block').replace('plain_indices_name', 'column block')
879
elif layout is torch.sparse_bsc:
880
errmsg = errmsg.replace('compressed_indices_name', 'column block').replace('plain_indices_name', 'row block')
881
elif layout is torch.sparse_csr:
882
errmsg = errmsg.replace('compressed_indices_name', 'row').replace('plain_indices_name', 'column')
883
elif layout is torch.sparse_csc:
884
errmsg = errmsg.replace('compressed_indices_name', 'column').replace('plain_indices_name', 'row')
885
if layout in {torch.sparse_csr, torch.sparse_bsr}:
886
errmsg = errmsg.replace('compressed_indices', 'crow_indices') \
887
.replace('plain_indices', 'col_indices') \
888
.replace('plain_dim', 'ncols') \
889
.replace('compressed_dim', 'nrows')
891
errmsg = errmsg.replace('compressed_indices', 'ccol_indices') \
892
.replace('plain_indices', 'row_indices') \
893
.replace('plain_dim', 'nrows') \
894
.replace('compressed_dim', 'ncols')
896
if target == 'sparse_compressed_tensor_no_size' and label in {
897
'invalid size', 'invalid batchsize', 'invalid compressed_indices shape', 'invalid max(plain_indices)',
898
'invalid blocksize'}:
899
# Skip invalid size input as a valid size is estimated for other inputs
902
with self.assertRaisesRegex(RuntimeError, errmsg):
903
if target == 'validate_sparse_compressed_tensor_args':
904
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout)
905
elif target == 'sparse_compressed_tensor':
906
torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout)
907
elif target == 'sparse_compressed_tensor_no_size':
908
torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=layout)
910
raise NotImplementedError(target)
914
@largeTensorTest("30GB", "cpu")
915
def test_invalid_input_csr_large(self):
917
with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in row dimension'):
918
torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int32) // rows,
919
torch.tensor([0], dtype=torch.int32),
920
torch.tensor([1]), (rows, 1))
921
torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int64) // rows,
922
torch.tensor([0], dtype=torch.int64),
923
torch.tensor([1]), (rows, 1))
926
with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in column dimension'):
927
torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int32),
928
torch.tensor([0], dtype=torch.int32),
929
torch.tensor([1]), (1, cols))
930
torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int64),
931
torch.tensor([0], dtype=torch.int64),
932
torch.tensor([1]), (1, cols))
935
with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in nnz'):
936
# nnz cannot be stored in int32 crow_indices
937
# but the `crow_indices[..., -1] == nnz`` check happens after the overflow validation
938
# So we can use `nnz - 1` here to avoid `value cannot be converted to type int32 without overflow`
939
# during construction of crow_indices
940
torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz - 1], dtype=torch.int32),
941
torch.arange(nnz // 2, dtype=torch.int32).repeat(2),
942
torch.ones(nnz, dtype=torch.int8), (2, nnz // 2))
943
torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz], dtype=torch.int64),
944
torch.arange(nnz // 2, dtype=torch.int64).repeat(2),
945
torch.ones(nnz, dtype=torch.int8), (2, nnz // 2))
949
@all_sparse_compressed_layouts()
950
def test_dim(self, layout):
951
for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(layout, output_tensor=False):
952
size = kwargs['size']
953
batch_dim = compressed_indices.dim() - 1
955
block_dim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
956
dense_dim = values.dim() - batch_dim - block_dim - 1
957
sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout)
958
self.assertEqual(sparse.sparse_dim(), sparse_dim)
959
self.assertEqual(sparse.dense_dim(), dense_dim)
963
@all_sparse_compressed_layouts()
964
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
965
def test_to_dtype(self, layout, device, dtype):
966
# to_dense does not support hybrid inputs
967
for sparse in self.generate_simple_inputs(layout, dtype=dtype, device=device, enable_hybrid=False):
968
for to_dtype in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16):
969
sparse_to_dtype = sparse.to(to_dtype)
970
dense_to_dtype = sparse.to_dense().to(to_dtype)
971
self.assertEqual(sparse_to_dtype.to_dense(), dense_to_dtype)
974
@all_sparse_compressed_layouts()
975
@dtypes(torch.double)
976
def test_pickle(self, layout, dtype, device):
979
for sparse in self.generate_simple_inputs(layout, device=device, dtype=dtype):
980
serialized = pickle.dumps(sparse)
981
sparse_loaded = pickle.loads(serialized)
983
self.assertEqual(sparse, sparse_loaded)
985
@all_sparse_compressed_layouts()
986
@parametrize("index_dtype", [torch.int32, torch.int64])
987
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
988
def test_select_copy(self, device, dtype, index_dtype, layout):
990
def is_view_of(base, other):
991
# a shameless copy of TestViewOps.is_view_of
993
not other._is_view() or
995
other._base is not base or
996
base.device != other.device
999
if base.device.type in ('cpu', 'cuda'):
1000
if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
1004
kwargs = dict(device=device, dtype=dtype, index_dtype=index_dtype)
1005
for sparse, dense in zip(self.generate_simple_inputs(layout, **kwargs),
1006
self.generate_simple_inputs(torch.strided, **kwargs)):
1007
if layout in {torch.sparse_csr, torch.sparse_bsr}:
1008
n_batchdim = sparse.crow_indices().ndim - 1
1009
elif layout in {torch.sparse_csc, torch.sparse_bsc}:
1010
n_batchdim = sparse.ccol_indices().ndim - 1
1012
assert 0 # unreachable
1013
self.assertEqual(sparse, dense)
1014
for dim in range(sparse.ndim):
1015
if sparse.shape[dim] == 0:
1016
with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"):
1017
torch.select_copy(sparse, dim, 0)
1018
with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"):
1019
torch.select_copy(dense, dim, 0)
1020
elif n_batchdim and dim >= n_batchdim and dim < n_batchdim + 2:
1021
with self.assertRaisesRegex(
1023
"selecting sparse dimensions is not supported for batched sparse compressed tensors"):
1024
torch.select_copy(sparse, dim, 0)
1026
for index in {0, sparse.shape[dim] // 2, sparse.shape[dim] - 1}:
1027
dense_select = torch.select_copy(dense, dim, index)
1028
sparse_select = torch.select_copy(sparse, dim, index)
1029
self.assertEqual(sparse_select, dense_select)
1030
self.assertFalse(is_view_of(sparse_select.values(), sparse.values()))
1033
def _npref_block_addmm_addmv(c, a, b, alpha, beta):
1034
return alpha * (a @ b) + beta * c
1037
class TestSparseCSR(TestCase):
1039
def test_csr_stride(self):
1040
a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1042
with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"):
1045
with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"):
1048
def test_csr_storage(self):
1049
a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1051
with self.assertRaisesRegex(RuntimeError, "Cannot access storage of SparseCsrTensorImpl"):
1054
def test_csr_is_contiguous(self):
1055
a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1057
with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have is_contiguous"):
1061
@largeTensorTest("20GB", "cpu")
1062
def test_csr_nnz(self):
1063
# Tests the limits of the number of specified elements in CSR tensors, see gh-102520.
1064
for nnz in [0, 2**31]:
1065
rows, cols = 1, max(nnz, 1)
1066
crow_indices = torch.tensor([0, nnz], dtype=torch.int64)
1067
col_indices = torch.arange(nnz, dtype=torch.int64)
1068
values = torch.ones(nnz, dtype=torch.int8)
1069
a = torch.sparse_csr_tensor(crow_indices, col_indices, values, (rows, cols))
1070
self.assertEqual(a._nnz(), nnz)
1072
def test_csr_double_to_sparse_csr(self):
1073
a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1074
a.to_sparse_csr().to_sparse_csr()
1076
@all_sparse_compressed_layouts()
1077
@parametrize("index_dtype", [torch.int32, torch.int64])
1078
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1079
def test_select(self, device, dtype, index_dtype, layout):
1080
compressed_indices_mth = {
1081
torch.sparse_csr: torch.Tensor.crow_indices,
1082
torch.sparse_bsr: torch.Tensor.crow_indices,
1083
torch.sparse_csc: torch.Tensor.ccol_indices,
1084
torch.sparse_bsc: torch.Tensor.ccol_indices,
1087
plain_indices_mth = {
1088
torch.sparse_csr: torch.Tensor.col_indices,
1089
torch.sparse_bsr: torch.Tensor.col_indices,
1090
torch.sparse_csc: torch.Tensor.row_indices,
1091
torch.sparse_bsc: torch.Tensor.row_indices,
1093
create_tensor_mth = {
1094
torch.sparse_csr: torch.sparse_csr_tensor,
1095
torch.sparse_bsr: torch.sparse_bsr_tensor,
1096
torch.sparse_csc: torch.sparse_csc_tensor,
1097
torch.sparse_bsc: torch.sparse_bsc_tensor,
1100
shape = (2, 3, 6, 10)
1102
blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
1103
sparse = self.genSparseCompressedTensor(
1104
shape, nnz, device=device, layout=layout, dtype=dtype, index_dtype=index_dtype, blocksize=blocksize)
1105
comp_indices = compressed_indices_mth(sparse)
1106
plain_indices = plain_indices_mth(sparse)
1107
values = sparse.values()
1109
# select from batch dimensions
1110
sparse_selected12 = sparse.select(1, 2)
1111
expected_sparse_selected12 = create_tensor_mth(comp_indices.select(1, 2).contiguous(),
1112
plain_indices.select(1, 2).contiguous(),
1113
values.select(1, 2).contiguous(),
1117
self.assertEqual(expected_sparse_selected12, sparse_selected12)
1119
# selecting rows/col with batch dims not allowed
1120
sparse_non_batched = sparse[0, 0]
1121
# select from sparse dimensions
1122
for select_args in [(0, 0), (1, 1)]:
1123
sparse_selected = sparse_non_batched.select(*select_args)
1124
dense_selected = sparse_non_batched.to_dense().select(*select_args)
1125
self.assertEqual(dense_selected, sparse_selected)
1127
self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0])
1128
# assigning to sparse through indexing is disabled
1129
with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
1130
sparse[0, 0, 0, 0] = 99.0
1132
# select from sparse dimensions without removing batch dims
1133
msg = "selecting sparse dimensions is not supported for batched sparse compressed tensors."
1134
with self.assertRaisesRegex(RuntimeError, msg):
1135
sparse.select(-2, 0)
1137
with self.assertRaisesRegex(RuntimeError, msg):
1138
sparse.select(-1, 0)
1141
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1142
def test_resize(self, device, dtype):
1146
for s in tensor.shape:
1150
batch_shapes = [(), (2,), (2, 3)]
1151
for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
1154
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1155
self.assertEqual(a.numel(), numel(a))
1157
new_shape = (*b, 4, 5)
1158
a.resize_(new_shape)
1160
self.assertEqual(a.shape, new_shape)
1161
# resize to larger shape doesn't add specified elements
1162
self.assertEqual(a._nnz(), nnz)
1163
self.assertEqual(a.numel(), numel(a))
1165
new_shape = (*b, 1, 5)
1166
a.resize_(new_shape)
1168
self.assertEqual(a.shape, new_shape)
1169
# resize to smaller shape trims specified elements
1170
self.assertEqual(a._nnz(), 5)
1171
self.assertEqual(a.numel(), numel(a))
1173
# trim batched dimensions
1174
a.resize_(new_shape[-2], new_shape[-1])
1175
self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
1176
self.assertEqual(a._nnz(), 5)
1177
self.assertEqual(a.numel(), numel(a))
1180
@dtypes(torch.float, torch.bool)
1181
@all_sparse_compressed_layouts()
1182
def test_resize_as_sparse_compressed(self, device, dtype, layout):
1184
def _check_resize_b_as_a(b, a):
1186
br.resize_as_sparse_(a)
1188
# shape is inherited from a
1189
self.assertEqual(a.shape, br.shape)
1190
# other metadata is not affected
1191
self.assertEqual(b.layout, br.layout)
1192
self.assertEqual(b.device, br.device)
1193
self.assertEqual(b.dtype, br.dtype)
1195
def _get_compressed_plain_inds(t):
1196
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[t.layout]
1197
return compressed_indices_mth(t), plain_indices_mth(t)
1199
br_compressed_indices, br_plain_indices = _get_compressed_plain_inds(br)
1200
br_values = br.values()
1202
b_compressed_indices, b_plain_indices = _get_compressed_plain_inds(b)
1203
a_compressed_indices, a_plain_indices = _get_compressed_plain_inds(a)
1204
self.assertEqual(a_plain_indices.shape, br_plain_indices.shape)
1205
self.assertEqual(a_compressed_indices.shape, br_compressed_indices.shape)
1206
# We don't check the content of br_plain_indices and br_compressed_indices
1207
# because it is not well-defined (the content depends on the original
1208
# shape of `b` that `resize_as` ought to discard) nor needed (the
1209
# subsequent operation likely updates the indices and values of `b` anyway).
1210
# the device/dtype of indices should always be unaffected
1211
self.assertEqual(b_plain_indices.dtype, br_plain_indices.dtype)
1212
self.assertEqual(b_plain_indices.device, br_plain_indices.device)
1213
self.assertEqual(b_compressed_indices.dtype, br_compressed_indices.dtype)
1214
self.assertEqual(b_compressed_indices.device, br_compressed_indices.device)
1215
# values are generated empty, shape is updated
1216
self.assertEqual(a.values().shape, br_values.shape)
1217
# the device/dtype of indices should always be unaffected
1218
b_values = b.values()
1219
self.assertEqual(b_values.dtype, br_values.dtype)
1220
self.assertEqual(b_values.device, br_values.device)
1221
# nnz will be picked up from a via new shape of values
1222
self.assertEqual(a._nnz(), br._nnz())
1224
# post resize the invariants of the layout are respected
1225
torch._validate_sparse_compressed_tensor_args(br_compressed_indices, br_plain_indices, br_values, br.shape,
1228
block_sparse = layout in (torch.sparse_bsr, torch.sparse_bsc)
1229
shape = (2, 1, 6, 4)
1231
blocksize = (2, 1) if block_sparse else ()
1232
for index_dtype in [torch.int32, torch.int64]:
1233
a = self.genSparseCompressedTensor(shape,
1236
index_dtype=index_dtype,
1239
blocksize=blocksize)
1241
# same size, resize should not trigger
1242
b = self.genSparseCompressedTensor(shape,
1245
index_dtype=index_dtype,
1248
blocksize=blocksize)
1250
# This test will not always trigger a resize, if the layouts are the same nothing should happen to b.
1251
# The invariants of the function as checked should still hold
1252
_check_resize_b_as_a(b, a)
1254
# same ndim, but bigger, more nnz, different dtype, different blocksize if blocked
1255
b = self.genSparseCompressedTensor(tuple(s * 2 for s in shape),
1259
index_dtype=torch.int64 if index_dtype == torch.int32 else torch.int32,
1261
blocksize=tuple(2 * bi for bi in blocksize))
1262
_check_resize_b_as_a(b, a)
1264
# different device, only check on cuda pass as we know we are testing in an environment
1265
# that has multiple devices
1267
# TODO: .cpu() does not seem to work correctly for sparse. Causes a call to `copy_` which
1268
# complains about incompatible nnz between src and self?
1269
if torch.device(device).type == 'cuda' and (layout not in (torch.sparse_bsc, torch.sparse_bsr)):
1270
a_cpu = self.genSparseCompressedTensor(shape,
1273
index_dtype=index_dtype,
1276
blocksize=blocksize)
1277
_check_resize_b_as_a(b, a)
1279
# error on a strided
1280
a_strided = a.to_dense()
1281
with self.assertRaisesRegex(
1282
RuntimeError, r'resize_as_sparse_compressed_: src expected sparse compressed tensor layout'):
1283
b.resize_as_sparse_(a_strided)
1285
# error on b strided
1286
b_strided = b.to_dense()
1287
with self.assertRaisesRegex(
1288
RuntimeError, r'resize_as_sparse_compressed_: self expected sparse compressed tensor layout'):
1289
b_strided.resize_as_sparse_(a)
1291
# error if layout does not match, transpose induces layout flip
1292
with self.assertRaisesRegex(RuntimeError,
1293
r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"):
1294
b.transpose(-2, -1).resize_as_sparse_(a)
1295
with self.assertRaisesRegex(RuntimeError,
1296
r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"):
1297
b.resize_as_sparse_(a.transpose(-2, -1))
1300
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1301
def test_resize_errors(self, device, dtype):
1302
for index_dtype in [torch.int32, torch.int64]:
1305
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1307
with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
1309
a.resize_(new_shape)
1311
# resizing of columns to smaller size is not implemented
1312
with self.assertRaisesRegex(
1314
"torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.",
1317
a.resize_(new_shape)
1319
@skipIfTorchDynamo()
1320
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1321
def test_sparse_csr_from_dense(self, device, dtype):
1322
dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device)
1323
sparse = dense.to_sparse_csr()
1324
self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
1325
self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
1326
self.assertEqual(torch.tensor([4, 5, 1], dtype=dtype), sparse.values())
1328
dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=dtype, device=device)
1329
sparse = dense.to_sparse_csr()
1330
self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
1331
self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
1332
self.assertEqual(torch.tensor([1, 1], dtype=dtype), sparse.values())
1334
dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=dtype, device=device)
1335
sparse = dense.to_sparse_csr()
1336
self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
1337
self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
1338
self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values())
1340
def _test_sparse_compressed_to_dense(self, device, dtype, layout):
1341
compressed_format_str = str(layout)[-3:]
1343
def to_compressed(t):
1344
return getattr(t, f"to_sparse_{compressed_format_str}")()
1346
def compressed_constructor(*input, **kwargs):
1347
constructor = getattr(torch, f"sparse_{compressed_format_str}_tensor")
1348
return constructor(*input, **kwargs)
1350
def get_dense_shape(shape, batch_ndim):
1351
if layout is torch.sparse_csc:
1352
compressed_dims_slice = slice(batch_ndim + 1, batch_ndim - 1, -1)
1354
compressed_dims_slice = slice(batch_ndim, batch_ndim + 2)
1355
return shape[:batch_ndim] + shape[compressed_dims_slice] + shape[batch_ndim + 2:]
1357
def transpose(t, batch_ndim):
1358
if layout is torch.sparse_csc:
1359
return t.transpose(batch_ndim, batch_ndim + 1)
1363
for (m, n) in itertools.product(mn, mn):
1365
dense = make_tensor(size, dtype=dtype, device=device)
1366
sparse = to_compressed(dense)
1367
self.assertEqual(sparse.to_dense(), dense)
1369
batch_shape = (2, 3)
1370
compressed_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
1371
plain_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
1372
values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
1373
sparse = compressed_constructor(compressed_indices, plain_indices, values, dtype=dtype, device=device)
1374
dense_shape = get_dense_shape(sparse.shape, len(batch_shape))
1375
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(dense_shape)
1376
self.assertEqual(sparse.to_dense(), transpose(dense, len(batch_shape)))
1378
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1379
def test_sparse_csr_to_dense(self, device, dtype):
1380
self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csr)
1382
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1383
def test_sparse_csc_to_dense(self, device, dtype):
1384
self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csc)
1387
@skipCPUIfNoMklSparse
1389
@dtypes(torch.double)
1390
def test_coo_to_csr_convert(self, device, dtype, coalesced):
1391
with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"):
1392
torch._convert_indices_from_coo_to_csr(
1393
torch.randint(100, (5, 5), device=device),
1399
sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz, coalesced, device, dtype)
1400
sparse_csr = sparse_coo.to_sparse_csr()
1402
self.assertTrue(sparse_csr.is_sparse_csr)
1403
self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense())
1405
vec = torch.randn((5, 1), dtype=dtype, device=device)
1406
coo_product = sparse_coo.matmul(vec)
1407
csr_product = sparse_csr.matmul(vec)
1409
self.assertEqual(coo_product, csr_product)
1411
vec = torch.randn((100, 1), dtype=dtype, device=device)
1412
index = torch.tensor([
1413
[1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
1414
[92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
1415
], dtype=torch.int32)
1416
values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
1417
coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]), dtype=dtype, device=device)
1418
csr = coo.to_sparse_csr()
1420
self.assertEqual(coo.matmul(vec), csr.matmul(vec))
1422
col_indices = torch.tensor([
1423
31, 92, 65, 50, 34, 62, 22, 56, 74, 89
1424
], dtype=torch.int64, device=device)
1425
self.assertEqual(csr.col_indices(), col_indices)
1427
values = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device)
1428
self.assertEqual(csr.values(), values)
1430
@parametrize("blocksize", [2, 4])
1431
@dtypes((torch.double, torch.int32), (torch.double, torch.int64))
1432
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1434
def test_csr_to_block_csr(self, device, dtypes, blocksize):
1435
for shape in [(24, 24), (12, 24)]:
1436
dtype, index_dtype = dtypes
1438
nnz = random.randint(0, m * k)
1439
t = self.genSparseCSRTensor((m * blocksize, k * blocksize), nnz, dtype=dtype,
1440
device=device, index_dtype=index_dtype)
1441
st = sp.csr_matrix((t.values().cpu(), t.col_indices().cpu(), t.crow_indices().cpu()), shape=tuple(t.size()))
1442
block_t = t.to_sparse_bsr((blocksize, blocksize))
1443
self.assertEqual(block_t.values().dim(), 3)
1444
self.assertTrue(block_t.layout == torch.sparse_bsr)
1445
block_st = st.tobsr(blocksize=(blocksize, blocksize))
1446
block_st.sort_indices()
1447
self.assertEqual(block_t.values().cpu(), block_st.data)
1448
self.assertEqual(block_t.col_indices().cpu(), torch.tensor(block_st.indices).to(index_dtype))
1449
self.assertEqual(block_t.crow_indices().cpu(), torch.tensor(block_st.indptr).to(index_dtype))
1451
@dtypes(torch.double)
1452
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1453
def test_csr_to_block_csr_errors(self, device, dtype):
1454
for index_dtype in [torch.int32, torch.int64]:
1456
t = self.genSparseCSRTensor((16, 16), nnz, dtype=dtype,
1457
device=device, index_dtype=index_dtype)
1459
with self.assertRaisesRegex(RuntimeError,
1460
r"tensor sparse size \(.*,.*\) must be divisible by given blocksize \(.*,.*\)"):
1461
block_t = t.to_sparse_bsr((5, 5))
1463
# TODO: Support auto generation of device check for sparse tensors
1464
# See: https://github.com/pytorch/pytorch/issues/59058
1466
@dtypes(torch.double)
1467
def test_matmul_device_mismatch(self, device, dtype):
1468
cpu = torch.rand((10, 10))
1470
for s, m1, m2 in itertools.product((cpu, cuda), repeat=3):
1471
csr = m1.to_sparse()
1472
if s.device == csr.device == m2.device:
1473
torch.addmm(s, csr, m2)
1475
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
1476
torch.addmm(s, csr, m2)
1478
@skipCPUIfNoMklSparse
1479
@skipCUDAIfNoSparseGeneric
1480
@dtypes(*floating_and_complex_types())
1481
@dtypesIfCUDA(*floating_and_complex_types_and(
1482
*[torch.half] if SM53OrLater else [],
1483
*[torch.bfloat16] if SM80OrLater else []))
1484
def test_csr_matvec(self, device, dtype):
1486
if TEST_WITH_ROCM and (dtype == torch.half or dtype == torch.bfloat16):
1487
self.skipTest("ROCm doesn't work with half dtypes correctly.")
1490
for index_dtype in [torch.int32, torch.int64]:
1491
csr = self.genSparseCSRTensor((side, side), 1000, device=device, dtype=dtype, index_dtype=index_dtype)
1492
vec = torch.randn(side, dtype=dtype, device=device)
1494
res = csr.matmul(vec)
1495
expected = csr.to_dense().matmul(vec)
1497
self.assertEqual(res, expected)
1499
bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
1500
err_msg = "size mismatch, got"
1501
with self.assertRaisesRegex(RuntimeError, err_msg):
1505
# hmm, the test passes ok on CUDA when Rocm is not available:
1506
@skipCUDAIfRocmVersionLessThan((5, 2))
1507
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1508
def test_baddbmm(self, device, dtype):
1510
# TODO: disable the invariant checks within torch.baddbmm that
1511
# constructs unconventional csr tensors leading to
1512
# RuntimeError: tensor dimensionality must be sum of batch,
1513
# base, and dense dimensionalities (=0 + 2 + 0) but got 3
1514
# when invariant checking is enabled. When done, undecorate run_test.
1515
@torch.sparse.check_sparse_tensor_invariants(enable=False)
1516
def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
1517
alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1518
beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1519
b = b.mH if (op_b and a.shape == b.shape) else b
1521
actual = torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta)
1523
out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
1524
torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta, out=out)
1526
expected = [torch.addmm(c[i], a, b[i], alpha=alpha, beta=beta) for i in range(c.shape[0])]
1527
expected = torch.stack(expected, 0)
1529
self.assertEqual(actual, out)
1530
self.assertEqual(actual, expected)
1532
for index_dtype in [torch.int32, torch.int64]:
1533
for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]):
1534
nnz = random.randint(0, m * k)
1535
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1537
# a_batched is a regular CSR tensor but with a batch dimension in the shape
1538
a_batched = torch.sparse_csr_tensor(
1539
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
1541
b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1542
c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1543
for op_b, op_out in itertools.product([True, False], repeat=2):
1544
run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
1547
@unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1548
@skipCUDAIfNoSparseGeneric
1549
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1550
def test_bmm(self, device, dtype):
1551
def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
1552
b = b.mH if (op_b and a.shape == b.shape) else b
1554
actual = torch.bmm(a_batched, b)
1556
out = torch.empty_like(actual.mH if op_out and a.shape == b.shape else actual)
1557
torch.bmm(a_batched, b, out=out)
1559
expected = [torch.mm(a, b[i]) for i in range(b.shape[0])]
1560
expected = torch.stack(expected, 0)
1562
self.assertEqual(actual, out)
1563
self.assertEqual(actual, expected)
1565
for index_dtype in [torch.int32, torch.int64]:
1566
for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]):
1567
nnz = random.randint(0, m * k)
1568
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1570
# a_batched is a regular CSR tensor but with a batch
1571
# dimension in the shape. It is unorthodox in PyTorch
1572
# to represent a batch sparse tensor in this way,
1573
# hence checking the tensor invariants is locally
1575
a_batched = torch.sparse_csr_tensor(
1576
a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
1578
b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1579
for op_b, op_out in itertools.product([True, False], repeat=2):
1580
run_test(a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
1582
def run_test_block_addmm_addmv(self,
1592
ref=_npref_block_addmm_addmv):
1593
alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1594
beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1595
b = b.mH if (op_b and a.shape == b.shape) else b
1597
actual = addmv_addmm(c, a, b, alpha=alpha, beta=beta)
1599
out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
1600
addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
1601
expected = ref(c, a, b, alpha, beta)
1603
self.assertEqual(actual, out)
1604
self.assertEqual(actual, expected, lambda msg: f"{msg}\na={a}\nc={c}\nb={b}\nalpha={alpha} beta={beta}")
1606
# TODO: block_size 1 is broken
1607
@parametrize("block_size", [2, 3])
1608
@parametrize("index_dtype", [torch.int32, torch.int64])
1609
@parametrize("noncontiguous", [True, False])
1610
@skipCPUIfNoMklSparse
1611
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1612
@skipIfTorchDynamo("raises 'sparse matrix length is ambiguous; use getnnz()'")
1613
@dtypes(*floating_and_complex_types())
1614
@dtypesIfCUDA(*floating_and_complex_types_and(
1615
*[torch.half] if SM53OrLater else [],
1616
*[torch.bfloat16] if SM80OrLater else []))
1617
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
1618
torch.float64: 1e-5, torch.complex128: 1e-5,
1619
torch.float16: 1e-3, torch.bfloat16: 1e-3})
1620
def test_block_addmm(self, device, dtype, index_dtype, block_size, noncontiguous):
1622
def make_transposed_addmm_op(f):
1625
if isinstance(t, torch.Tensor):
1626
return t.transpose(-2, -1)
1628
# assume numpy/scipy spmatrix
1629
return t.transpose()
1632
def wrapper(c, a, b, alpha=None, beta=None, out=None):
1634
# the ref takes no out kwarg
1635
assert isinstance(out, torch.Tensor)
1636
# transpose inplace to propagate out to checking context
1637
out.transpose_(-2, -1)
1638
return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta, out=out)
1640
return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta)
1644
def ref_sp_numpy(c, a, b, alpha=None, beta=None, out=None):
1648
def to_sp_block_compressed(t):
1650
if t.layout is torch.sparse_bsc:
1651
tt = t.transpose(-1, -2)
1655
t_sp_bsr = sp.bsr_matrix(
1657
tt.values().cpu().numpy(),
1658
tt.col_indices().cpu().numpy(),
1659
tt.crow_indices().cpu().numpy(),
1664
if t.layout is torch.sparse_bsc:
1665
return t_sp_bsr.transpose()
1669
if t.layout is not torch.strided:
1670
return to_sp_block_compressed(t)
1672
return t.cpu().resolve_conj().numpy()
1674
res = _npref_block_addmm_addmv(
1675
*(prep_input(t) for t in (c, a, b)),
1686
def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
1687
res = alpha * (a.to_dense().to(torch.float32) @ b.to_dense().to(torch.float32)).to(a.dtype) + beta * c
1694
if dtype in (torch.half, torch.bfloat16):
1695
ref = ref_half_bfloat16
1699
for (m, n, k) in itertools.product([2, 5], repeat=3):
1700
nnz = random.randint(0, m * k)
1701
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1702
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1703
a_data = a_data.mT if noncontiguous else a_data
1704
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1705
a_data, (m * block_size, k * block_size), check_invariants=False)
1706
b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
1707
c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
1708
for op_b, op_out in itertools.product([True, False], repeat=2):
1709
self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device, ref=ref)
1710
self.run_test_block_addmm_addmv(make_transposed_addmm_op(torch.addmm),
1718
ref=make_transposed_addmm_op(ref))
1720
@parametrize("block_size", [2, 3])
1721
@parametrize("index_dtype", [torch.int32, torch.int64])
1722
@parametrize("noncontiguous", [True, False])
1723
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1724
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1725
def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous):
1726
# TODO: Explicitly disable block size 1 support
1727
# if (TEST_WITH_ROCM or not TEST_CUSPARSE_GENERIC) and block_size == 1:
1729
def ref_block_addmv(c, a, b, alpha, beta):
1730
return _npref_block_addmm_addmv(c, a.to_dense(), b, alpha, beta)
1732
for (m, k) in itertools.product([2, 5], repeat=2):
1733
nnz = random.randint(0, m * k)
1734
if not noncontiguous:
1735
a = self.genSparseCSRTensor((m * block_size, k * block_size), nnz,
1736
dtype=dtype, device=device, index_dtype=index_dtype)
1737
a = a.to_sparse_bsr((block_size, block_size))
1739
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1740
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1741
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
1742
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1743
a_data, (m * block_size, k * block_size), check_invariants=False)
1744
b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
1745
c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
1746
self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device, ref=ref_block_addmv)
1748
@parametrize("matrix_shape", [(3, 3), (5, 7), (11, 9)], name_fn=lambda x: "shape_{}x{}".format(*x))
1749
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1751
def test_addmv(self, device, dtype, matrix_shape):
1752
mat = torch.randn(matrix_shape, dtype=dtype, device=device)
1753
mat[mat.real < 0] = 0
1754
sparse_mat = mat.to_sparse_csr()
1755
mvec = torch.randn((mat.size(1),), dtype=dtype, device=device)
1756
avec = torch.randn((mat.size(0),), dtype=torch.float64, device=device)
1757
ref_output = torch.addmv(avec, mat, mvec)
1758
output = torch.addmv(avec, sparse_mat, mvec)
1759
self.assertEqual(ref_output, output)
1761
@parametrize("block_size", [2, 3])
1762
@parametrize("index_dtype", [torch.int32, torch.int64])
1763
@parametrize("noncontiguous", [True, False])
1764
@skipCPUIfNoMklSparse
1765
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1766
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1767
def test_block_triangular_solve(self, device, dtype, index_dtype, block_size, noncontiguous):
1768
def run_test(a, b, upper, transpose, unitriangular, op_out):
1769
if unitriangular and self.device_type == 'cpu':
1770
# TODO: When unitriangular=True results are not correct on CPU
1773
if not upper and self.device_type == 'cpu':
1774
# TODO: When upper=False some generated inputs might crash on CPU
1777
actual = torch.triangular_solve(b, a, upper=upper, unitriangular=unitriangular, transpose=transpose)
1778
actual_X = actual.solution
1779
actual_A_clone = actual.cloned_coefficient
1780
self.assertTrue(actual_A_clone.numel() == 0)
1782
self.assertTrue(actual_X.isnan().all())
1785
# TODO: replace with torch method when implemented to_dense() on block sparse tensor
1786
a_bsr = sp.bsr_matrix(
1788
a.values().cpu().numpy(),
1789
a.col_indices().cpu().numpy(),
1790
a.crow_indices().cpu().numpy(),
1794
expected_X, _ = torch.triangular_solve(
1796
torch.tensor(a_bsr.todense(), device=device),
1797
transpose=transpose,
1799
unitriangular=unitriangular)
1801
if expected_X.isnan().any():
1802
# TODO: zeros on the diagonal are not handled for CPU path
1803
# there's no way to query this info from MKL
1804
if self.device_type == 'cuda' and not TEST_WITH_ROCM:
1805
self.assertTrue(actual_X.isnan().any() or actual_X.isinf().any())
1808
self.assertEqual(actual_X, expected_X)
1810
out = torch.empty_like(b.mH if op_out and a.shape == b.shape else b)
1811
torch.triangular_solve(
1813
upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
1815
self.assertEqual(out, actual_X)
1816
self.assertEqual(out, expected_X)
1818
for (m, k) in itertools.product([2, 3], [1, 3]):
1819
nnz = random.randint(0, m * m)
1820
if not noncontiguous:
1821
a = self.genSparseCSRTensor((m * block_size, m * block_size), nnz,
1822
dtype=dtype, device=device, index_dtype=index_dtype)
1823
a = a.to_sparse_bsr((block_size, block_size))
1825
a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1826
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1827
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
1828
a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1829
a_data, (m * block_size, m * block_size), check_invariants=False)
1830
b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
1832
for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4):
1833
run_test(a, b, upper, unitriangular, transpose, op_out)
1835
@skipCPUIfNoMklSparse
1836
@unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1837
@dtypes(torch.double)
1838
def test_mm(self, device, dtype):
1839
def test_shape(di, dj, dk, nnz0=None, nnz1=None):
1840
for index_dtype in [torch.int32, torch.int64]:
1841
alpha = random.random()
1842
beta = random.random()
1844
def _test_addmm(t, x, y):
1845
# TODO: addmm doesn't support strided result for sparse inputs.
1846
# res = beta * t + alpha * (x @ y)
1847
res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
1848
expected = torch.addmm(t, x.to_dense(), y.to_dense(), beta=beta, alpha=alpha)
1849
self.assertEqual(res, expected)
1851
res = torch.addmm(t, x, y)
1852
expected = torch.addmm(t, x.to_dense(), y.to_dense())
1853
self.assertEqual(res, expected)
1856
res = torch.mm(x, y)
1857
expected = torch.mm(x.to_dense(), y.to_dense())
1858
if x.layout is torch.strided or y.layout is torch.strided:
1859
self.assertEqual(res.layout, torch.strided)
1861
self.assertEqual(res.layout, torch.sparse_csr)
1862
self.assertEqual(res.to_dense(), expected)
1865
_test_addmm(t, x, y)
1869
nnz0 = random.randint(di * dk // 2, di * dk)
1870
t = torch.randn(di, dj, dtype=dtype, device=device)
1871
x = self.genSparseCSRTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1872
y = torch.randn(dk, dj, dtype=dtype, device=device)
1875
t = torch.randn(di, dj, dtype=dtype, device=device)
1876
x = self.genSparseCSCTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1877
y = torch.randn(dk, dj, dtype=dtype, device=device)
1881
nnz1 = random.randint(dk * dj // 2, dk * dj)
1882
t = torch.randn(di, dj, dtype=dtype, device=device)
1883
x = torch.randn(di, dk, dtype=dtype, device=device)
1884
y = self.genSparseCSRTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1887
t = torch.randn(di, dj, dtype=dtype, device=device)
1888
x = torch.randn(di, dk, dtype=dtype, device=device)
1889
y = self.genSparseCSCTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1892
x_shape, y_shape = x.shape, y.shape
1894
gen_csr_csc = [self.genSparseCSRTensor, self.genSparseCSCTensor]
1896
# Test mm({CSR, CSC}, {CSR, CSC})
1897
for gen_x, gen_y in itertools.product(gen_csr_csc, gen_csr_csc):
1898
x = gen_x(x_shape, nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1899
y = gen_y(y_shape, nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1902
def test_empty_inputs(lhs_layout, rhs_layout):
1903
xd = torch.rand(10, 0, device=device, dtype=dtype)
1904
yd = xd.transpose(-2, -1)
1905
zd = torch.rand(0, 0, device=device, dtype=dtype)
1907
xls, yls, zls = (t.to_sparse(layout=lhs_layout) for t in (xd, yd, zd))
1908
xrs, yrs, zrs = (t.to_sparse(layout=rhs_layout) for t in (xd, yd, zd))
1910
for ls, rs, ld, rd in [(xls, yrs, xd, yd), (xls, zrs, xd, zd), (zls, yrs, zd, yd), (zls, zrs, zd, zd)]:
1911
res_sparse = ls @ rs
1913
self.assertEqual(res_sparse.to_dense(), res_dense)
1915
def test_orthogonal_inputs(lhs_layout, rhs_layout):
1916
ones = torch.ones(2, 2, device=device, dtype=dtype)
1917
zeros = torch.zeros(2, 2, device=device, dtype=dtype)
1918
x = torch.cat((ones, zeros), -1).to_sparse(layout=lhs_layout)
1919
y = torch.cat((zeros, ones), -2).to_sparse(layout=rhs_layout)
1921
res_expected = torch.zeros(*res.shape, device=device, dtype=dtype, layout=res.layout)
1922
self.assertEqual(res, res_expected)
1924
for lhs_layout, rhs_layout in itertools.product([torch.sparse_csr, torch.sparse_csc], repeat=2):
1925
test_empty_inputs(lhs_layout, rhs_layout)
1926
test_orthogonal_inputs(lhs_layout, rhs_layout)
1932
test_shape(4, 4, 4, 0, 0)
1934
@skipCPUIfNoMklSparse
1935
@dtypes(*floating_and_complex_types())
1936
@dtypesIfCUDA(*floating_and_complex_types_and(
1937
*[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1938
*[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1939
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1940
def test_sparse_mm(self, device, dtype):
1941
def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
1943
D = torch.randn(d3, d2, dtype=dtype, device=device).t_()
1945
D = torch.randn(d2, d3, dtype=dtype, device=device)
1946
S = self.genSparseCSRTensor((d1, d2), nnz, device=device, dtype=dtype, index_dtype=index_dtype)
1947
S_dense = S.to_dense()
1948
self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
1950
for index_dtype in [torch.int32, torch.int64]:
1951
test_shape(7, 8, 9, 20, False, index_dtype)
1952
test_shape(7, 8, 9, 20, True, index_dtype)
1954
@dtypes(*floating_and_complex_types())
1955
@dtypesIfCUDA(*floating_and_complex_types_and(
1956
*[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1957
*[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1958
@precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1959
def test_sparse_addmm(self, device, dtype):
1960
def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
1961
if alpha_beta is None:
1962
alpha = random.random()
1963
beta = random.random()
1965
alpha, beta = alpha_beta
1967
D1 = make_tensor((), dtype=dtype, device=device)
1969
D1 = make_tensor([n, p], dtype=dtype, device=device)
1970
D2 = make_tensor([m, p], dtype=dtype, device=device)
1971
S = self.genSparseCSRTensor([n, m], nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1972
S_dense = S.to_dense()
1973
Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1974
Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
1975
self.assertEqual(Y, Y_dense)
1977
for index_dtype in [torch.int32, torch.int64]:
1978
test_shape(7, 8, 9, 20, False, index_dtype, None)
1979
test_shape(7, 8, 9, 20, True, index_dtype, None)
1980
test_shape(7, 8, 9, 20, False, index_dtype, (1, 0))
1981
test_shape(7, 8, 9, 20, True, index_dtype, (1, 0))
1982
test_shape(7, 8, 9, 20, False, index_dtype, (1, 1))
1983
test_shape(7, 8, 9, 20, True, index_dtype, (1, 1))
1985
@skipCPUIfNoMklSparse
1986
@dtypes(*floating_and_complex_types())
1987
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
1988
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
1989
@dtypesIfCUDA(*floating_types_and(torch.complex64,
1990
*[torch.bfloat16] if SM80OrLater else [],
1991
*[torch.half] if SM53OrLater else [],
1992
*[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
1993
@sparse_compressed_nonblock_layouts()
1995
not _check_cusparse_spgemm_available(),
1996
"cuSparse Generic API SpGEMM is not available"
1998
def test_addmm_all_sparse_csr(self, device, dtype, layout):
1999
M = torch.randn(10, 25, device=device).to(dtype)
2000
m1 = torch.randn(10, 50, device=device).to(dtype)
2001
m2 = torch.randn(50, 25, device=device).to(dtype)
2002
_test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse")
2005
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
2006
m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
2007
m2 = torch.randn(50, 25, device=device).to(dtype)
2008
_test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse")
2010
# Test beta=0, M=nan
2011
M = torch.full((10, 25), float('nan'), device=device).to(dtype)
2012
m1 = torch.randn(10, 50, device=device).to(dtype)
2013
m2 = torch.randn(50, 25, device=device).to(dtype)
2014
_test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="all_sparse")
2017
for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
2018
def maybe_transpose(cond, m):
2021
return m.t().clone(memory_format=torch.contiguous_format).t()
2023
M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
2024
m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
2025
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
2026
_test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="all_sparse")
2029
@skipCPUIfNoMklSparse
2030
@dtypes(*floating_and_complex_types())
2031
@sparse_compressed_nonblock_layouts()
2032
def test_addmm_dense_result(self, device, dtype, layout):
2033
M = torch.randn(10, 25, device=device).to(dtype)
2034
m1 = torch.randn(10, 50, device=device).to(dtype)
2035
m2 = torch.randn(50, 25, device=device).to(dtype)
2036
_test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result")
2039
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
2040
m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
2041
m2 = torch.randn(50, 25, device=device).to(dtype)
2042
_test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result")
2044
# Test beta=0, M=nan
2045
M = torch.full((10, 25), float('nan'), device=device).to(dtype)
2046
m1 = torch.randn(10, 50, device=device).to(dtype)
2047
m2 = torch.randn(50, 25, device=device).to(dtype)
2048
_test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="dense_result")
2051
for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
2052
def maybe_transpose(cond, m):
2055
return m.t().clone(memory_format=torch.contiguous_format).t()
2057
M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
2058
m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
2059
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
2060
_test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="dense_result")
2062
@parametrize("k", [0, 1, 8])
2063
@parametrize("n", [0, 1, 10])
2064
@parametrize("m", [0, 1, 25])
2065
@skipCPUIfNoMklSparse
2066
@dtypes(*floating_and_complex_types())
2067
@dtypesIfCUDA(*floating_types_and(torch.complex64,
2068
*[torch.bfloat16] if SM80OrLater else [],
2069
*[torch.half] if SM53OrLater else [],
2071
if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
2073
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
2074
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
2075
def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
2076
if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0):
2077
self.skipTest("Skipped on ROCm")
2078
M = torch.randn(n, m, device=device).to(dtype)
2079
m1 = torch.randn(n, k, device=device).to(dtype)
2080
m2 = torch.randn(k, m, device=device).to(dtype)
2081
_test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, mode="all_sparse")
2083
M = torch.randn(n, m, device=device).to(dtype).to_sparse_csr()
2084
m1 = torch.randn(n, k + 1, device=device).to(dtype).to_sparse_csr()
2085
m2 = torch.randn(k, m, device=device).to(dtype).to_sparse_csr()
2086
self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
2087
self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
2089
@skipCPUIfNoMklSparse
2090
@dtypes(torch.float)
2091
def test_addmm_errors(self, device, dtype):
2092
# test that the errors are the same for dense and sparse versions
2095
def test1(*, is_sparse):
2096
# shapes must be compatible for matrix multiplication
2097
a = make_tensor((2, 3), dtype=dtype, device=device)
2099
a_sparse = a.to_sparse_csr()
2100
return torch.addmm(a, a_sparse, a)
2102
return torch.addmm(a, a, a)
2104
def test2(*, is_sparse):
2105
# mat2 must be a matrix
2106
a = make_tensor((2, 3), dtype=dtype, device=device)
2108
a_sparse = a.to_sparse_csr()
2109
return torch.addmm(a, a_sparse, a.unsqueeze(0))
2111
return torch.addmm(a, a, a.unsqueeze(0))
2113
def test3(*, is_sparse):
2114
# the first input needs to be 1D or 2D
2115
a = make_tensor((3, 3), dtype=dtype, device=device)
2117
a_sparse = a.to_sparse_csr()
2118
return torch.addmm(a.unsqueeze(0), a_sparse, a)
2120
return torch.addmm(a.unsqueeze(0), a, a)
2122
for test in (test1, test2, test3):
2124
test(is_sparse=False)
2125
except RuntimeError as msg:
2126
with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
2127
test(is_sparse=True)
2129
@skipCPUIfNoMklSparse
2130
@dtypes(torch.float)
2131
def test_mm_errors(self, device, dtype):
2132
# test that the errors are the same for dense and sparse versions
2135
def test1(*, is_sparse):
2136
# shapes must be compatible for matrix multiplication
2137
a = make_tensor((2, 3), dtype=dtype, device=device)
2139
a_sparse = a.to_sparse_csr()
2140
return torch.mm(a_sparse, a)
2142
return torch.mm(a, a)
2144
def test2(*, is_sparse):
2145
# mat2 must be a matrix
2146
a = make_tensor((2, 3), dtype=dtype, device=device)
2148
a_sparse = a.to_sparse_csr()
2149
return torch.mm(a_sparse, a.unsqueeze(0))
2151
return torch.mm(a, a.unsqueeze(0))
2153
for test in (test1, test2):
2155
test(is_sparse=False)
2156
except RuntimeError as msg:
2157
with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
2158
test(is_sparse=True)
2160
@sparse_compressed_nonblock_layouts()
2161
@dtypes(torch.float, torch.double)
2162
def test_add(self, device, layout, dtype):
2163
def _test_spadd_shape(nnz, shape):
2164
# sparse.to_dense() uses torch.add internally so if torch.add is wrong,
2165
# the dense tensor will be wrong but this test would still pass
2166
# there's a separate test that checks for the correctness of the .to_dense() call
2167
x = self.genSparseCompressedTensor(shape, nnz,
2170
index_dtype=torch.int32,
2173
y = torch.randn(*shape, dtype=dtype, device=device)
2176
res = torch.add(y, x, alpha=r)
2177
expected = y + r * x.to_dense()
2178
self.assertEqual(res, expected)
2179
res_perm = torch.add(x, y, alpha=r)
2180
self.assertEqual(res_perm, expected)
2182
# Non contiguous dense tensor
2186
y = torch.randn(*s, dtype=torch.double, device=device)
2187
y.transpose_(0, len(s) - 1)
2190
res = torch.add(y, x, alpha=r)
2191
expected = y + r * x.to_dense()
2192
res_perm = torch.add(x, y, alpha=r)
2194
self.assertEqual(res, expected)
2195
self.assertEqual(res_perm, expected)
2199
batch_shapes = [(), (2,), (2, 3)]
2200
for b, m, n in itertools.product(batch_shapes, ns, ns):
2201
_test_spadd_shape(0, (*b, m, n))
2202
_test_spadd_shape(m * n // 2, (*b, m, n))
2203
_test_spadd_shape(m * n, (*b, m, n))
2205
@dtypes(torch.float, torch.double)
2206
def test_mul(self, device, dtype):
2207
# TODO: This whole test should be migrated to OpInfos
2208
def _test_spadd_shape(fn, nnz, shape):
2209
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2210
y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2212
# Forward comparison
2213
res_sparse_sparse = fn(y, x)
2214
res_dense_sparse = fn(y.to_dense(), x)
2215
res_sparse_dense = fn(y, x.to_dense())
2216
expected = fn(y.to_dense(), x.to_dense())
2217
self.assertEqual(res_sparse_sparse, expected)
2218
# TODO: While result of mul(dense, csr) is csr, it is not fully compressed.
2219
# That means it may contain materialized zeros, since the dense argument
2220
# is converted according to the sparsity pattern of csr. In the future
2221
# we might require the result to be fully compressed.
2222
self.assertEqual(res_dense_sparse, expected)
2223
self.assertEqual(res_sparse_dense, expected)
2226
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2227
y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2228
z = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2230
# csr * csr -> csr with csr, csr gradients
2231
x_a = x.clone().requires_grad_()
2232
y_a = y.clone().requires_grad_()
2234
fn(y_a, x_a).backward(z)
2236
x_dense_a = x.to_dense().requires_grad_()
2237
y_dense_a = y.to_dense().requires_grad_()
2239
fn(y_dense_a, x_dense_a).backward(z.to_dense())
2241
self.assertEqual(x_a.grad.layout, torch.sparse_csr)
2242
self.assertEqual(y_a.grad.layout, torch.sparse_csr)
2244
self.assertEqual(x_a.grad.to_dense(), x_dense_a.grad)
2245
self.assertEqual(y_a.grad.to_dense(), y_dense_a.grad)
2247
# TODO: Currently strided Tensors cannot have csr gradients
2248
# dense * csr -> csr with csr, dense gradients
2249
x_a = x.clone().requires_grad_()
2250
y_a = y.to_dense().clone().requires_grad_()
2251
err_msg = "Function MulBackward0 returned an invalid gradient at index 0 - expected layout Strided but got SparseCsr"
2252
with self.assertRaisesRegex(RuntimeError, err_msg):
2253
fn(y_a, x_a).backward(z)
2255
# csr * dense -> csr with dense, csr gradients
2256
x_a = x.to_dense().clone().requires_grad_()
2257
y_a = y.clone().requires_grad_()
2258
err_msg = "Function MulBackward0 returned an invalid gradient at index 1 - expected layout Strided but got SparseCsr"
2259
with self.assertRaisesRegex(RuntimeError, err_msg):
2260
fn(y_a, x_a).backward(z)
2262
_test_spadd_shape(torch.mul, 100, [100, 100])
2263
_test_spadd_shape(torch.mul, 0, [100, 100])
2264
_test_spadd_shape(torch.mul, 100, [100, 1])
2265
_test_spadd_shape(torch.mul, 100, [1, 100])
2267
# TODO: enable hybrid once to_dense supports it
2268
@parametrize('enable_hybrid', [False])
2269
@all_sparse_compressed_layouts()
2270
@dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
2271
def test_mul_scalar(self, layout, device, dtype, enable_hybrid):
2272
for sparse in self.generate_simple_inputs(
2273
layout, device=device, dtype=dtype, index_dtype=torch.int32, enable_hybrid=enable_hybrid):
2274
for scalar_dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
2275
# ComplexHalf is experimental
2276
if dtype is torch.half and scalar_dtype.is_complex:
2279
scalar_t = torch.tensor(2, dtype=scalar_dtype)
2280
for scalar in (scalar_t, scalar_t.item()):
2281
res_out = sparse.mul(scalar)
2282
self.assertEqual(res_out, scalar * sparse)
2284
res_dense_out = sparse.to_dense().mul(scalar)
2285
# BUG: dispatcher ignores mul.Scalar(Tensor, Scalar)
2286
# This issues is circumvented in the mul(Tensor, Tensor) kernel.
2287
self.assertEqual(res_out, res_dense_out)
2289
if dtype == torch.result_type(sparse, scalar):
2290
res_in_dense = sparse.to_dense().mul_(scalar)
2291
res_in = sparse.clone().mul_(scalar)
2292
self.assertEqual(res_in, res_in_dense)
2293
self.assertEqual(res_out, res_in)
2295
@skipCPUIfNoMklSparse
2296
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2297
def test_sparse_add(self, device, dtype):
2298
def run_test(m, n, index_dtype):
2300
alpha = random.random()
2301
nnz1 = random.randint(0, m * n)
2302
nnz2 = random.randint(0, m * n)
2303
nnz3 = random.randint(0, m * n)
2306
# ROCm fails when nnz = 0
2307
nnz1, nnz2, nnz3 = max(1, nnz1), max(1, nnz2), max(1, nnz3)
2309
S1 = self.genSparseCSRTensor([m, n], nnz1, dtype=dtype, device=device, index_dtype=index_dtype)
2310
S2 = self.genSparseCSRTensor([m, n], nnz2, dtype=dtype, device=device, index_dtype=index_dtype)
2311
S3 = self.genSparseCSRTensor([m, n], nnz3, dtype=dtype, device=device, index_dtype=index_dtype)
2312
sparse_args = [S1, S2, S3]
2313
dense_args = [t.to_dense() for t in sparse_args]
2314
arg_idx = list(range(len(sparse_args)))
2315
out_idx = arg_idx + [None]
2317
for idx1, idx2, idx3 in itertools.product(arg_idx, arg_idx, out_idx):
2318
s1 = sparse_args[idx1]
2319
s2 = sparse_args[idx2]
2320
s3 = None if idx3 is None else sparse_args[idx3]
2321
d1 = dense_args[idx1]
2322
d2 = dense_args[idx2]
2323
d3 = None if idx3 is None else dense_args[idx3]
2325
expected = torch.add(d1, d2, alpha=alpha, out=d3)
2326
actual = torch.add(s1, s2, alpha=alpha, out=s3)
2327
self.assertEqual(actual.crow_indices().dtype, index_dtype)
2328
self.assertEqual(actual.col_indices().dtype, index_dtype)
2329
self.assertEqual(actual, expected)
2330
self.assertEqual(s3, d3)
2332
self.assertEqual(s3.crow_indices().dtype, index_dtype)
2333
self.assertEqual(s3.col_indices().dtype, index_dtype)
2335
for index_dtype in [torch.int32, torch.int64]:
2336
for m, n in itertools.product([3, 5], [3, 5]):
2337
run_test(m, n, index_dtype)
2339
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2340
def test_sparse_add_errors(self, device, dtype):
2341
def run_test(index_type):
2342
a = self.genSparseCSRTensor((2, 2), 3, dtype=dtype, device=device, index_dtype=index_dtype)
2343
b = self.genSparseCSRTensor((2, 1), 2, dtype=dtype, device=device, index_dtype=index_dtype)
2344
with self.assertRaisesRegex(RuntimeError, "Expected input tensors to have the same shape"):
2347
for index_dtype in [torch.int32, torch.int64]:
2348
run_test(index_dtype)
2350
@skipCPUIfNoMklSparse
2352
not _check_cusparse_triangular_solve_available(),
2353
"cuSparse Generic API SpSV is not available"
2355
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2356
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2357
torch.float64: 1e-8, torch.complex128: 1e-8})
2358
def test_sparse_triangular_solve(self, device, dtype):
2360
def run_test(n, k, upper, unitriangular, transpose, zero):
2361
if not unitriangular:
2362
triangle_function = torch.triu if upper else torch.tril
2364
# Make sure diagonal elements are not materialized.
2365
# This is to exercise `unitriangular=True` not relying on
2366
# explicit presence of these indices.
2368
def remove_diagonal(t):
2372
def remove_diagonal(t):
2375
triangle_function = remove_diagonal
2377
make_A = torch.zeros if zero else make_tensor
2378
A = make_A((n, n), dtype=dtype, device=device)
2379
A = triangle_function(A)
2380
A_sparse = A.to_sparse_csr()
2381
B = make_tensor((n, k), dtype=dtype, device=device)
2383
expected = torch.triangular_solve(B, A, upper=upper, unitriangular=unitriangular, transpose=transpose)
2384
expected_X = expected.solution
2386
actual = torch.triangular_solve(B, A_sparse, upper=upper, unitriangular=unitriangular, transpose=transpose)
2387
actual_X = actual.solution
2388
actual_A_clone = actual.cloned_coefficient
2389
self.assertTrue(actual_A_clone.numel() == 0)
2390
if A_sparse._nnz() == 0:
2391
self.assertTrue(actual_X.isnan().all())
2393
self.assertEqual(actual_X, expected_X)
2395
# test out with C contiguous strides
2396
out = torch.empty_strided((n, k), (k, 1), dtype=dtype, device=device)
2397
torch.triangular_solve(
2399
upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2401
self.assertEqual(out, expected_X)
2403
# test out with F contiguous strides
2404
out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device)
2405
torch.triangular_solve(
2407
upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2409
self.assertEqual(out, expected_X)
2410
self.assertEqual(out.stride(), (1, n))
2412
# test out with discontiguous strides
2413
out = torch.empty_strided((2 * n, k), (1, 2 * n), dtype=dtype, device=device)[::2]
2415
self.assertFalse(out.is_contiguous())
2416
self.assertFalse(out.t().is_contiguous())
2417
before_stride = out.stride()
2418
torch.triangular_solve(
2420
upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2422
self.assertEqual(out, expected_X)
2423
self.assertEqual(out.stride(), before_stride)
2427
for (k, n), (upper, unitriangular, transpose, zero) in itertools.product(itertools.product(ks, ns),
2428
itertools.product([True, False], repeat=4)):
2429
run_test(n, k, upper, unitriangular, transpose, zero)
2432
not _check_cusparse_sddmm_available(),
2433
"cuSparse Generic API SDDMM is not available"
2435
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2436
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2437
torch.float64: 1e-8, torch.complex128: 1e-8})
2438
def test_sampled_addmm(self, device, dtype):
2439
def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None):
2440
if dtype.is_complex:
2441
alpha = random.random() + 0.3j if alpha is None else alpha
2442
beta = random.random() + 0.6j if beta is None else beta
2444
alpha = random.random() if alpha is None else alpha
2445
beta = random.random() if beta is None else beta
2447
if op_a and a.shape == b.shape:
2449
if op_b and a.shape == b.shape:
2452
actual = torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta)
2454
out = torch.sparse_csr_tensor(
2455
*map(torch.clone, (actual.crow_indices(), actual.col_indices())),
2456
torch.empty_like(actual.values()),
2459
torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
2461
spy_c = torch.sparse_csr_tensor(c.crow_indices(), c.col_indices(), torch.ones_like(c.values()), size=c.shape)
2462
expected = alpha * (a @ b) * spy_c.to_dense() + beta * c.to_dense()
2463
self.assertEqual(actual.to_dense(), out.to_dense())
2464
self.assertEqual(actual.to_dense(), expected)
2466
mnk = list(itertools.product([2, 5], repeat=3))
2468
# Add a test case for size 0 a and b tensors
2469
mnk = mnk + [(5, 5, 0)]
2471
batch_shapes = [(), (2,), (2, 3)]
2473
for index_dtype in [torch.int32, torch.int64]:
2474
for (m, n, k), b, noncontiguous, bcast_c in itertools.product(mnk, batch_shapes, tf, tf):
2475
if bcast_c and len(b) == 0:
2477
nnz = random.randint(0, m * n)
2478
c_batch = () if bcast_c else b
2479
c = self.genSparseCSRTensor((*c_batch, m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2480
a = make_tensor((*b, m, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
2481
b = make_tensor((*b, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
2482
for op_a, op_b in itertools.product([True, False], repeat=2):
2483
run_test(c, a, b, op_a, op_b)
2486
not _check_cusparse_sddmm_available(),
2487
"cuSparse Generic API SDDMM is not available"
2489
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2490
def test_sampled_addmm_autograd(self, device, dtype):
2491
from torch.testing._internal.common_methods_invocations import sample_inputs_sparse_sampled_addmm
2493
samples = list(sample_inputs_sparse_sampled_addmm(None, device, dtype, requires_grad=True))
2495
for sample, dense_covector in zip(samples, [True, False]):
2500
# Compute sparse result
2501
output = torch.sparse.sampled_addmm(c, a, b, **sample.kwargs)
2502
covector = torch.randn_like(output).to_dense() if dense_covector else torch.randn_like(output)
2503
output.backward(covector)
2505
# Compute dense result and compare with sparse result
2506
c1, a1, b1 = (x.detach().to_dense().requires_grad_(True) for x in [c, a, b])
2507
dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
2508
self.assertEqual(output, dense_output)
2509
dense_covector = covector.to_dense()
2510
dense_output.backward(dense_covector)
2511
self.assertEqual(c.grad, c1.grad)
2512
self.assertEqual(a.grad, a1.grad)
2513
self.assertEqual(b.grad, b1.grad)
2516
# It works on ROCm and CUDA issue is currently active
2517
@skipCUDAIf(not TEST_WITH_ROCM, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177")
2519
not _check_cusparse_sddmm_available(),
2520
"cuSparse Generic API SDDMM is not available"
2522
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2523
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2524
torch.float64: 1e-8, torch.complex128: 1e-8})
2525
def test_sampled_addmm_zero_sized(self, device, dtype):
2526
def run_test(c, a, b):
2527
actual = torch.sparse.sampled_addmm(c, a, b)
2528
self.assertEqual(actual.shape, c.shape)
2530
for m, n, k in itertools.product([0, 5], repeat=3):
2531
c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
2532
a = make_tensor((m, k), dtype=dtype, device=device)
2533
b = make_tensor((k, n), dtype=dtype, device=device)
2538
not _check_cusparse_sddmm_available(),
2539
"cuSparse Generic API SDDMM is not available"
2541
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2542
def test_sampled_addmm_errors(self, device, dtype):
2543
# test that the errors are the same for dense and sparse sampled versions
2546
# shapes must be compatible for matrix multiplication
2547
a = make_tensor((2, 3), dtype=dtype, device=device)
2548
a_sparse = a.to_sparse_csr()
2549
with self.assertRaisesRegex(RuntimeError, r"cannot be multiplied"):
2550
torch.sparse.sampled_addmm(a_sparse, a, a)
2552
# mat1 must be a matrix
2553
with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to be a matrix"):
2554
torch.sparse.sampled_addmm(a_sparse, a[..., 0, :], a)
2556
# mat2 must be a matrix
2557
with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to be a matrix"):
2558
torch.sparse.sampled_addmm(a_sparse, a, a[..., 0, :])
2560
a = make_tensor((2, 2), dtype=dtype, device=device)
2561
b = make_tensor((3, 3), dtype=dtype, device=device)
2562
b_sparse = b.to_sparse_csr()
2563
with self.assertRaisesRegex(RuntimeError, r"self.shape\[-2\] must match mat1.shape\[-2\]"):
2564
torch.sparse.sampled_addmm(b_sparse, a, a)
2566
b = make_tensor((2, 3), dtype=dtype, device=device)
2567
b_sparse = b.to_sparse_csr()
2568
with self.assertRaisesRegex(RuntimeError, r"self.shape\[-1\] must match mat2.shape\[-1\]"):
2569
torch.sparse.sampled_addmm(b_sparse, a, a)
2571
a = make_tensor((2, 2), dtype=dtype, device=device)
2572
a_sparse = a.to_sparse_csr()
2573
with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to have strided layout"):
2574
torch.sparse.sampled_addmm(a_sparse, a_sparse, a_sparse)
2576
with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"):
2577
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
2580
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2581
@precisionOverride({torch.bfloat16: 0.01})
2582
def test_sparse_mm_reduce_sum(self, device, dtype):
2583
def run_test(m, n, k, nnz, train):
2584
sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)
2585
dense = sparse.to_dense()
2587
mat = torch.randn(k, n, dtype=dtype)
2588
ref_mat = mat.clone()
2591
sparse.requires_grad_()
2592
mat.requires_grad_()
2593
dense.requires_grad_()
2594
ref_mat.requires_grad_()
2596
ref_out = torch.mm(dense, ref_mat)
2597
out = torch.sparse.mm(sparse, mat, 'sum')
2599
self.assertEqual(out, ref_out)
2602
ref_out.sum().backward()
2603
out.sum().backward()
2605
grad_input = sparse.grad
2606
ref_grad_input = dense.grad
2608
ref_grad_mat = ref_mat.grad
2610
self.assertEqual(grad_input.to_dense(), ref_grad_input)
2611
self.assertEqual(grad_mat, ref_grad_mat)
2613
run_test(4, 5, 4, 10, False)
2614
run_test(4, 4, 4, 16, True)
2616
@skipIfTorchDynamo()
2618
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2619
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
2620
def test_sparse_mm_reduce(self, device, dtype):
2621
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
2622
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2623
mat = torch.randn(n, k, dtype=dtype)
2624
ref_mat = mat.clone()
2625
ref_values = csr.values().clone()
2627
out_int32 = index_dtype == torch.int32
2628
coo_indices = torch._convert_indices_from_csr_to_coo(
2631
out_int32=out_int32)
2632
row, col = coo_indices[0], coo_indices[1]
2634
def ref(row, col, val, mat):
2635
out = torch.zeros([m, k], dtype=dtype)
2636
weight = mat.index_select(0, col)
2637
src = weight.mul(val.view(-1, 1))
2638
index = row.view(-1, 1).expand_as(weight)
2639
index = index.to(dtype=torch.int64)
2640
# scatter_reduce expect index to be int64
2641
out.scatter_reduce_(0, index, src, reduce=reduce_type, include_self=False)
2645
csr.requires_grad_()
2646
mat.requires_grad_()
2647
ref_values.requires_grad_()
2648
ref_mat.requires_grad_()
2650
ref_out = ref(row, col, ref_values, ref_mat)
2651
out = torch.sparse.mm(csr, mat, reduce_type)
2652
self.assertEqual(out, ref_out)
2654
if train and dtype not in (torch.bfloat16, torch.float16):
2655
ref_out.sum().backward()
2656
out.sum().backward()
2658
grad_values = csr.grad.values()
2659
grad_weight = mat.grad
2660
ref_grad_values = ref_values.grad
2661
ref_grad_weight = ref_mat.grad
2662
self.assertEqual(grad_values, ref_grad_values)
2663
self.assertEqual(grad_weight, ref_grad_weight)
2665
for train in [False, True]:
2666
for index_dtype in [torch.int32, torch.int64]:
2667
for reduce_type in ["sum", "mean", "amax", "amin"]:
2668
# by setting nnz < M, create empty rows
2669
run_test(3, 4, 11, 1, reduce_type, index_dtype, train)
2670
run_test(3, 4, 11, 6, reduce_type, index_dtype, train)
2671
run_test(3, 4, 11, 12, reduce_type, index_dtype, train)
2672
# we are doing blocking with 4x vector length in the kernel,
2673
# so need to test when K > 4x vector length
2674
run_test(4, 7, 33, 13, reduce_type, index_dtype, train)
2677
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2678
def test_coo_csr_conversion(self, device, dtype):
2679
for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2681
dense = make_tensor(size, dtype=dtype, device=device)
2682
coo_sparse = dense.to_sparse()
2683
csr_sparse = coo_sparse.to_sparse_csr()
2685
self.assertEqual(csr_sparse.to_dense(), dense)
2688
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2689
def test_csr_coo_conversion(self, device, dtype):
2690
for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2692
dense = make_tensor(size, dtype=dtype, device=device)
2693
csr_sparse = dense.to_sparse_csr()
2694
coo_sparse = csr_sparse.to_sparse()
2696
self.assertEqual(coo_sparse.to_dense(), dense)
2698
# Currently, there is no rule in PyTorch for filling zeros in the outputs
2699
# from operations on Sparse CSR tensors. Hence only those operators are supported
2700
# which have 0->0 correspondence, example: sin(0) = 0, tan(0) = 0 but
2701
# cos(0) = 1 (and hence it's not supported).
2702
# Note: here, we do this test only for unary operators
2703
@ops(sparse_csr_unary_ufuncs)
2704
def test_zero_to_zero_correspondence_unary(self, device, dtype, op):
2705
zero = torch.zeros((1, 2), dtype=dtype, device=device)
2706
tensor_explicit_zeros = torch.sparse_csr_tensor([0, 1], [1], [0], dtype=dtype, device=device)
2708
output_zero = op(zero)
2709
expected_zero = zero.to(output_zero.dtype)
2711
output_explicit_zeros = op(tensor_explicit_zeros).to_dense()
2712
expected_explicit_zeros = tensor_explicit_zeros.to_dense().to(output_explicit_zeros.dtype)
2714
for (output, expected) in [
2715
(output_zero, expected_zero),
2716
(output_explicit_zeros, expected_explicit_zeros)
2718
self.assertEqual(output, expected, f"This operator ({op.name}) should not be supported for "
2719
"Sparse CSR as it breaks 0->0 correspondence.")
2721
for inp in [zero.to_sparse_csr(), tensor_explicit_zeros]:
2722
self.assertEqual(op(inp).values().numel(), inp.values().numel(),
2723
f"{op.name} fails to preserve sparsity pattern.")
2725
@ops(sparse_csr_unary_ufuncs)
2726
def test_sparse_csr_unary_out(self, device, dtype, op):
2727
samples = op.sample_inputs(device, dtype)
2729
if not op.supports_out:
2730
self.skipTest("Skipped! Out not supported")
2732
for sample in samples:
2733
assert torch.is_tensor(sample.input)
2734
# Sparse CSR only supports 2D tensors as inputs
2735
# Fail early to prevent silent success with this test
2736
if sample.input.ndim != 2:
2737
raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")
2739
sample.input = sample.input.to_sparse_csr()
2740
expect = op(sample.input, *sample.args, **sample.kwargs)
2742
out = self.genSparseCSRTensor(sample.input.size(), sample.input._nnz(),
2743
device=sample.input.device, dtype=expect.dtype,
2744
index_dtype=sample.input.crow_indices().dtype)
2745
op(sample.input, *sample.args, **sample.kwargs, out=out)
2747
self.assertEqual(out, expect)
2749
@ops(sparse_csr_unary_ufuncs)
2750
def test_sparse_csr_unary_inplace(self, device, dtype, op):
2751
samples = op.sample_inputs(device, dtype)
2753
if op.inplace_variant is None:
2754
self.skipTest("Skipped! Inplace variant not supported!")
2756
for sample in samples:
2757
assert torch.is_tensor(sample.input)
2758
# Sparse CSR only supports 2D tensors as inputs
2759
# Fail early to prevent silent success with this test
2760
if sample.input.ndim != 2:
2761
raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")
2763
sample.input = sample.input.to_sparse_csr()
2764
expect = op(sample.input, *sample.args, **sample.kwargs)
2766
if not torch.can_cast(expect.dtype, dtype):
2767
with self.assertRaisesRegex(RuntimeError, "result type"):
2768
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2771
if sample.input.is_complex() and op.name == "abs":
2772
with self.assertRaisesRegex(RuntimeError, "not supported"):
2773
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2776
actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2778
self.assertIs(actual, sample.input)
2779
self.assertEqual(actual, expect)
2781
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
2782
@ops(sparse_csr_unary_ufuncs, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble])
2783
def test_autograd_sparse_csr_unary(self, device, dtype, op):
2784
if op.name not in UNARY_EWISE_CSR_ALLOW_AUTOGRAD:
2785
self.skipTest(f"Skipped! Unary op {op.name} not supported with CSR input and autograd")
2787
samples = list(op.sample_inputs(device, dtype))
2789
# Fail early to prevent silent success with this test
2790
ndims_equals_2d = (s.input.ndim == 2 for s in samples)
2791
if not any(ndims_equals_2d):
2792
raise ValueError("Expected at least one 2D tensor in samples.")
2794
for sample in samples:
2795
# We must skip samples of low dimensionality, we can't covert them to sparsed compressed layouts
2796
if sample.input.ndim < 2:
2798
sparse_input = sample.input.to_sparse_csr().requires_grad_(True)
2801
output = op.gradcheck_wrapper(op.get_op(), input, *sample.args, **sample.kwargs)
2802
if sample.output_process_fn_grad is not None:
2803
return sample.output_process_fn_grad(output)
2806
# Compute sparse result
2807
output = fn(sparse_input)
2808
covector = torch.randn_like(output)
2809
output.backward(covector)
2810
self.assertTrue(torch.is_tensor(sparse_input.grad))
2811
self.assertTrue(sparse_input.grad.is_sparse_csr)
2813
# Compute dense result and compare with sparse result
2814
dense_input = sparse_input.detach().to_dense().requires_grad_(True)
2815
dense_output = fn(dense_input)
2816
dense_covector = covector.to_dense()
2817
dense_output.backward(dense_covector)
2818
self.assertEqual(sparse_input.grad, dense_input.grad)
2821
not _check_cusparse_sddmm_available(),
2822
"cuSparse Generic API SDDMM is not available"
2824
@dtypes(torch.float64)
2825
def test_autograd_dense_output_addmm(self, device, dtype):
2826
from torch.testing._internal.common_methods_invocations import sample_inputs_addmm
2828
samples = list(sample_inputs_addmm(None, device, dtype, requires_grad=True))
2830
# Fail early to prevent silent success with this test
2831
ndims_equals_2d = (s.args[0].ndim == 2 for s in samples)
2832
if not any(ndims_equals_2d):
2833
raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.")
2835
for sample in samples:
2836
a = sample.args[0].relu().to_sparse_csr()
2837
if sample.args[0].shape == sample.args[1].shape:
2839
warnings.warn("Broken for square matrices, see https://github.com/pytorch/pytorch/issues/116565")
2842
# This path tests the autograd path wrt dense inputs
2843
for addmm in [torch.addmm, torch.sparse.addmm]:
2846
output = addmm(c, a, b, **sample.kwargs)
2847
if sample.output_process_fn_grad is not None:
2848
return sample.output_process_fn_grad(output)
2851
self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True))
2854
c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2855
b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2856
self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True))
2858
# Now test the autograd path wrt sparse inputs
2859
for reverse in [True, False]:
2860
c, b = sample.input, sample.args[1]
2861
if reverse and a.shape != b.shape:
2865
inputs = (c, b, a) if reverse else (c, a, b)
2866
output = addmm(*inputs, **sample.kwargs)
2867
if sample.output_process_fn_grad is not None:
2868
return sample.output_process_fn_grad(output)
2871
# gradcheck doesn't work for sparse CSR yet, compare against dense path
2872
# Compute sparse result
2873
a = a.detach().requires_grad_(True)
2875
covector = torch.randn_like(output)
2876
output.backward(covector)
2877
self.assertTrue(torch.is_tensor(a.grad))
2878
if addmm == torch.sparse.addmm:
2879
self.assertTrue(a.grad.is_sparse_csr)
2881
self.assertTrue(a.grad.layout == torch.strided)
2883
# Compute dense result and compare with sparse result
2884
dense_a = a.detach().to_dense().requires_grad_(True)
2885
dense_output = fn(dense_a)
2886
self.assertEqual(output, dense_output)
2887
dense_covector = covector.to_dense()
2888
dense_output.backward(dense_covector)
2890
if addmm == torch.sparse.addmm:
2891
self.assertEqual(a.grad, dense_a.grad.sparse_mask(a))
2893
self.assertEqual(a.grad, dense_a.grad)
2895
@skipCPUIfNoMklSparse
2896
@dtypes(torch.float64)
2897
def test_autograd_dense_output_addmv(self, device, dtype):
2898
from torch.testing._internal.common_methods_invocations import sample_inputs_addmv
2900
samples = list(sample_inputs_addmv(None, device, dtype, requires_grad=True))
2902
# Fail early to prevent silent success with this test
2903
ndims_equals_2d = (s.args[0].ndim == 2 for s in samples)
2904
if not any(ndims_equals_2d):
2905
raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.")
2907
for sample in samples:
2908
# TODO: Remove detach once we have autograd support for CSR input
2909
a = sample.args[0].to_sparse_csr().detach()
2912
output = torch.addmv(c, a, b, **sample.kwargs)
2913
if sample.output_process_fn_grad is not None:
2914
return sample.output_process_fn_grad(output)
2917
self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True))
2920
c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2921
b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2922
self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True))
2924
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
2925
@ops(binary_ops_with_dense_output, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, ])
2926
def test_autograd_dense_output(self, device, dtype, op):
2927
if op.name == "mv" and no_mkl_sparse and self.device_type == 'cpu':
2928
self.skipTest("MKL Sparse is not available")
2930
samples = list(op.sample_inputs(device, dtype, requires_grad=True))
2932
# Fail early to prevent silent success with this test
2933
ndims_equals_2d = (s.input.ndim == 2 for s in samples)
2934
if not any(ndims_equals_2d):
2935
raise ValueError("Expected at least one 2D tensor in samples.")
2937
# Here we assume that the signature is op(sparse_input, dense_input) -> dense_output
2938
for sample in samples:
2939
# TODO: Remove detach once we have autograd support for CSR input
2940
sparse_input = sample.input.to_sparse_csr().detach()
2943
output = op.gradcheck_wrapper(op.get_op(), sparse_input, *args, **sample.kwargs)
2944
if sample.output_process_fn_grad is not None:
2945
return sample.output_process_fn_grad(output)
2948
self.assertTrue(torch.autograd.gradcheck(fn, sample.args, fast_mode=True))
2951
args = [make_tensor(a.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) for a in sample.args]
2952
self.assertTrue(torch.autograd.gradcheck(fn, args, fast_mode=True))
2954
@dtypes(*all_types_and_complex())
2955
def test_direct_coo_csr_conversion(self, device, dtype):
2956
for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2958
dense = make_tensor(size, dtype=dtype, device=device)
2959
coo_sparse = dense.to_sparse_coo()
2961
self.assertEqual(coo_sparse.to_sparse_csr().to_sparse_coo(), coo_sparse)
2964
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2965
def test_sum(self, device, dtype):
2966
def run_test(shape, nnz, index_type):
2967
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2968
self.assertEqual(a.sum(), a.values().sum())
2969
if dtype in floating_types():
2970
a.requires_grad_(True)
2972
self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device))
2973
for shape, index_dtype in itertools.product(
2974
[(10, 5), (10, 10)],
2975
[torch.int32, torch.int64]):
2976
run_test(shape, 0, index_dtype)
2977
run_test(shape, max(shape), index_dtype)
2978
run_test(shape, shape[0] * shape[1], index_dtype)
2980
@skipIfTorchDynamo()
2982
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2983
@all_sparse_compressed_layouts()
2984
def test_transpose(self, device, dtype, layout):
2986
def _check_transpose_view(subject, transpose):
2987
self.assertTrue(transpose.values()._is_view())
2988
self.assertTrue(transpose._is_view())
2989
self.assertTrue(transpose._base is subject)
2991
def _check_layout_invariants(transpose):
2992
self.assertEqual(transpose.device, torch.device(device))
2993
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[transpose.layout]
2994
compressed_indices, plain_indices = compressed_indices_mth(transpose), plain_indices_mth(transpose)
2995
torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, transpose.values(),
2996
transpose.shape, transpose.layout)
2998
def check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout):
2999
transpose = subject.transpose(dim0, dim1)
3001
self.assertEqual(transpose.layout, expected_layout)
3002
# transpose must be return a view
3003
_check_transpose_view(subject, transpose)
3004
# result uses unsafe construction, so we check invariants
3005
_check_layout_invariants(transpose)
3006
self.assertEqual(transpose.to_dense(), subject_dense.transpose(dim0, dim1))
3008
round_trip = transpose.transpose(dim0, dim1)
3009
self.assertEqual(round_trip.layout, subject.layout)
3010
# transpose must be return a view
3011
_check_transpose_view(subject, round_trip)
3012
# result uses unsafe construction, so we check invariants
3013
_check_layout_invariants(round_trip)
3014
self.assertEqual(round_trip.to_dense(), subject_dense)
3016
def check_same_dim_transpose(subject, subject_dense, dim):
3017
transpose = subject.transpose(dim, dim)
3019
self.assertEqual(transpose.layout, subject.layout)
3020
# transpose must be return a view
3021
_check_transpose_view(subject, transpose)
3022
# result uses unsafe construction, so we check invariants
3023
_check_layout_invariants(transpose)
3024
self.assertEqual(transpose.to_dense(), subject_dense)
3026
def check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1):
3027
mismatch_name = f"{dim0}\\({name0}\\) and {dim1}\\({name1}\\)"
3028
err = r"transpose\(\): can only transpose dimensions of the same type \(Batch, Sparse, Dense\), got " + mismatch_name
3030
with self.assertRaisesRegex(RuntimeError, err):
3031
subject.transpose(dim0, dim1)
3033
def run_test(shape, nnz, index_type, n_dense, blocksize=()):
3034
subject = self.genSparseCompressedTensor(shape,
3038
index_dtype=index_type,
3039
blocksize=blocksize,
3044
sparse0 = len(shape) - n_dense - 1
3045
sparse1 = sparse0 - 1
3047
dense0 = sparse0 + 1 if n_dense > 0 else None
3048
dense1 = dense0 + 1 if n_dense > 1 else None
3050
n_batch = len(shape) - n_dense - 2
3051
batch0 = sparse1 - 1 if n_batch > 0 else None
3052
batch1 = 0 if n_batch > 1 else None
3054
sparse_dims = (sparse0, sparse1)
3055
dense_dims = (dense0, dense1)
3056
batch_dims = (batch0, batch1)
3058
named0 = [(name, d[0]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))]
3059
named1 = [(name, d[1]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))]
3062
torch.sparse_csr: torch.sparse_csc,
3063
torch.sparse_csc: torch.sparse_csr,
3064
torch.sparse_bsr: torch.sparse_bsc,
3065
torch.sparse_bsc: torch.sparse_bsr
3068
# expect all transpose to throw
3069
for (name0, dim0), (name1, dim1) in itertools.product(named0, named1):
3070
msg = r"transpose\(\): hybrid sparse compressed tensors with dense dimensions are not supported"
3071
if (dim0 is not None) and (dim1 is not None):
3072
with self.assertRaisesRegex(RuntimeError, msg):
3073
subject.transpose(dim0, dim1)
3075
subject_dense = subject.to_dense()
3076
for (name0, dim0), (name1, dim1) in itertools.product(named0, named1):
3077
if dim0 is not None:
3078
check_same_dim_transpose(subject, subject_dense, dim0)
3080
if dim1 is not None:
3082
expected_layout = flipped_layout if name0 == "Sparse" else layout
3083
check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout)
3085
check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1)
3087
# batch/sparse, sparse/dense only and full hybrid cases
3088
shape_ndense = list(itertools.product([(2, 4, 6, 2), (10, 6, 4, 2), (2, 4, 4, 2, 6)], [0, 1, 2]))
3090
shape_ndense += [[(4, 8), 0], [(2, 2), 0], [(8, 4), 0]]
3091
for (shape, n_dense), index_dtype in itertools.product(shape_ndense, [torch.int32, torch.int64]):
3092
n_batch = len(shape) - n_dense - 2
3093
sparse_shape = shape[n_batch: n_batch + 2]
3094
if layout in (torch.sparse_bsr, torch.sparse_bsc):
3095
# for blocked all combinations of 2,1 should be valid blocksizes
3096
run_test(shape, 0, index_dtype, n_dense, blocksize=(2, 2))
3097
run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 2))
3098
run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense, blocksize=(2, 2))
3099
# repeat the realistic sparseity case with varried block sizes
3100
run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 1))
3101
run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 2))
3102
run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 1))
3104
run_test(shape, 0, index_dtype, n_dense)
3105
run_test(shape, max(sparse_shape), index_dtype, n_dense)
3106
run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense)
3108
# TODO: This is a stopgap for a rigorous extension of our autograd tests
3109
# to test the functionality of detach
3111
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3112
def test_exercise_detach(self, device, dtype):
3115
for index_dtype in [torch.int32, torch.int64]:
3116
inp = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
3117
detached_inp = inp.detach()
3118
self.assertEqual(inp, detached_inp)
3120
def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)):
3121
if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]:
3122
tensor = tensor.to_dense()
3124
raise NotImplementedError(repr(tensor))
3125
if layout is torch.sparse_csr:
3126
return sp.csr_matrix(tensor.cpu().numpy())
3127
if layout is torch.sparse_csc:
3128
return sp.csc_matrix(tensor.cpu().numpy())
3129
if layout is torch.sparse_bsr:
3130
return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
3131
if layout is torch.sparse_bsc:
3132
# SciPy doesn't have native BSC support - but our tests don't need the full
3133
# functionality so fake it by using a transposed BSR matrix.
3134
class FakeBscMatrix:
3135
def __init__(self, matrix):
3136
self._matrix = matrix
3137
self.shape = tuple(reversed(matrix.shape))
3138
self.indptr = matrix.indptr
3139
self.indices = matrix.indices
3140
self.data = [x.transpose() for x in matrix.data]
3143
def from_matrix(matrix, blocksize):
3144
blocksize = tuple(reversed(blocksize))
3145
matrix = matrix.transpose()
3146
return FakeBscMatrix(sp.bsr_matrix(matrix, blocksize=blocksize))
3148
def sorted_indices(self):
3149
sub = self._matrix.sorted_indices()
3150
return FakeBscMatrix(sub)
3152
return FakeBscMatrix.from_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
3153
raise NotImplementedError(repr(tensor))
3156
@all_sparse_compressed_layouts('to_layout')
3157
@all_sparse_compressed_layouts('from_layout')
3158
def test_compressed_layout_conversions_coverage(self, device, from_layout, to_layout):
3159
"""This test performs a smoke test for covered conversion and verifies
3160
that an exception is thrown for unsupported conversions.
3162
TODO: This test covers a subset of
3163
TestSparseAny.test_to_sparse tests and can be
3164
eliminated. Keeping the test until the new
3165
`Tensor.to_sparse(*, layout, blocksize)` has landed.
3168
allowed_pairwise_layouts_sets = {
3169
frozenset({torch.sparse_csc}),
3170
frozenset({torch.sparse_csr}),
3171
frozenset({torch.sparse_csc, torch.sparse_csr}),
3172
frozenset({torch.sparse_csc, torch.sparse_bsc}),
3173
frozenset({torch.sparse_csc, torch.sparse_bsr}),
3174
frozenset({torch.sparse_csr, torch.sparse_bsc}),
3175
frozenset({torch.sparse_csr, torch.sparse_bsr}),
3176
frozenset({torch.sparse_bsc}),
3177
frozenset({torch.sparse_bsr}),
3178
frozenset({torch.sparse_bsc, torch.sparse_bsr}),
3180
block_layouts = (torch.sparse_bsr, torch.sparse_bsc)
3182
def _to_from_layout(layout_a, layout_b, a):
3184
if {layout_a, layout_b} in allowed_pairwise_layouts_sets:
3185
expect_error = False
3187
# BSR -> CSR is not yet supported
3188
if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csr):
3190
# BSR -> CSC is not yet supported
3191
if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csc):
3193
# BSC -> CSR is not yet supported
3194
if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csr):
3196
# BSC -> CSC is not yet supported
3197
if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csc):
3199
# CSR -> BSR only works for non-batched inputs
3200
if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsr):
3203
# CSR -> BSC only works for non-batched inputs
3204
if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsc):
3207
# CSC -> BSR only works for non-batched inputs
3208
if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsr):
3211
# CSC -> BSC only works for non-batched inputs
3212
if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsc):
3216
blocksize_a = (1, 1) if layout_a in {torch.sparse_bsr, torch.sparse_bsc} else None
3217
blocksize_b = (1, 1) if layout_b in {torch.sparse_bsr, torch.sparse_bsc} else None
3218
b = a.to_sparse(layout=layout_a, blocksize=blocksize_a)
3220
with self.assertRaises(RuntimeError):
3221
b.to_sparse(layout=layout_b, blocksize=blocksize_b)
3223
c = b.to_sparse(layout=layout_b, blocksize=blocksize_b)
3224
self.assertEqual(a.to_dense(), c.to_dense())
3226
# change of blocksize upon conversion is not yet supported.
3227
if b.layout in block_layouts:
3228
for block_layout in block_layouts:
3229
with self.assertRaisesRegex(RuntimeError,
3230
"conversion from.*to.*with blocksize changed from.*to.*is not supported"):
3231
b.to_sparse(layout=block_layout, blocksize=(3, 3))
3233
batch_dims = [(), (2,), (2, 2), (2, 2, 2)]
3234
sparse_dims = (6, 12)
3235
for batch_dim in batch_dims:
3236
a = make_tensor(batch_dim + sparse_dims, dtype=torch.float, device=device)
3237
_to_from_layout(from_layout, to_layout, a)
3240
@all_sparse_compressed_layouts()
3241
@batched_nonbatched()
3243
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
3244
def test_dense_to_from_sparse_compressed(self, device, hybrid, batched, layout):
3245
"""This test tests conversion from dense to/from CSR and CSC
3246
by comparing to SciPy's implementation.
3248
Here we test only those conversion combinations that SciPy
3249
supports to ensure that PyTorch conversions are in the same
3250
page with SciPy. Independent from SciPy, all conversion
3251
combinations are tested in TestSparseAny.test_to_sparse.
3254
blocked_layouts = (torch.sparse_bsr, torch.sparse_bsc)
3258
def _check_against_scipy_matrix(pt_matrix, dense, blocksize, **kwargs):
3259
# scipy has no bsc layout, so we check against the bsr layout of the tranposed dense
3260
if layout == torch.sparse_bsc:
3261
sp_matrix = self._construct_sp_matrix(dense.t(), layout=torch.sparse_bsr, blocksize=blocksize[::-1])
3263
sp_matrix = self._construct_sp_matrix(dense, layout=layout, blocksize=blocksize)
3265
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3267
self.assertEqual(layout, pt_matrix.layout)
3268
if layout == torch.sparse_bsc:
3269
self.assertEqual(sp_matrix.shape[::-1], pt_matrix.shape)
3271
self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3273
self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3274
self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3275
if layout == torch.sparse_bsc:
3276
# we must tranpose the blocks before comparing
3277
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values().transpose(-2, -1))
3279
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3281
def _check_hybrid_matrix(pt_matrix, dense, blocksize, **kwargs):
3282
# Calculate COO indices for sparse matrix.
3283
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3284
compressed_indices = compressed_indices_mth(pt_matrix)
3285
plain_indices = plain_indices_mth(pt_matrix)
3286
coo_indices = torch._convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
3287
row_indices, col_indices = {
3288
torch.sparse_csr: (coo_indices[0, ], coo_indices[1, ]),
3289
torch.sparse_csc: (coo_indices[1, ], coo_indices[0, ]),
3290
torch.sparse_bsr: (coo_indices[0, ], coo_indices[1, ]),
3291
torch.sparse_bsc: (coo_indices[1, ], coo_indices[0, ]),
3294
# If sparse matrix layout blocked, rearrange dense matrix
3295
# so that the shape past first two dimensions match the
3296
# shape of sparse matrix values.
3297
dense_to_check = dense
3299
dense_shape = dense.shape
3300
dense_to_check_shape = (dense.shape[0] // blocksize[0],
3302
dense.shape[1] // blocksize[1],
3303
blocksize[1]) + dense.shape[2:]
3304
dense_to_check = dense_to_check.reshape(dense_to_check_shape).transpose(1, 2)
3306
# Verify that non-zero values of the sparse matrix are
3307
# equal to corresponding values of the dense matrix.
3308
self.assertEqual(pt_matrix.values(), dense_to_check[row_indices, col_indices])
3310
# Verify that the remaining elements of the dense matrix
3311
# are 0, i.e. that dense are sparse matrix are fully
3313
mask = torch.ones_like(dense_to_check, dtype=torch.bool)
3314
mask[row_indices, col_indices] = False
3315
self.assertTrue(torch.all(torch.masked_select(dense_to_check, mask) == 0))
3317
def _check_batched(pt_tensor, dense, check_batch=None, batch_shape=(), blocksize=(), **kwargs):
3318
self.assertEqual(layout, pt_tensor.layout)
3319
self.assertEqual(pt_tensor.shape, dense.shape)
3320
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3321
for batch_index in np.ndindex(batch_shape):
3322
pt_matrix = pt_tensor[batch_index]
3323
dense_matrix = dense[batch_index]
3324
dense_dim = pt_matrix.dim() - 2
3325
dense_matrix_pt = dense_matrix.to_sparse(layout=layout,
3326
blocksize=blocksize or None,
3327
dense_dim=dense_dim)
3328
# sanity check, selecting batch of to_<layout> and dense[batch].to_<layout> should give the same result
3329
self.assertEqual(pt_matrix, dense_matrix_pt)
3330
check_batch(pt_matrix, dense_matrix, blocksize, **kwargs)
3332
def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
3333
shape = batch_shape + sparse_shape + hybrid_shape
3334
n_batch_dim = len(batch_shape)
3335
n_hybrid_dim = len(hybrid_shape)
3336
# generate a dense tensor
3337
dense = make_tensor(shape, dtype=torch.float, device=device)
3339
# introduce some sparsty, mask is sparse shape, element applies to entire dense sub-tensor (hybrid) and is
3340
# applied to each batch
3341
mask = make_tensor(sparse_shape, dtype=torch.bool, device=device)
3342
# manually expand to match hybrid shape
3344
mask = mask.view(sparse_shape + tuple(1 for _ in range(n_hybrid_dim)))
3345
mask = mask.expand(sparse_shape + hybrid_shape)
3347
# mask will broadcast over the batch dims if present
3351
# note: order is important here, the hybrid-ness decides the inner content check which is used to build the
3352
# batched checker (if needed)
3353
check_content = _check_against_scipy_matrix
3355
check_content = _check_hybrid_matrix
3357
check_content = functools.partial(_check_batched, check_batch=check_content)
3359
sparse_sizes = [(6, 10), (0, 10), (6, 0), (0, 0)]
3360
blocksizes = [(2, 2), (1, 1), (1, 2)] if layout in blocked_layouts else [()]
3361
batch_sizes = [(3,), (1, 3), (2, 1, 3)] if batched else [()]
3362
hybrid_sizes = [(4, ), (2, 2)] if hybrid else [()]
3364
# general cases, always run
3365
for sparse_shape, blocksize, batch_shape, hybrid_shape in itertools.product(
3366
sparse_sizes, blocksizes, batch_sizes, hybrid_sizes):
3367
dense = _generate_subject(sparse_shape, batch_shape, hybrid_shape)
3368
sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape))
3369
check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)
3370
dense_back = sparse.to_dense()
3371
self.assertEqual(dense, dense_back)
3373
# special cases for batched tensors
3375
# batched sparse tensors need only have the same number of non-zeros in each batch not nessesarily the
3376
# same sparsity pattern in each batch
3377
sparse_shape = sparse_sizes[0]
3378
hybrid_shape = hybrid_sizes[0]
3379
batch_shape = batch_sizes[0]
3380
shape = batch_shape + sparse_shape + hybrid_shape
3381
dense = make_tensor(shape, dtype=torch.float, device=device)
3382
blocksize = blocksizes[0]
3383
# number of elements/blocks in each batch (total not nnz)
3384
batch_mask_shape = sparse_shape
3385
if layout in blocked_layouts:
3386
# if we are blocked the mask is genereated for the block valued elemetns
3387
batch_mask_shape = sparse_shape[0] // blocksize[0], sparse_shape[1] // blocksize[1]
3389
# random bool vector w/ length equal to max possible nnz for the sparse_shape
3390
mask_source = make_tensor(batch_mask_shape, dtype=torch.bool, device=device).flatten()
3391
n_batch = functools.reduce(operator.mul, batch_shape, 1)
3393
# stack random permutations of the source for each batch
3394
mask = torch.stack([mask_source[torch.randperm(mask_source.numel())]
3395
for _ in range(n_batch)], dim=0).reshape(batch_shape + batch_mask_shape)
3396
if layout in blocked_layouts:
3397
# for blocked we need to do a bit of extra work to expand the mask from blocked-space to element-space
3398
mask_shape = mask.shape
3399
mask = mask.view(mask_shape + (1, 1))
3400
mask = mask.expand(mask_shape + blocksize)
3401
mask = mask.transpose(-3, -2)
3402
mask = mask.flatten(-4, -3).flatten(-2, -1)
3403
mask_shape = mask.shape
3404
mask = mask.view(mask_shape + (1,) * len(hybrid_shape))
3405
mask = mask.expand(mask_shape + hybrid_shape)
3406
dense = dense * mask
3407
sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape))
3408
check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)
3410
dense_back = sparse.to_dense()
3411
self.assertEqual(dense, dense_back)
3413
# if batches have different nnz we expect the conversion to throw
3415
mask_1 = mask[0].clone().fill_(True)
3416
mask_2 = mask[0].clone().fill_(False)
3417
mask_true = mask_source.clone().fill_(True)
3418
mask_false = mask_source.clone().fill_(False)
3419
mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
3420
dense = make_tensor(shape, dtype=torch.float, device=device)
3421
dense = dense * mask
3422
msg = "Expect the same number of specified elements per batch."
3423
with self.assertRaisesRegex(RuntimeError, msg):
3424
dense.to_sparse(layout=layout, blocksize=blocksize or None)
3426
# Should throw if there is a zero in the batch size
3427
dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
3428
layout_code = str(layout).split("_")[-1]
3429
msg = f"to_sparse_{layout_code}: Expected product of batch dimensions to be non-zero."
3430
with self.assertRaisesRegex(RuntimeError, msg):
3431
dense.to_sparse(layout=layout, blocksize=blocksize or None)
3434
@all_sparse_compressed_layouts()
3436
@dtypes(torch.double)
3437
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
3438
def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
3440
This test tests conversion from COO to CSR and CSC and CSC to CSR and CSC
3441
by comparing to SciPy's implementation.
3443
Here we test only those conversion combinations that SciPy
3444
supports to ensure that PyTorch conversions are in the same
3445
page with SciPy. Independent from SciPy, all conversion
3446
combinations are tested in TestSparseAny.test_to_sparse.
3450
if layout in (torch.sparse_bsc, torch.sparse_bsr):
3451
blocksize_kw['blocksize'] = (2, 2)
3452
# block modes don't support 0 width/height
3454
elif layout in (torch.sparse_csc, torch.sparse_csr):
3455
shapes = [(0, 10), (6, 0), (6, 10), (0, 0)]
3457
raise NotImplementedError("unhandled layout")
3459
if layout in (torch.sparse_bsc, torch.sparse_csc):
3460
compressed_indices_mth = torch.Tensor.ccol_indices
3461
plain_indices_mth = torch.Tensor.row_indices
3462
elif layout in (torch.sparse_bsr, torch.sparse_csr):
3463
compressed_indices_mth = torch.Tensor.crow_indices
3464
plain_indices_mth = torch.Tensor.col_indices
3466
raise NotImplementedError("unhandled layout")
3468
for shape in shapes:
3470
nnz = shape[0] * shape[1] // 2
3471
sparse, _, _ = self.genSparseTensor(shape, sparse_dim, nnz, coalesced, device, dtype)
3472
sp_matrix = self._construct_sp_matrix(sparse, layout)
3473
pt_matrix = sparse.to_sparse(layout=layout, **blocksize_kw)
3475
self.assertEqual(layout, pt_matrix.layout)
3476
self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3477
self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3478
self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3479
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3481
sparse_csc = sparse.to_sparse_csc()
3482
sp_matrix = self._construct_sp_matrix(sparse_csc, layout)
3483
pt_matrix = sparse_csc.to_sparse(layout=layout, **blocksize_kw)
3485
self.assertEqual(layout, pt_matrix.layout)
3486
self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3487
self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3488
self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3489
self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3491
@unittest.skipIf(not TEST_CUDA_CUDSS, "The test requires cudss")
3492
@dtypes(*floating_types())
3493
def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
3494
# https://github.com/krshrimali/pytorch/blob/f5ee21dd87a7c5e67ba03bfd77ea22246cabdf0b/test/test_sparse_csr.py
3497
spd = torch.rand(4, 3)
3499
b = torch.rand(3).cuda()
3500
A = A.to_sparse_csr().cuda()
3501
x = torch.sparse.spsolve(A, b)
3502
except RuntimeError as e:
3503
if "Calling linear solver with sparse tensors requires compiling " in str(e):
3504
self.skipTest("PyTorch was not built with cuDSS support")
3506
samples = sample_inputs_linalg_solve(None, device, dtype)
3508
for sample in samples:
3509
if sample.input.ndim != 2:
3512
out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
3513
if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
3514
with self.assertRaisesRegex(RuntimeError, "b must be a 1D tensor"):
3515
out = torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
3517
if not sample.args[0].numel():
3518
with self.assertRaisesRegex(RuntimeError,
3519
"Expected non-empty other tensor, but found empty tensor"):
3520
torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
3523
expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
3524
sample.input = sample.input.to_sparse_csr()
3525
if sample.args[0].ndim != 1 and sample.args[0].size(-1) == 1:
3526
expect = expect.squeeze(-1)
3527
sample.args = (sample.args[0].squeeze(-1), )
3528
out = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
3529
self.assertEqual(expect, out)
3532
def skipIfNoTriton(cls):
3533
from torch.utils._triton import has_triton
3535
# no-op if triton is present
3540
@functools.wraps(cls, updated=())
3541
class skipped_cls(cls):
3543
self.skipTest("Triton is not available.")
3548
class TestSparseCompressedTritonKernels(TestCase):
3550
def _to_block_triangular_inplace(self, d, row_block, col_block):
3552
This function modifies `d` to become (upper/lower) block-triangular in-place.
3553
It is assumed that `d.shape[-2]` is divisible by `row_block` and
3554
`d.shape[-1]` is divisible by `col_block`.
3557
from torch.sparse._triton_ops import tile_to_blocksize
3560
d_tiled = tile_to_blocksize(d, (row_block, col_block))
3561
d_tiled = d_tiled.moveaxis(-4, -1).moveaxis(-4, -1)
3562
if m // row_block > n // col_block:
3570
@skipIfRocm(msg="test is too slow on ROCm stack")
3571
@dtypes(torch.half, torch.bfloat16, torch.float)
3572
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3573
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3574
def test_triton_bsr_softmax(self, device, dtype):
3575
from functools import partial
3576
from torch.sparse._triton_ops import bsr_softmax
3578
tensor = partial(make_tensor, device=device, dtype=dtype, low=1.0, high=3.0)
3580
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3581
batches = [(), (2,), (2, 2)]
3585
# General correctness
3586
for row_block, col_block, b, m, n in itertools.product(block_size, block_size, batches, size, size):
3587
input = tensor(b + (m, n))
3588
input.diagonal(dim1=-2, dim2=-1).fill_(m * n)
3589
input = self._to_block_triangular_inplace(input, row_block, col_block)
3591
bsr = input.to_sparse_bsr((row_block, col_block))
3592
coo = input.to_sparse().to(torch.float)
3594
res_tri = bsr_softmax(bsr)
3595
res_coo = torch.sparse.softmax(coo, -1)
3596
self.assertEqual(res_tri, res_coo.to(input.dtype))
3598
# Test long rows which exceed Triton's max numel limit set to 2 ** 17
3599
input = tensor(b + (1, 150000))
3600
bsr = input.to_sparse_bsr(1)
3601
self.assertEqual(input.softmax(-1), bsr_softmax(bsr))
3603
@parametrize("block_size", [16, 32, 64])
3604
@parametrize("index_dtype", [torch.int32, torch.int64])
3606
@dtypes(torch.half, torch.bfloat16, torch.float)
3607
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3608
@unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(),
3609
"Skipped for deploy and internal with remote GPUs")
3610
def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
3611
from functools import partial
3612
from torch.sparse._triton_ops import bsr_dense_mm
3614
def kernel_impl(*args, **kwargs):
3615
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
3617
kernel = torch._TritonLibrary.registerOp(
3618
"_triton_bsr_dense_mm_out",
3619
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
3624
# kernel != kernel_impl means dispatch was already registered.
3625
# This is exactly what we need!
3626
self.assertTrue(kernel is not kernel_impl)
3628
# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3629
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3631
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3632
batches = [(), (2,), (2, 2)]
3633
size = [128, 256, 0]
3635
# Whether to make inputs orthogonal so that the product is zero
3636
make_orthogonal = [True, False]
3638
for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
3639
bsr = tensor(bs + (m, k))
3640
# NOTE: do not get confused, it will be transposed
3641
dense = tensor(bd + (n, k))
3644
bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
3645
dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)
3647
bsr = bsr.to_sparse_bsr(block_size)
3649
if bsr.dim() == 2 and dtype != torch.float:
3650
# Test against linear to check dispatch
3651
# which takes place for torch.half and torch.bfloat16.
3652
res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
3653
res_tri_out = torch.empty_like(res_dense)
3654
res_tri = torch.nn.functional.linear(dense, bsr, out=res_tri_out)
3656
# Check dispatch worked with non-trivial outputs
3657
if m > 0 and n > 0 and k > 0:
3658
self.assertTrue(kernel.kernel_invoked)
3659
kernel.kernel_invoked = False
3661
# Otherwise check correctness against bmm
3662
# since nn.linear does not support bsr.dim() > 2.
3663
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
3664
res_tri_out = torch.empty_like(res_dense)
3665
res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out)
3667
self.assertTrue(res_tri is res_tri_out)
3668
self.assertEqual(res_tri, res_dense)
3670
res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
3671
# check whether bsr_dense_mm handles different grid sizes
3672
# None means max possible grid size which is CUDA-dependent.
3673
grid_size = (None, 2, 4)
3674
grid_gen = itertools.product(grid_size, repeat=3)
3675
for grid in grid_gen:
3676
res_tri = torch.sparse._triton_ops.bsr_dense_mm(
3678
dense.transpose(-2, -1),
3681
self.assertEqual(res_tri, res_dense)
3685
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
3686
"Skipped for deploy and internal with remote GPUs")
3687
def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
3688
from torch.sparse._triton_ops import bsr_dense_mm
3690
rhs = torch.rand(32, 32, dtype=dtype, device=device)
3691
lhs = rhs.to_sparse_bsr(16)
3692
with self.assertRaisesRegex(ValueError, "only BSR sparse format is supported"):
3693
bsr_dense_mm(lhs.to_sparse_bsc(16), rhs)
3694
with self.assertRaisesRegex(ValueError, "on the same GPU device"):
3695
bsr_dense_mm(lhs, rhs.cpu())
3696
if torch.cuda.device_count() > 1:
3697
with self.assertRaisesRegex(ValueError, "on the same GPU device"):
3698
bsr_dense_mm(lhs.to("cuda:0"), rhs.to("cuda:1"))
3699
with self.assertRaisesRegex(ValueError, "all inputs are expected to be of the same dtype"):
3700
bsr_dense_mm(lhs, rhs.to(torch.float))
3701
with self.assertRaisesRegex(ValueError, r"and one of \(half, bfloat16, float32\)"):
3702
bsr_dense_mm(lhs.to(torch.double), rhs.to(torch.double))
3703
with self.assertRaisesRegex(ValueError, "all inputs involved in the matrix product are expected to be at least 2D"):
3704
bsr_dense_mm(lhs, torch.rand(1, dtype=dtype, device=device))
3705
with self.assertRaisesRegex(ValueError,
3706
"sizes involved in the matrix product are not compatible for matrix multiplication"):
3707
bsr_dense_mm(lhs, torch.rand(1, 1, dtype=dtype, device=device))
3708
with self.assertRaisesRegex(ValueError,
3709
r"dense.size\(-1\) == 15 should be divisible by 16"):
3710
bsr_dense_mm(lhs, torch.rand(32, 15, dtype=dtype, device=device))
3712
for blocksize in (15, 30):
3714
rhs = torch.rand(n, n, dtype=dtype, device=device)
3715
lhs = rhs.to_sparse_bsr(blocksize)
3716
with self.assertRaisesRegex(ValueError, "should be at least 16 and a power of 2"):
3717
bsr_dense_mm(lhs, rhs)
3719
rhs = torch.rand(2, 32, 32, dtype=dtype, device=device)
3720
lhs = rhs.to_sparse_bsr(16)
3721
with self.assertRaisesRegex(ValueError, r"`out` argument has wrong shape"):
3722
out = torch.rand(2, 30, 30, dtype=dtype, device=device)
3723
bsr_dense_mm(lhs, rhs, out=out)
3724
with self.assertRaisesRegex(ValueError, r"only row-major/col-major `out`"):
3725
out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1)
3726
bsr_dense_mm(lhs, rhs, out=out)
3728
@parametrize("block_size", [16, 32, 64])
3731
@dtypes(torch.half, torch.bfloat16, torch.float)
3732
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3733
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3734
@precisionOverride({torch.float16: 1e-3})
3735
def test_triton_scaled_dot_product_attention(self, device, dtype, block_size):
3736
from functools import partial
3737
from torch.sparse._triton_ops import _scaled_dot_product_attention
3739
# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3740
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)
3742
def broadcast_input(*ts):
3743
batch_dims = torch.broadcast_shapes(*(t.shape[:-2] for t in ts))
3744
yield from (torch.broadcast_to(t, batch_dims + t.shape[-2:]) for t in ts)
3746
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3747
batches = [(), (2,), (2, 2)]
3748
size = [128, 256, 0]
3750
for bam, bq, bk, bv, m, n, k in itertools.product(batches, batches, batches, batches, size, size, size):
3751
query = tensor(bq + (m, k))
3752
key = tensor(bk + (n, k))
3753
value = tensor(bv + (n, k))
3755
# We make attn_mask block lower/upper triangular so that BSR and Strided
3756
# function variants are directly comparable.
3757
attn_mask = torch.ones(bam + (m, n), device=device, dtype=torch.bool)
3758
attn_mask = self._to_block_triangular_inplace(attn_mask, block_size, block_size)
3759
attn_mask_bsr = attn_mask.to_sparse_bsr(block_size)
3761
# NOTE: only boolean mask is directly compatible with the Strided version
3762
# without any pre-/post-processing. Hence we test against a boolean mask.
3763
for scale in (None, 1. / 16):
3764
if scale is None and query.size(-1) == 0:
3766
expected = torch.nn.functional.scaled_dot_product_attention(
3767
*broadcast_input(query, key, value, attn_mask), scale=scale
3770
for mask_dtype in (torch.bool, dtype):
3771
res = _scaled_dot_product_attention(query, key, value, attn_mask_bsr.to(mask_dtype), scale=scale)
3772
self.assertEqual(res, expected)
3775
@parametrize("block_size", [16, 32, 64])
3778
@dtypes(torch.half, torch.bfloat16, torch.float)
3779
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3780
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3781
def test_triton_sampled_addmm(self, device, dtype, block_size):
3782
from functools import partial
3783
from torch.sparse._triton_ops import sampled_addmm, broadcast_batch_dims_bsr
3785
# Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3786
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)
3788
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3789
batches = [(), (2,), (2, 2)]
3790
size = [128, 256, 0]
3793
for bi, bm1, bm2, m, n, k, dk in itertools.product(batches, batches, batches, size, size, size, delta_k):
3794
# Test not powers of 2 ks as well.
3796
# Non-trivial sparsity pattern.
3797
# Plus with tril inputs the result is also tril,
3798
# so we can compare BSR and CSR implementations.
3799
input = tensor(bi + (m, n)).tril_()
3800
bsr = input.to_sparse_bsr(block_size)
3801
mat1 = tensor(bm1 + (m, k)).tril_()
3802
mat2 = tensor(bm2 + (k, n)).tril_()
3804
batch_dim = torch.broadcast_shapes(input.shape[:-2], mat1.shape[:-2], mat2.shape[:-2])
3806
csr = input.broadcast_to(batch_dim + input.shape[-2:]).to_sparse_csr().to(torch.float)
3807
mat1csr = mat1.broadcast_to(batch_dim + mat1.shape[-2:]).to(torch.float)
3808
mat2csr = mat2.broadcast_to(batch_dim + mat2.shape[-2:]).to(torch.float)
3810
input_broadcasted_clone = broadcast_batch_dims_bsr(
3811
"test_triton_sampled_addmm",
3814
input_broadcasted_clone = torch.sparse_compressed_tensor(
3815
input_broadcasted_clone.crow_indices(),
3816
input_broadcasted_clone.col_indices(),
3817
# For testing `out=` let's make values to have "weird" strides
3818
# so that if the kernel modifies values to it's needs, the result
3819
# is being compied into out.values.
3820
input_broadcasted_clone.values().transpose(-3, -2).contiguous().transpose(-3, -2),
3821
layout=input_broadcasted_clone.layout,
3822
size=input_broadcasted_clone.shape
3825
scalars = (0.0, 2.0)
3826
for alpha, beta, out in itertools.product(scalars, scalars, (None, input_broadcasted_clone)):
3827
res_tri = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, out=out)
3829
self.assertTrue(res_tri is out)
3831
batch_broadcasted_shape = torch.broadcast_shapes(*(t.shape[:-2] for t in (input, mat1, mat2)))
3832
self.assertTrue(res_tri.shape == batch_broadcasted_shape + (m, n))
3834
res_csr = torch.sparse.sampled_addmm(csr, mat1csr, mat2csr, alpha=alpha, beta=beta).to(input.dtype)
3835
self.assertEqual(res_tri.to_dense(), res_csr.to_dense())
3837
# Check different grid sizes to make sure that input slicing works
3838
# if this input is larger than the grid.
3839
grid_size = (3, None)
3840
grid_gen = itertools.product(grid_size, repeat=2)
3841
for grid in grid_gen:
3842
res_tri_grid = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, max_grid=grid)
3843
self.assertEqual(res_tri, res_tri_grid)
3847
@dtypes(torch.half, torch.bfloat16, torch.float)
3848
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3849
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3850
def test_triton_scatter_mm(self, device, dtype):
3851
from torch.sparse._triton_ops import scatter_mm
3852
from functools import partial
3853
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3855
for m, k, n in itertools.product(sizes, sizes, sizes):
3856
blocks = torch.stack([tensor(m, k), tensor(m, k)])
3857
others = torch.stack([tensor(k, n), tensor(k, n)])
3859
expected = torch.stack([blocks[0] @ others[0] + blocks[1] @ others[0],
3860
blocks[0] @ others[1],
3861
blocks[1] @ others[1]])
3865
torch.tensor([0, 2, 3, 4], dtype=torch.int32, device=device),
3866
torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32, device=device))
3868
result = scatter_mm(blocks, others, indices_data=indices_data)
3870
self.assertEqual(result, expected)
3874
torch.tensor([0, 2, 4, 5, 6], dtype=torch.int32, device=device),
3875
torch.tensor([0, n, 2 * n * m, 2 * n * m + n], dtype=torch.int32, device=device),
3876
torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.int32, device=device),
3877
torch.tensor([0, 2 * k * n, n, 2 * k * n + n, 2 * k * n, 2 * k * n + n],
3878
dtype=torch.int32, device=device),
3879
dict(SPLIT_N=2, is_compressed=False, TILE_M=m, TILE_N=n, GROUP_SIZE=1)
3882
for bsize in [(), (2,), (3, 4)]:
3883
other = tensor(*bsize, 2 * k, 2 * n)
3884
expected = torch.cat([
3885
torch.cat([blocks[1], blocks[0]], dim=1),
3886
torch.cat([torch.zeros_like(blocks[0]), blocks[1]], dim=1)], dim=0) @ other
3887
result = scatter_mm(blocks, other, indices_data=indices_data)
3888
self.assertEqual(result, expected)
3890
@parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64])
3892
@dtypes(torch.half, torch.bfloat16, torch.float)
3893
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3894
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3895
def test_triton_bsr_scatter_mm(self, device, dtype, blocksize):
3897
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
3898
from functools import partial
3899
if isinstance(blocksize, str):
3900
blocksize = tuple(map(int, blocksize.split('x')))
3902
blocksize = (blocksize,) * 2
3903
# Note that each value in a non-zero block is in range blocksize * [low^2, high^2).
3904
tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3906
# NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3907
batches = [(), (2,), (2, 2)]
3908
sizes = [blocksize[0], 2 * blocksize[0], 4 * blocksize[0]]
3909
sizes_K = [blocksize[1], 2 * blocksize[1]]
3911
for bd, bs, M, K, N, has_zero_row_block in itertools.product(batches, batches[:1], sizes, sizes_K, sizes, (False, True)):
3912
bsr_dense = tensor(bs + (M, K))
3913
if has_zero_row_block:
3914
if M > blocksize[0]:
3915
bsr_dense[:blocksize[0]].zero_()
3918
bsr = bsr_dense.to_sparse_bsr(blocksize)
3919
dense = tensor(bd + (K, N))
3920
expected = bsr.to_dense() @ dense
3922
for indices_format in ('bsr_strided_mm', 'bsr_strided_mm_compressed', 'scatter_mm'):
3923
if indices_format in {'bsr_strided_mm', 'bsr_strided_mm_compressed'}:
3925
while SPLIT_N_list[-1] > 1:
3926
SPLIT_N_list.append(max(1, SPLIT_N_list[-1] // 2))
3929
for SPLIT_N in SPLIT_N_list:
3930
indices_data = bsr_scatter_mm_indices_data(
3931
bsr, dense, indices_format=indices_format, SPLIT_N=SPLIT_N)
3933
result = bsr_scatter_mm(bsr, dense, indices_data=indices_data)
3934
except triton.compiler.OutOfResources:
3935
# ensure that there was at least one succesful test:
3936
assert SPLIT_N < SPLIT_N_list[0]
3939
self.assertEqual(result, expected)
3940
torch.sparse._triton_ops._bsr_scatter_mm_indices_data.cache_clear()
3942
def test_TensorAsKey(self, device):
3943
from torch.sparse._triton_ops import TensorAsKey
3944
assertEqualOptions = dict(exact_dtype=True, exact_device=True, exact_layout=True)
3946
t = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device=device)
3947
key = TensorAsKey(t)
3948
self.assertTrue(key == TensorAsKey(t))
3949
self.assertTrue(key.obj is t)
3952
key2 = TensorAsKey(t2)
3953
self.assertTrue(key == key2)
3954
self.assertEqual(key2.obj, t, **assertEqualOptions)
3955
# deleting object leads to dead key
3957
self.assertTrue(key2.obj is None)
3958
self.assertTrue(key.obj is t)
3960
# key with different storage offset and shape:
3961
self.assertFalse(key == TensorAsKey(t[1:]))
3963
# key with different strides:
3964
self.assertFalse(key == TensorAsKey(t[::2]))
3966
# when object dies, make sure that key represents a dead
3969
self.assertTrue(key.obj is None)
3971
# Storing a tensor as a dict key:
3973
t3 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
3974
key3 = TensorAsKey(t3)
3976
self.assertTrue(d.get(key3) == 123)
3978
self.assertTrue(d.get(TensorAsKey(t3_)) == 123)
3979
self.assertTrue(d.get(TensorAsKey(t3.clone())) is None)
3981
d[TensorAsKey(t3_)] = 567
3982
self.assertTrue(d.get(key3) == 567)
3984
# t3 and t3_ reference the same data, so, the key becomes dead
3985
# (that is, its .obj property returns None) until all
3986
# references are deleted:
3988
self.assertTrue(key3.obj is not None)
3989
self.assertTrue(d.get(key3) == 567)
3991
self.assertTrue(key3.obj is None)
3992
self.assertTrue(d.get(key3) == 567)
3994
# Storing a tensor as a dict key and value:
3996
t4 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
3997
key4 = TensorAsKey(t4)
3999
self.assertEqual(d.get(key4), (t4, 123), **assertEqualOptions)
4000
# when object is deleted, the key represents an alive object
4001
# because the object is referenced by the dict item value:
4003
self.assertTrue(key4.obj is not None)
4004
# This also means that the life time of the tensor is same as
4005
# the life time of the corresponding dict item:
4007
self.assertTrue(key4.obj is None)
4009
# Storing a tensor as a dict key and value wrapped with TensorAsKey:
4011
t5 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
4012
key5 = TensorAsKey(t5)
4013
d[key5] = (key5, 567)
4014
self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
4015
self.assertTrue(key5.obj is not None)
4016
# when object is deleted, it will be dead as the wrapped value
4017
# hold the tensor instance as a weakref:
4019
self.assertTrue(key5.obj is None)
4020
# but key is still valid:
4021
self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
4024
@parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm'])
4025
@parametrize("blocksize", [16, '16x32', 32])
4028
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
4029
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
4030
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4031
def test_triton_kernel(self, op, device, dtype, blocksize):
4032
from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
4033
from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
4034
optimize_bsr_dense_addmm, dump)
4036
def bsr_dense_linear(input, weights, bias=None):
4037
return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
4039
operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
4040
_int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
4042
def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
4043
assert mat1.layout is torch.strided
4044
assert mat2.layout is torch.strided
4045
if dtype is torch.int8:
4046
if op == '_int_bsr_dense_addmm':
4047
return beta * input + alpha * torch._int_mm(mat1, mat2)
4048
# workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
4049
return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
4050
return beta * input + alpha * (mat1 @ mat2)
4052
if op == '_int_bsr_dense_addmm':
4053
# _int_bsr_dense_addmm is same as bsr_dense_addmm except
4054
# with int8 inputs, _int_bsr_dense_addmm returns int32
4055
# result. This is covered by operation and reference
4056
# definitions above and all other definitions below are
4057
# identical between _int_bsr_dense_addmm and
4059
op = 'bsr_dense_addmm'
4061
def nc_copy(t, axes=(-1,)):
4062
"""Return a copy of input.
4064
The returned copy will be a non-contiguous tensor.
4066
if t.layout is torch.strided:
4067
shape = list(t.shape)
4070
r = torch.empty(shape, dtype=t.dtype, device=t.device)
4071
s = r[tuple(slice(None, None, 2 if t.shape[i] != r.shape[i] else None) for i in range(t.ndim))]
4074
elif t.layout is torch.sparse_bsr:
4075
compressed_indices = t.crow_indices()
4076
plain_indices = t.col_indices()
4077
return torch.sparse_compressed_tensor(compressed_indices, plain_indices, nc_copy(t.values()),
4078
t.shape, layout=t.layout)
4080
raise NotImplementedError(t.layout)
4082
if isinstance(blocksize, str):
4083
BM, BK = tuple(map(int, blocksize.split('x')))
4085
BM, BK = (blocksize,) * 2
4087
if op in {"bsr_dense_linear"} and BM != BK:
4088
# todo: eliminate this skip
4089
self.skipTest(f"{op} does not support non-square blocks")
4091
if op in {"bsr_dense_linear"} and dtype is torch.int8:
4092
# todo: eliminate this skip
4093
self.skipTest(f"{op} does not support int8")
4095
if dtype is torch.int8 and min(BM, BK) < 32:
4096
self.skipTest("triton kernel does not support support int8 blocks smaller than 32")
4098
beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
4099
alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
4100
sparsity_lst = [0, 0.5, 1]
4101
blocks_per_row_lst = [1, 2]
4102
blocks_per_col_lst = [1, 2]
4103
result_cols_lst = [16, 32, 64]
4104
for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
4105
beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
4106
M = BM * blocks_per_row
4107
K = BK * blocks_per_col
4108
mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
4109
bsr = mat1.to_sparse_bsr((BM, BK))
4110
mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
4111
input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
4113
if 0 and op == "bsr_dense_addmm":
4114
# Find optimal kernel parameters, the speed-up is
4115
# about 10x for running this test.
4117
# Enable this if-block when the test method is
4118
# updated, run the test, and finally, disable the
4120
key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
4121
meta = get_meta(op, key, version=(0, dtype, 0.5))
4123
optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
4124
meta = get_meta(op, key, version=(0, dtype, 0.5))
4125
assert meta is not None
4126
dump() # this will update torch/sparse/_triton_ops_meta.py
4128
expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
4129
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={},
4130
bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
4132
args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
4133
bsr_dense_linear=(mat2.transpose(-1, -2), bsr))[op]
4134
result = operation(*args, **kwargs)
4135
self.assertEqual(result, expected)
4137
# Test non-contiguous input tensors:
4138
nc_mat2 = nc_copy(mat2)
4139
nc_input = nc_copy(input)
4140
nc_bsr = nc_copy(bsr)
4142
args = dict(bsr_dense_addmm=(input, bsr, nc_mat2), bsr_dense_mm=(bsr, nc_mat2),
4143
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
4144
result = operation(*args, **kwargs)
4145
self.assertEqual(result, expected)
4147
# todo: add bsr_dense_linear to the set below (currently,
4148
# nn.linear has unnecessarily restrictive arguments
4150
if op in {'bsr_dense_addmm', 'bsr_dense_mm'}:
4151
args = dict(bsr_dense_addmm=(input, nc_bsr, mat2), bsr_dense_mm=(nc_bsr, mat2),
4152
bsr_dense_linear=(mat2.transpose(-1, -2), nc_bsr))[op]
4153
result = operation(*args, **kwargs)
4154
self.assertEqual(result, expected)
4156
if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
4157
args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
4158
bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
4159
kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
4160
bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
4161
result = operation(*args, **kwargs)
4162
self.assertEqual(result, expected)
4164
@parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm'])
4167
@dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
4168
@dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
4169
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4170
def test_triton_tune(self, op, device, dtype):
4171
from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm
4172
from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta)
4174
operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
4175
tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm,
4176
_int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op]
4178
if op == '_int_bsr_dense_addmm':
4179
M, K, N = 32, 32, 32
4180
blocksize = (32, 32)
4182
M, K, N = 16, 16, 32
4183
blocksize = (16, 16)
4185
bsr = create_blocked_tensor(0, M, K, blocksize, sparsity, dtype, device).to_sparse_bsr(blocksize)
4186
sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K)
4187
input = make_tensor(K, N, dtype=dtype, device=device)
4188
dense = make_tensor(K, N, dtype=dtype, device=device)
4190
if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}:
4191
args = (input, bsr, dense)
4193
def get_current_meta():
4194
version = (0, dtype, sparsity)
4195
meta_key = (M, K, N, *blocksize, False, True, True)
4196
return get_meta(op, meta_key, version=version, exact=True)
4198
raise NotImplementedError(op)
4200
self.assertEqual(get_current_meta(), None)
4202
meta = tuner(*args, **dict(store=True, verbose=False))
4203
self.assertEqual(get_current_meta(), meta)
4205
expected = operation(*args)
4206
result = operation(*args, **dict(meta=meta))
4207
self.assertEqual(result, expected)
4211
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4212
def test_triton_bsr_dense_addmm_meta(self, device):
4213
from torch.sparse._triton_ops import bsr_dense_addmm_meta
4214
from torch.sparse._triton_ops_meta import update as update_bsr_dense_addmm_meta
4216
dtype = torch.float32
4221
def get_meta(M, K, N, sparsity=None):
4222
return bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, dtype=dtype, sparsity=sparsity,
4223
_version="test_triton_bsr_dense_addmm_meta")
4225
def update_meta(M, K, N, value, sparsity=0.5):
4226
key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
4227
update_bsr_dense_addmm_meta("bsr_dense_addmm", torch.cuda.get_device_name(),
4228
("test_triton_bsr_dense_addmm_meta", dtype, sparsity),
4231
def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None):
4233
with redirect_stderr(f):
4234
result = get_meta(M, K, N, sparsity=sparsity)
4236
FileCheck().check_count(
4237
str=f"UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M={M} K={K} N={N}",
4238
count=warn_count, exactly=True
4242
# Test warn_once when requesting non-existing tuned parameters multiple times
4244
with redirect_stderr(f):
4246
get_meta(16, 16, 16)
4248
get_meta(16, 16, 32)
4251
FileCheck().check_count(
4252
str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=16", count=1, exactly=True
4254
FileCheck().check_count(
4255
str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=32", count=1, exactly=True
4258
# Test warn_once when tuned parameters are missing
4259
default_meta = dict(GROUP_SIZE_ROW=4, SPLIT_N=2, num_stages=1, num_warps=4)
4260
self.assertEqual(get_meta_with_checks(32, 32, 32, warn_count=1), default_meta)
4262
# Test (no)warn_once when tuned parameters are available
4263
update_meta(32, 32, 48, (2, 8, 5, 6))
4264
expected_meta = dict(GROUP_SIZE_ROW=2, SPLIT_N=8, num_stages=5, num_warps=6)
4265
self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0), expected_meta)
4267
# Test non-existing tuned parameters with non-default sparsity
4268
# while for default sparsity 0.5 the parameters are available
4269
self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0, sparsity=0.6), expected_meta)
4271
# Test non-existing tuned parameters while there exists
4272
# parameters with consistent N // SPLIT_N ratio:
4273
self.assertEqual(get_meta_with_checks(32, 32, 72, warn_count=0),
4274
dict(GROUP_SIZE_ROW=2, SPLIT_N=12, num_stages=5, num_warps=6))
4276
self.assertEqual(get_meta_with_checks(32, 32, 64, warn_count=1),
4277
dict(GROUP_SIZE_ROW=4, SPLIT_N=4, num_stages=1, num_warps=4))
4280
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
4281
instantiate_device_type_tests(TestSparseCSR, globals())
4282
instantiate_device_type_tests(TestSparseCompressed, globals())
4283
instantiate_device_type_tests(TestSparseCompressedTritonKernels, globals())
4285
if __name__ == '__main__':