pytorch

Форк
0
/
test_sparse_csr.py 
4286 строк · 206.9 Кб
1
# Owner(s): ["module: sparse"]
2

3
import torch
4
import random
5
import io
6
import itertools
7
import unittest
8
import functools
9
from contextlib import redirect_stderr
10
from torch.testing import make_tensor, FileCheck
11
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
12
from torch.testing._internal.common_utils import \
13
    (TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
14
     run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
15
     suppress_warnings)
16
from torch.testing._internal.common_device_type import \
17
    (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
18
     precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
19
     largeTensorTest)
20
from torch.testing._internal.common_methods_invocations import \
21
    (op_db, sparse_csr_unary_ufuncs, ReductionOpInfo)
22
from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUDA
23
from torch.testing._internal.common_dtype import (
24
    floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
25
    all_types_and_complex, floating_and_complex_types_and)
26
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
27
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
28
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
29
import operator
30

31
if TEST_SCIPY:
32
    import scipy.sparse as sp
33

34
if TEST_NUMPY:
35
    import numpy as np
36
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
37
# sharding on sandcastle. This line silences flake warnings
38
load_tests = load_tests
39

40
no_mkl_sparse = IS_WINDOWS or not TEST_MKL
41

42
def _check_cusparse_triangular_solve_available():
43
    version = _get_torch_cuda_version()
44
    # cusparseSpSM was added in 11.3.1 but we don't have access to patch version
45
    min_supported_version = (11, 4)
46
    return version >= min_supported_version
47

48
def _check_cusparse_spgemm_available():
49
    # cusparseSpGEMM was added in 11.0
50
    return not TEST_WITH_ROCM
51

52
def _check_cusparse_sddmm_available():
53
    if TEST_WITH_ROCM:
54
        return True
55
    version = _get_torch_cuda_version()
56
    # cusparseSDDMM was added in 11.2.1 but we don't have access to patch version
57
    min_supported_version = (11, 3)
58
    return version >= min_supported_version
59

60
_sparse_csr_ops = list(filter(lambda op: op.supports_sparse_csr, op_db))
61
_sparse_compressed_ops = list(filter(lambda op: (op.supports_sparse_csr or op.supports_sparse_csc
62
                                                 or op.supports_sparse_bsr or op.supports_sparse_bsc), op_db))
63
binary_functions_with_dense_output = ['mm', 'mv', ]
64
binary_ops_with_dense_output = list(filter(lambda op: op.name in binary_functions_with_dense_output, op_db))
65

66
UNARY_EWISE_CSR_ALLOW_AUTOGRAD = [
67
    'abs',
68
    'conj_physical',
69
    'deg2rad',
70
    'neg',
71
    'positive',
72
    'frac',
73
    'nn.functional.relu',
74
    'log1p',
75
    'rad2deg'
76
]
77

78
# This should be just an import from test_linalg instead of code duplication
79
# but https://github.com/pytorch/pytorch/pull/63511#discussion_r733989701
80
def _test_addmm_addmv(
81
    test_case,
82
    f,
83
    t,
84
    m,
85
    v,
86
    *,
87
    alpha=None,
88
    beta=None,
89
    transpose_out=False,
90
    layout=torch.strided,
91
    mode=None
92
):
93
    """
94
    Unified test for checking `f(t, m, v, alpha=alpha, beta=beta)` computation,
95
    where f is `torch.addmv` or `torch.addmm`.
96
    `transpose_out` controls whether the out argument is in column-major order.
97
    `layout` controls whether `m` is converted to specified layout or not.
98
    Custom behaviour is implemented only for torch.sparse_csr layout.
99
    """
100
    dtype = t.dtype
101
    numpy_dtype = dtype
102
    if dtype in {torch.bfloat16}:
103
        numpy_dtype = torch.float
104
    if dtype.is_complex:
105
        alpha = 0.9 + 0.3j if alpha is None else alpha
106
        beta = 0.5 + 0.6j if beta is None else beta
107
    else:
108
        alpha = 1.2 if alpha is None else alpha
109
        beta = 0.8 if beta is None else beta
110

111
    def convert_layout(mat):
112
        if layout == torch.sparse_csr:
113
            return mat.to_sparse_csr()
114
        elif layout == torch.sparse_csc:
115
            return mat.to_sparse_csc()
116
        else:
117
            assert mat.layout == layout
118
            return mat
119

120
    if mode == "all_sparse":
121
        res1 = f(*map(convert_layout, (t, m, v)), alpha=alpha, beta=beta)
122
        test_case.assertEqual(res1.layout, layout)
123
        res1 = res1.to_dense()
124
    elif mode == "dense_result":
125
        res1 = f(t, convert_layout(m), convert_layout(v), alpha=alpha, beta=beta)
126
    else:
127
        res1 = f(t, convert_layout(m), v, alpha=alpha, beta=beta)
128
    res2 = torch.full_like(res1, float('nan'))
129
    if transpose_out:
130
        res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
131
    f(t, convert_layout(m), v, alpha=alpha, beta=beta, out=res2)
132
    res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
133
    if beta != 0:
134
        res3 += (beta * t).to(numpy_dtype).cpu().numpy()
135
    res3 = torch.from_numpy(res3).to(dtype)
136
    test_case.assertEqual(res1, res2)
137
    test_case.assertEqual(res1, res3)
138

139

140
class TestSparseCSRSampler(TestCase):
141

142
    def test_make_crow_indices(self):
143
        # Here we test the correctness of the crow_indices algorithm
144
        # and testing it on CPU and with int32 dtype will be
145
        # sufficient.
146
        device = torch.device('cpu')
147
        index_dtype = torch.int32
148
        for n_rows in range(1, 10):
149
            for n_cols in range(1, 10):
150
                for nnz in range(0, n_rows * n_cols + 1):
151
                    crow_indices = self._make_crow_indices(
152
                        n_rows, n_cols, nnz,
153
                        device=device, dtype=index_dtype)
154
                    self.assertEqual(len(crow_indices), n_rows + 1)
155
                    counts = crow_indices[1:] - crow_indices[:-1]
156
                    self.assertEqual(counts.sum(), nnz)
157
                    self.assertGreaterEqual(counts.min(), 0)
158
                    self.assertLessEqual(counts.max(), n_cols)
159

160

161
def all_sparse_compressed_layouts(test_name='layout'):
162
    return parametrize(test_name, [
163
        subtest(torch.sparse_csr, name='SparseCSR'),
164
        subtest(torch.sparse_csc, name='SparseCSC'),
165
        subtest(torch.sparse_bsr, name='SparseBSR'),
166
        subtest(torch.sparse_bsc, name='SparseBSC')])
167

168

169
def sparse_compressed_nonblock_layouts(test_name='layout'):
170
    return parametrize(test_name, [
171
        subtest(torch.sparse_csr, name='SparseCSR'),
172
        subtest(torch.sparse_csc, name='SparseCSC')])
173

174

175
sparse_compressed_indices_methods = {
176
    torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
177
    torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
178
    torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
179
    torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
180
}
181

182

183
def batched_nonbatched(test_name='batched'):
184
    return parametrize(test_name, [
185
        subtest(True, name="Batched"),
186
        subtest(False, name="NonBatched")
187
    ])
188

189

190
def hybrid_nonhybrid(test_name='hybrid'):
191
    return parametrize(test_name, [
192
        subtest(True, name="Hybrid"),
193
        subtest(False, name="NonHybrid")
194
    ])
195

196

197
class TestSparseCompressed(TestCase):
198
    """Testing sparse compressed (CSR, CSC, BSR, BSC) tensor generic features.
199
    """
200

201
    def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_dtype=torch.int64):
202
        if device is None:
203
            device = self.device_type
204
        return self.genSparseCompressedTensor(size, nnz, device=device, dtype=dtype, index_dtype=index_dtype, layout=layout)
205

206
    @all_sparse_compressed_layouts()
207
    @onlyCPU
208
    def test_layout(self, layout):
209
        self.assertIn(str(layout), {'torch.sparse_csr', 'torch.sparse_csc', 'torch.sparse_bsr', 'torch.sparse_bsc'})
210
        self.assertEqual(type(layout), torch.layout)
211

212
    @parametrize('shape_and_device_inference', [subtest(False, name='_'), subtest(True, name='shape_and_device_inference')])
213
    @parametrize('use_factory_function', [subtest(False, name='_'), subtest(True, name='factory')])
214
    @parametrize('input_kind', [subtest('tensor', name='from_tensor'), subtest('list', name='from_list')])
215
    @all_sparse_compressed_layouts()
216
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
217
    def test_sparse_compressed_constructor(self, layout, device, dtype,
218
                                           use_factory_function, shape_and_device_inference, input_kind):
219
        if input_kind == 'list' and shape_and_device_inference:
220
            if torch.device(device).type == 'cuda':
221
                # list inputs to factory/constructor function without
222
                # specifying device will result a sparse compressed tensor
223
                # on CPU. So, skip testing against cuda device as unused.
224
                self.skipTest("nothing to test")
225
            if dtype not in {torch.float32, torch.complex64, torch.int64, torch.bool}:
226
                self.skipTest("dtype not supported with list values")
227

228
        expected_devices = [torch.device(device)]
229
        if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2 and not shape_and_device_inference:
230
            expected_devices.append(torch.device('cuda:1'))
231

232
        factory_function = {
233
            torch.sparse_csr: torch.sparse_csr_tensor,
234
            torch.sparse_csc: torch.sparse_csc_tensor,
235
            torch.sparse_bsr: torch.sparse_bsr_tensor,
236
            torch.sparse_bsc: torch.sparse_bsc_tensor,
237
        }[layout]
238
        compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
239
        if input_kind == 'list':
240
            index_dtypes = [torch.int64]
241
        else:
242
            index_dtypes = [torch.int32, torch.int64]
243
        if dtype.is_floating_point or dtype.is_complex:
244
            requires_grad_lst = [False, True]
245
        else:
246
            requires_grad_lst = [False]
247
        for index_dtype in index_dtypes:
248
            for expected_device in expected_devices:
249
                for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
250
                        layout, device=expected_device, dtype=dtype, index_dtype=index_dtype,
251
                        # skip zero-sized tensors for list inputs:
252
                        enable_zero_sized=input_kind != 'list',
253
                        output_tensor=False):
254
                    size = kwargs['size']
255
                    if shape_and_device_inference and 0 in size:
256
                        # skip shape inference for zero-sized tensor
257
                        # inputs because (i) the shape determined from
258
                        # an empty list is ambiguous, and (ii) the
259
                        # size of the plain dimension defined as
260
                        # max(plain_indices) is undefined if
261
                        # plain_indices has no values
262
                        continue
263
                    compressed_indices_expect = compressed_indices
264
                    plain_indices_expect = plain_indices
265
                    values_expect = values
266

267
                    if input_kind == 'list':
268
                        compressed_indices = compressed_indices.tolist()
269
                        plain_indices = plain_indices.tolist()
270
                        values = values.tolist()
271

272
                    for requires_grad in requires_grad_lst:
273
                        if use_factory_function:
274
                            if shape_and_device_inference:
275
                                sparse = factory_function(
276
                                    compressed_indices, plain_indices, values, requires_grad=requires_grad)
277
                            else:
278
                                sparse = factory_function(
279
                                    compressed_indices, plain_indices, values, size,
280
                                    dtype=dtype, device=expected_device, requires_grad=requires_grad)
281
                        else:
282
                            if shape_and_device_inference:
283
                                sparse = torch.sparse_compressed_tensor(
284
                                    compressed_indices, plain_indices, values,
285
                                    layout=layout, requires_grad=requires_grad)
286
                            else:
287
                                sparse = torch.sparse_compressed_tensor(
288
                                    compressed_indices, plain_indices, values, size,
289
                                    dtype=dtype, layout=layout, device=expected_device, requires_grad=requires_grad)
290

291
                        self.assertEqual(layout, sparse.layout)
292
                        self.assertEqual(size, sparse.shape)
293
                        self.assertEqual(compressed_indices_expect, compressed_indices_mth(sparse))
294
                        self.assertEqual(plain_indices_expect, plain_indices_mth(sparse))
295
                        self.assertEqual(values_expect, sparse.values())
296
                        self.assertEqual(sparse.device, sparse.values().device)
297
                        self.assertEqual(sparse.device, expected_device)
298
                        self.assertEqual(sparse.values().requires_grad, requires_grad)
299
                        self.assertEqual(sparse.requires_grad, requires_grad)
300
                        self.assertFalse(compressed_indices_mth(sparse).requires_grad)
301
                        self.assertFalse(plain_indices_mth(sparse).requires_grad)
302

303
    @skipMeta
304
    @sparse_compressed_nonblock_layouts()
305
    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
306
    def test_empty(self, layout, device, dtype):
307
        ns = [5, 2, 0]
308
        batch_shapes = [(), (2,), (2, 3)]
309
        compressed_dim = {
310
            torch.sparse_csr: -2,
311
            torch.sparse_csc: -1,
312
        }[layout]
313
        compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
314
        for m, n, b in itertools.product(ns, ns, batch_shapes):
315
            shape = (*b, m, n)
316
            with torch.sparse.check_sparse_tensor_invariants(enable=False):
317
                # torch.empty may return invalid sparse compressed tensors
318
                result = torch.empty(shape, dtype=dtype, device=device, layout=layout)
319
            self.assertEqual(result.shape, shape)
320
            self.assertEqual(result.dtype, dtype)
321
            self.assertEqual(result.device, torch.device(device))
322
            self.assertEqual(result.layout, layout)
323
            self.assertEqual(compressed_indices_mth(result).shape, (*b, shape[compressed_dim] + 1,))
324
            self.assertEqual(plain_indices_mth(result).shape, (*b, 0,))
325
            self.assertEqual(result.values().shape, (*b, 0,))
326
            self.assertEqual(result._nnz(), 0)
327
            self.assertEqual(compressed_indices_mth(result).device, torch.device(device))
328
            self.assertEqual(plain_indices_mth(result).device, torch.device(device))
329
            self.assertEqual(result.values().device, torch.device(device))
330
            self.assertEqual(compressed_indices_mth(result).dtype, torch.int64)
331
            self.assertEqual(plain_indices_mth(result).dtype, torch.int64)
332
            self.assertEqual(result.values().dtype, dtype)
333

334
    @skipMeta
335
    @sparse_compressed_nonblock_layouts()
336
    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
337
    def test_empty_errors(self, layout, device, dtype):
338
        with self.assertRaisesRegex(RuntimeError,
339
                                    "torch.empty: Only batched sparse compressed \\(non-block\\) tensors are supported"
340
                                    ", but got size"):
341
            torch.empty((5,), dtype=dtype, device=device, layout=layout)
342

343
    @skipMeta
344
    @all_sparse_compressed_layouts()
345
    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
346
    def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
347

348
        def get_sparse_compressed_tensor_properties(s):
349
            if layout in {torch.sparse_csr, torch.sparse_bsr}:
350
                compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
351
            else:
352
                compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
353
            values = s.values()
354
            return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
355
                        compressed_indices_shape=compressed_indices.shape,
356
                        compressed_indices_dtype=compressed_indices.dtype,
357
                        compressed_indices_device=compressed_indices.device,
358
                        plain_indices_shape=plain_indices.shape,
359
                        plain_indices_dtype=plain_indices.dtype,
360
                        plain_indices_device=plain_indices.device,
361
                        values_shape=values.shape,
362
                        values_dtype=values.dtype,
363
                        values_device=values.device)
364

365
        for index_dtype in [torch.int32, torch.int64]:
366
            for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
367
                dense_dim = t.dense_dim()
368
                sparse_dim = t.sparse_dim()
369
                batch_dim = t.ndim - sparse_dim - dense_dim
370
                nnz = t.values().shape[batch_dim]
371
                if layout in {torch.sparse_bsr, torch.sparse_bsc}:
372
                    blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
373
                else:
374
                    blocksize = ()
375

376
                e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
377
                                                                       dtype=dtype, layout=layout, device=device)
378

379
                e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
380
                for k, v in e_prop.items():
381
                    self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
382

383
    @skipMeta
384
    @all_sparse_compressed_layouts()
385
    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
386
    def test_clone(self, layout, device, dtype):
387
        for sparse in self.generate_simple_inputs(
388
                layout, device=device, dtype=dtype, index_dtype=torch.int32):
389
            cloned_sparse = sparse.clone()
390
            self.assertEqual(sparse, cloned_sparse)
391

392
    @all_sparse_compressed_layouts()
393
    def test_print(self, layout, device):
394
        compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
395
        printed = []
396
        for enable_hybrid in [False, True]:
397
            # using local patterns for test_print stability
398
            patterns = [
399
                # 2 x 3 batch of 3 x 2 tensors, trivial blocksize, non-hybrid/hybrid:
400
                ([[[[1, 2, 0],
401
                    [1, 0, 3]],
402
                   [[1, 2, 3],
403
                    [1, 0, 0]],
404
                   [[1, 0, 0],
405
                    [1, 2, 3]]],
406
                  [[[0, 2, 0],
407
                    [1, 2, 3]],
408
                   [[1, 0, 3],
409
                    [1, 2, 0]],
410
                   [[1, 2, 3],
411
                    [0, 2, 0]]]], [(2, 1)], [(), (4,)] if enable_hybrid else [()]),
412
                # tensor with non-trivial blocksize, non-hybrid/hybrid:
413
                ([[0, 1, 0, 2, 0, 2],
414
                  [0, 1, 0, 0, 2, 0],
415
                  [3, 3, 3, 0, 0, 0],
416
                  [0, 0, 0, 0, 0, 0],
417
                  [0, 5, 0, 6, 6, 6],
418
                  [5, 0, 5, 6, 6, 6],
419
                  [0, 0, 0, 0, 8, 8],
420
                  [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 2)] if enable_hybrid else [()]),
421
            ]
422
            for index_dtype in [torch.int32, torch.int64]:
423
                for dtype in [torch.float32, torch.float64]:
424
                    for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
425
                            layout, device=device, dtype=dtype, index_dtype=index_dtype, enable_hybrid=enable_hybrid,
426
                            enable_non_contiguous_indices=False, enable_non_contiguous_values=False,
427
                            enable_zero_sized=False, output_tensor=False, patterns=patterns):
428
                        size = tuple(kwargs['size'])
429
                        block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
430
                        base_ndim = 2
431
                        batch_ndim = compressed_indices.dim() - 1
432
                        dense_ndim = values.dim() - batch_ndim - block_ndim - 1
433
                        if enable_hybrid and dense_ndim == 0:
434
                            # non-hybrid cases are covered by the enable_hybrid==False loop
435
                            continue
436
                        batchsize = size[:batch_ndim]
437
                        basesize = size[batch_ndim:batch_ndim + base_ndim]
438
                        densesize = size[batch_ndim + base_ndim:]
439
                        assert len(densesize) == dense_ndim
440
                        printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{basesize}+{densesize} ##########")
441
                        x = torch.sparse_compressed_tensor(compressed_indices,
442
                                                           plain_indices,
443
                                                           values, size, dtype=dtype, layout=layout, device=device)
444
                        printed.append("# sparse tensor")
445
                        printed.append(str(x))
446
                        printed.append(f"# _{compressed_indices_mth.__name__}")
447
                        printed.append(str(compressed_indices_mth(x)))
448
                        printed.append(f"# _{plain_indices_mth.__name__}")
449
                        printed.append(str(plain_indices_mth(x)))
450
                        printed.append("# _values")
451
                        printed.append(str(x.values()))
452
                        printed.append('')
453
                    printed.append('')
454
        orig_maxDiff = self.maxDiff
455
        self.maxDiff = None
456
        try:
457
            self.assertExpected('\n'.join(printed))
458
            self.maxDiff = orig_maxDiff
459
        except Exception:
460
            self.maxDiff = orig_maxDiff
461
            raise
462

463
    @skipMeta
464
    @all_sparse_compressed_layouts()
465
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
466
    def test_copy(self, layout, device, dtype):
467

468
        def run_test(shape, blocksize, nnz, index_type):
469
            a = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
470
                                               index_dtype=index_dtype, blocksize=blocksize)
471
            b = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
472
                                               index_dtype=index_dtype, blocksize=blocksize)
473

474
            a.copy_(b)
475

476
            self.assertEqual(a, b)
477

478
        ns = [(9, 3), (2, 1), (0, 0)]  # (number of dimensions, the corresponding block size)
479
        batch_shapes = [(), (2,), (2, 3)]
480
        for ((m, bm), (n, bn), b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
481
            blocksize = (bm, bn) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
482
            run_test((*b, m, n), blocksize, 0, index_dtype)
483
            run_test((*b, m, n), blocksize, m * n, index_dtype)
484

485
    @skipMeta
486
    @all_sparse_compressed_layouts()
487
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
488
    def test_copy_errors(self, layout, device, dtype):
489
        blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
490
        nnz = 6 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 1
491
        shape1 = (2 * 6, 3 * 6) if layout in {torch.sparse_bsr, torch.sparse_bsc} else (2, 3)
492
        for index_dtype in [torch.int32, torch.int64]:
493
            a = self.genSparseCompressedTensor(shape1, 0, dtype=dtype, layout=layout, device=device,
494
                                               index_dtype=index_dtype, blocksize=blocksize)
495

496
            with self.assertRaisesRegex(RuntimeError,
497
                                        "copy of sparse compressed tensors having different layouts is not supported."):
498
                a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
499

500
            b = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device,
501
                                               index_dtype=index_dtype, blocksize=blocksize)
502
            assert a._nnz() != b._nnz(), (a._nnz(), b._nnz())
503
            with self.assertRaisesRegex(RuntimeError,
504
                                        "only sparse compressed tensors with the same number of specified elements are supported."):
505
                a.copy_(b)
506

507
            shape2 = tuple(reversed(shape1))
508
            c = self.genSparseCompressedTensor(shape2, nnz, dtype=dtype, layout=layout, device=device,
509
                                               index_dtype=index_dtype, blocksize=blocksize)
510
            with self.assertRaisesRegex(
511
                    RuntimeError,
512
                    "expected shapes of self and src to match along dimension"):
513
                b.copy_(c)
514

515
            if blocksize:
516
                blocksize1 = tuple(reversed(blocksize))
517
                d = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device,
518
                                                   index_dtype=index_dtype, blocksize=blocksize1)
519
                with self.assertRaisesRegex(RuntimeError,
520
                                            "copy of sparse compressed tensors having different block sizes is not supported"):
521
                    b.copy_(d)
522

523
    def _smallest_divisor(self, n):
524
        for i in range(2, int(n ** 0.5) + 1):
525
            if n % i == 0:
526
                return i
527
        return n
528

529
    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
530
    @all_sparse_compressed_layouts()
531
    @ops(_sparse_compressed_ops)
532
    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
533
    def test_consistency(self, layout, device, dtype, op):
534
        """Checks that the op on a strided and on a sparse tensors will
535
        produce the same results.
536
        """
537
        if not op.supports_sparse_layout(layout):
538
            self.skipTest(f"{op.name} does not support input with {layout} layout")
539

540
        # FIXME: remove in followup once integer support is landed for segment_reduce
541
        if (layout == torch.sparse_csr and not dtype.is_floating_point
542
                and op.name in ('masked.mean', 'masked.amax', 'masked.amin')):
543
            self.skipTest(f"{op.name} does not support input with {layout} layout and {dtype} dtype")
544

545
        require_mask = isinstance(op, ReductionOpInfo) and 'masked.' in op.name
546

547
        samples = []
548
        for sample in op.sample_inputs(device, dtype):
549
            if sample.input.ndim < 2:
550
                continue
551
            dense_dim = sample.input.ndim - 2
552
            blocksize = (tuple(map(self._smallest_divisor, sample.input.shape[:2]))
553
                         if layout in {torch.sparse_bsr, torch.sparse_bsc} else None)
554

555
            def _to_sparse(x):
556
                if isinstance(x, torch.Tensor):
557
                    if blocksize is None:
558
                        if x.ndim != sample.input.ndim:
559
                            return x
560
                    elif x.ndim != sample.input.ndim + 2 or x.shape[-3] % blocksize[0] or x.shape[-2] % blocksize[1]:
561
                        return x
562
                    return x.clone().to_sparse(layout=layout, blocksize=blocksize, dense_dim=dense_dim)
563
                return x
564

565
            sparse_sample = sample.transform(_to_sparse)
566
            # Some strided samples (with inf, nan elements) appear to share
567
            # storage, so we must clone:
568
            sample = sample.transform(lambda x: (x.clone() if isinstance(x, torch.Tensor) else x))
569

570
            if validate_sample_input_sparse(op, sparse_sample, check_validate=False) is not sparse_sample:
571
                # that is, the validation returns the sparse sample
572
                # wrapped within ErrorInput instance
573
                continue
574
            samples.append((sample, sparse_sample))
575

576
        # Fail early to prevent silent success with this test
577
        if len(samples) == 0:
578
            raise ValueError("Expected at least one 2 or higher D tensor in samples.")
579

580
        # Re-define atol and rtol for operations that result values
581
        # are random (and hence, non-comparable) be we still want to
582
        # check the shape, dtype, etc attributes of the results:
583
        atol = rtol = None
584
        if op.name == 'randn_like':
585
            atol = 1e300
586
            rtol = 1
587

588
        for sample, sparse_sample in samples:
589
            expected = op(sample.input, *sample.args, **sample.kwargs)
590
            assert torch.is_tensor(expected)
591
            output = op(sparse_sample.input, *sparse_sample.args, **sparse_sample.kwargs)
592
            assert torch.is_tensor(output)
593
            strided_output = output.to_dense()
594
            if require_mask and sample.kwargs.get('mask') is not None:
595
                output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
596
                expected.masked_fill_(~output_mask, 0)
597
            self.assertEqual(strided_output, expected, atol=atol, rtol=rtol)
598

599
    @skipMeta
600
    @all_sparse_compressed_layouts()
601
    @all_sparse_compressed_layouts('layout2')
602
    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
603
    def test_empty_like(self, layout, layout2, device, dtype):
604
        for sparse in self.generate_simple_inputs(layout):
605
            if layout == layout2:
606
                result = torch.empty_like(sparse, layout=layout2)
607
                compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout]
608
                torch._validate_sparse_compressed_tensor_args(compressed_indices_mth(result),
609
                                                              plain_indices_mth(result),
610
                                                              result.values(),
611
                                                              result.shape,
612
                                                              result.layout)
613
                self.assertEqual(sparse.shape, result.shape)
614
            else:
615
                self.assertRaisesRegex(
616
                    RuntimeError,
617
                    "empty_like with different sparse layout is not supported",
618
                    lambda: torch.empty_like(sparse, layout=layout2)
619
                )
620

621
    @skipMeta
622
    @all_sparse_compressed_layouts()
623
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
624
    def test_validate(self, layout, device, dtype):
625
        def make_zero_batched(t):
626
            return torch.empty(*((0,) + t.shape), dtype=t.dtype, device=t.device)
627

628
        for index_dtype in [torch.int32, torch.int64]:
629
            for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
630
                    layout, device=device, dtype=dtype, index_dtype=index_dtype, output_tensor=False):
631
                size = kwargs['size']
632
                torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout)
633

634
                # check empty batch
635
                torch._validate_sparse_compressed_tensor_args(
636
                    *(make_zero_batched(t) for t in (compressed_indices, plain_indices, values)),
637
                    (0,) + size,
638
                    layout
639
                )
640

641
            compressed_indices = torch.tensor([0, 0], dtype=index_dtype)
642
            plain_indices = torch.tensor([], dtype=index_dtype)
643
            torch._validate_compressed_sparse_indices(layout in {torch.sparse_csr, torch.sparse_bsr},
644
                                                      compressed_indices, plain_indices, 1, 1, 0)
645

646
    def _generate_invalid_input(self, layout, device):
647
        from functools import partial
648

649
        def shape(shape, basedim=0):
650
            blocksize = (1, 1)
651
            if layout is torch.sparse_csc:
652
                shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:]
653
            elif layout is torch.sparse_bsc:
654
                shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:]
655
            elif layout is torch.sparse_bsr:
656
                shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:]
657
            return shape
658

659
        def values(lst, device=device):
660
            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
661
                lst = [[[item]] for item in lst]
662
            return torch.tensor(lst, device=device)
663

664
        tensor = partial(torch.tensor, device=device)
665
        values = partial(values, device=device)
666

667
        yield ('incontiguous compressed_indices',
668
               tensor([0, -1, 2, -1, 4, -1])[::2],
669
               tensor([0, 1, 0, 2]),
670
               values([1, 2, 3, 4]),
671
               shape((2, 3)),
672
               'expected compressed_indices to be a contiguous tensor per batch')
673

674
        yield ('incontiguous plain_indices',
675
               tensor([0, 2, 4]),
676
               tensor([0, -1, 1, -1, 0, -1, 2, -1])[::2],
677
               values([1, 2, 3, 4]),
678
               shape((2, 3)),
679
               'expected plain_indices to be a contiguous tensor per batch')
680

681
        yield ('0-D compressed_indices',
682
               tensor(0),
683
               tensor([0, 1, 0, 2]),
684
               values([1, 2, 3, 4]),
685
               shape((2, 3)),
686
               'compressed_indices must have dimensionality >= 1 but got 0')
687

688
        yield ('compressed/plain_indices mismatch of dimensionalities',
689
               tensor([[0, 2, 4]]),
690
               tensor([0, 1, 0, 2]),
691
               values([1, 2, 3, 4]),
692
               shape((2, 3)),
693
               'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively')
694

695
        if layout in {torch.sparse_csr, torch.sparse_csc}:
696
            yield ('indices and values mismatch of dimensionalities',
697
                   tensor([[0, 2, 4]]),
698
                   tensor([[0, 1, 0, 2]]),
699
                   values([1, 2, 3, 4]),
700
                   shape((2, 3)),
701
                   r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1')
702
        else:
703
            yield ('indices and values mismatch of dimensionalities',
704
                   tensor([[0, 2, 4]]),
705
                   tensor([[0, 1, 0, 2]]),
706
                   values([1, 2, 3, 4]),
707
                   shape((2, 3)),
708
                   r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 2\) but got 3')
709

710
        yield ('invalid size',
711
               tensor([0, 2, 4]),
712
               tensor([0, 1, 0, 2]),
713
               values([1, 2, 3, 4]),
714
               (2,),
715
               r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1')
716

717
        yield ('invalid batchsize',
718
               tensor([[0, 2, 4]]),
719
               tensor([[0, 1, 0, 2]]),
720
               values([[1, 2, 3, 4]]),
721
               shape((2, 2, 3), 1),
722
               r'all batch dimensions of compressed_indices \(=\[1\]\), plain_indices \(=\[1\]\), '
723
               r'and values \(=\[1\]\) must be equal to tensor batch dimensions \(=\[2\]\)')
724

725
        if layout is torch.sparse_bsr:
726
            yield ('invalid blocksize',
727
                   tensor([0, 2, 4]),
728
                   tensor([0, 1, 0, 2]),
729
                   tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]),
730
                   shape((2, 3)),
731
                   r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape')
732

733
        if layout is torch.sparse_bsc:
734
            yield ('invalid blocksize',
735
                   tensor([0, 2, 4]),
736
                   tensor([0, 1, 0, 2]),
737
                   tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]),
738
                   shape((3, 2)),
739
                   r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape')
740

741
        yield ('invalid compressed_indices shape',
742
               tensor([0, 2, 3, 4]),
743
               tensor([0, 1, 0, 2]),
744
               values([1, 2, 3, 4]),
745
               shape((2, 3)),
746
               r'compressed_indices.shape\[-1\] must be equal to the number of compressed_indices_names \+ 1 \(=3\), but got 4')
747

748
        yield ('invalid compressed_indices shape',
749
               tensor([0, 2, 4]),
750
               tensor([0, 1, 0, 1, 2]),
751
               values([1, 2, 3, 4]),
752
               shape((2, 3)),
753
               r'plain_indices.shape\[-1\] must be equal to nnz \(=4\) as defined by values.shape\[0\], but got 5')
754

755
        yield ('compressed/plain_indices mismatch of dtype',
756
               tensor([0, 2, 4], dtype=torch.int32),
757
               tensor([0, 1, 0, 2], dtype=torch.int64),
758
               values([1, 2, 3, 4]),
759
               shape((2, 3)),
760
               r'compressed_indices and plain_indices must have the same dtype, bot got Int and Long, respectively')
761

762
        yield ('invalid compressed/plain_indices dtype',
763
               tensor([0, 2, 4], dtype=torch.int16),
764
               tensor([0, 1, 0, 2], dtype=torch.int16),
765
               values([1, 2, 3, 4]),
766
               shape((2, 3)),
767
               r'compressed_indices and plain_indices dtype must be Int or Long, but got Short')
768

769
        # CUDA kernel asserts are not recoverable, so we skip these for now
770
        if torch.device(device).type == 'cpu':
771
            yield ('invalid compressed_indices[0]',
772
                   tensor([1, 2, 4]),
773
                   tensor([0, 1, 0, 2]),
774
                   values([1, 2, 3, 4]),
775
                   shape((2, 3)),
776
                   r'`compressed_indices\[..., 0\] == 0` is not satisfied.')
777

778
            yield ('invalid compressed_indices[0] when nnz == 0',
779
                   tensor([1, 0], dtype=torch.int64),
780
                   tensor([], dtype=torch.int64),
781
                   values([1])[:0],
782
                   shape((1, 1)),
783
                   r'`compressed_indices\[..., 0\] == 0` is not satisfied.')
784

785
            yield ('invalid compressed_indices[-1]',
786
                   tensor([0, 2, 5]),
787
                   tensor([0, 1, 0, 2]),
788
                   values([1, 2, 3, 4]),
789
                   shape((2, 3)),
790
                   r'`compressed_indices\[..., -1\] == nnz` is not satisfied.')
791

792
            yield ('invalid compressed_indices[-1] when nnz == 0',
793
                   tensor([0, 1], dtype=torch.int64),
794
                   tensor([], dtype=torch.int64),
795
                   values([1])[:0],
796
                   shape((1, 1)),
797
                   r'`compressed_indices\[..., -1\] == nnz` is not satisfied.')
798

799
            yield ('invalid compressed_indices.diff(dim=-1)',
800
                   tensor([0, 0, 4]),
801
                   tensor([0, 1, 0, 2]),
802
                   values([1, 2, 3, 4]),
803
                   shape((2, 3)),
804
                   r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.')
805

806
            yield ('invalid compressed_indices.diff(dim=-1)',
807
                   tensor([0, 5, 4]),
808
                   tensor([0, 1, 0, 2]),
809
                   values([1, 2, 3, 4]),
810
                   shape((2, 3)),
811
                   r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.')
812

813
            yield ('invalid min(plain_indices)',
814
                   tensor([0, 2, 4]),
815
                   tensor([0, -1, 0, 3]),
816
                   values([1, 2, 3, 4]),
817
                   shape((2, 3)),
818
                   r'`0 <= plain_indices < plain_dim` is not satisfied.')
819

820
            yield ('invalid max(plain_indices)',
821
                   tensor([0, 2, 4]),
822
                   tensor([0, 1, 0, 3]),
823
                   values([1, 2, 3, 4]),
824
                   shape((2, 3)),
825
                   r'`0 <= plain_indices < plain_dim` is not satisfied.')
826

827
            yield ('non-coalesced',
828
                   tensor([0, 2, 4]),
829
                   tensor([1, 0, 0, 2]),
830
                   values([1, 2, 3, 4]),
831
                   shape((2, 3)),
832
                   r'`plain_indices\[..., compressed_indices\[..., i - 1\]:compressed_indices\[..., i\]\] '
833
                   'for all i = 1, ..., compressed_dim '
834
                   'are sorted and distinct along the last dimension values` is not satisfied.')
835

836
        if TEST_CUDA and torch.device(device).type == 'cpu':
837
            yield ('indices and values mismatch of device',
838
                   torch.tensor([0, 2, 4]),
839
                   torch.tensor([0, 1, 0, 1]),
840
                   values([1, 2, 3, 4], device='cuda'),
841
                   shape((2, 3)),
842
                   r'device of compressed_indices \(=cpu\) must match device of values \(=cuda:0\)')
843
            yield ('compressed_indices and values mismatch of device',
844
                   torch.tensor([0, 2, 4], device='cuda'),
845
                   torch.tensor([0, 1, 0, 1]),
846
                   values([1, 2, 3, 4]),
847
                   shape((2, 3)),
848
                   r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!')
849
            yield ('compressed/plain_indices mismatch of device',
850
                   torch.tensor([0, 2, 4], device='cuda'),
851
                   torch.tensor([0, 1, 0, 1]),
852
                   values([1, 2, 3, 4], device='cuda'),
853
                   shape((2, 3)),
854
                   r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!')
855

856
        if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2:
857
            yield ('indices and values mismatch of device index',
858
                   torch.tensor([0, 2, 4], device='cuda:0'),
859
                   torch.tensor([0, 1, 0, 1], device='cuda:0'),
860
                   values([1, 2, 3, 4], device='cuda:1'),
861
                   shape((2, 3)),
862
                   r'device of compressed_indices \(=cuda:0\) must match device of values \(=cuda:1\)')
863
            yield ('compressed_indices and values mismatch of device index',
864
                   torch.tensor([0, 2, 4], device='cuda:0'),
865
                   torch.tensor([0, 1, 0, 1], device='cuda:1'),
866
                   values([1, 2, 3, 4], device='cuda:0'),
867
                   shape((2, 3)),
868
                   r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!')
869

870
    @skipMeta
871
    @all_sparse_compressed_layouts()
872
    @parametrize('target', [subtest('validate_sparse_compressed_tensor_args'),
873
                            subtest('sparse_compressed_tensor'),
874
                            subtest('sparse_compressed_tensor_no_size')])
875
    def test_invalid_input(self, layout, device, target):
876
        for label, compressed_indices, plain_indices, values, size, errmsg in self._generate_invalid_input(layout, device):
877
            if layout is torch.sparse_bsr:
878
                errmsg = errmsg.replace('compressed_indices_name', 'row block').replace('plain_indices_name', 'column block')
879
            elif layout is torch.sparse_bsc:
880
                errmsg = errmsg.replace('compressed_indices_name', 'column block').replace('plain_indices_name', 'row block')
881
            elif layout is torch.sparse_csr:
882
                errmsg = errmsg.replace('compressed_indices_name', 'row').replace('plain_indices_name', 'column')
883
            elif layout is torch.sparse_csc:
884
                errmsg = errmsg.replace('compressed_indices_name', 'column').replace('plain_indices_name', 'row')
885
            if layout in {torch.sparse_csr, torch.sparse_bsr}:
886
                errmsg = errmsg.replace('compressed_indices', 'crow_indices') \
887
                               .replace('plain_indices', 'col_indices') \
888
                               .replace('plain_dim', 'ncols') \
889
                               .replace('compressed_dim', 'nrows')
890
            else:
891
                errmsg = errmsg.replace('compressed_indices', 'ccol_indices') \
892
                               .replace('plain_indices', 'row_indices') \
893
                               .replace('plain_dim', 'nrows') \
894
                               .replace('compressed_dim', 'ncols')
895

896
            if target == 'sparse_compressed_tensor_no_size' and label in {
897
                    'invalid size', 'invalid batchsize', 'invalid compressed_indices shape', 'invalid max(plain_indices)',
898
                    'invalid blocksize'}:
899
                # Skip invalid size input as a valid size is estimated for other inputs
900
                continue
901

902
            with self.assertRaisesRegex(RuntimeError, errmsg):
903
                if target == 'validate_sparse_compressed_tensor_args':
904
                    torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout)
905
                elif target == 'sparse_compressed_tensor':
906
                    torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout)
907
                elif target == 'sparse_compressed_tensor_no_size':
908
                    torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=layout)
909
                else:
910
                    raise NotImplementedError(target)
911

912
    @skipMeta
913
    @onlyCPU
914
    @largeTensorTest("30GB", "cpu")
915
    def test_invalid_input_csr_large(self):
916
        rows = 2 ** 31
917
        with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in row dimension'):
918
            torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int32) // rows,
919
                                    torch.tensor([0], dtype=torch.int32),
920
                                    torch.tensor([1]), (rows, 1))
921
        torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int64) // rows,
922
                                torch.tensor([0], dtype=torch.int64),
923
                                torch.tensor([1]), (rows, 1))
924

925
        cols = 2 ** 31
926
        with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in column dimension'):
927
            torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int32),
928
                                    torch.tensor([0], dtype=torch.int32),
929
                                    torch.tensor([1]), (1, cols))
930
        torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int64),
931
                                torch.tensor([0], dtype=torch.int64),
932
                                torch.tensor([1]), (1, cols))
933

934
        nnz = 2 ** 31
935
        with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in nnz'):
936
            # nnz cannot be stored in int32 crow_indices
937
            # but the `crow_indices[..., -1] == nnz`` check happens after the overflow validation
938
            # So we can use `nnz - 1` here to avoid `value cannot be converted to type int32 without overflow`
939
            # during construction of crow_indices
940
            torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz - 1], dtype=torch.int32),
941
                                    torch.arange(nnz // 2, dtype=torch.int32).repeat(2),
942
                                    torch.ones(nnz, dtype=torch.int8), (2, nnz // 2))
943
        torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz], dtype=torch.int64),
944
                                torch.arange(nnz // 2, dtype=torch.int64).repeat(2),
945
                                torch.ones(nnz, dtype=torch.int8), (2, nnz // 2))
946

947
    @skipMeta
948
    @onlyCPU
949
    @all_sparse_compressed_layouts()
950
    def test_dim(self, layout):
951
        for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(layout, output_tensor=False):
952
            size = kwargs['size']
953
            batch_dim = compressed_indices.dim() - 1
954
            sparse_dim = 2
955
            block_dim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
956
            dense_dim = values.dim() - batch_dim - block_dim - 1
957
            sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout)
958
            self.assertEqual(sparse.sparse_dim(), sparse_dim)
959
            self.assertEqual(sparse.dense_dim(), dense_dim)
960

961

962
    @skipMeta
963
    @all_sparse_compressed_layouts()
964
    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
965
    def test_to_dtype(self, layout, device, dtype):
966
        # to_dense does not support hybrid inputs
967
        for sparse in self.generate_simple_inputs(layout, dtype=dtype, device=device, enable_hybrid=False):
968
            for to_dtype in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16):
969
                sparse_to_dtype = sparse.to(to_dtype)
970
                dense_to_dtype = sparse.to_dense().to(to_dtype)
971
                self.assertEqual(sparse_to_dtype.to_dense(), dense_to_dtype)
972

973
    @skipMeta
974
    @all_sparse_compressed_layouts()
975
    @dtypes(torch.double)
976
    def test_pickle(self, layout, dtype, device):
977
        import pickle
978

979
        for sparse in self.generate_simple_inputs(layout, device=device, dtype=dtype):
980
            serialized = pickle.dumps(sparse)
981
            sparse_loaded = pickle.loads(serialized)
982

983
            self.assertEqual(sparse, sparse_loaded)
984

985
    @all_sparse_compressed_layouts()
986
    @parametrize("index_dtype", [torch.int32, torch.int64])
987
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
988
    def test_select_copy(self, device, dtype, index_dtype, layout):
989

990
        def is_view_of(base, other):
991
            # a shameless copy of TestViewOps.is_view_of
992
            if (
993
                not other._is_view() or
994
                other is base or
995
                other._base is not base or
996
                base.device != other.device
997
            ):
998
                return False
999
            if base.device.type in ('cpu', 'cuda'):
1000
                if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
1001
                    return False
1002
            return True
1003

1004
        kwargs = dict(device=device, dtype=dtype, index_dtype=index_dtype)
1005
        for sparse, dense in zip(self.generate_simple_inputs(layout, **kwargs),
1006
                                 self.generate_simple_inputs(torch.strided, **kwargs)):
1007
            if layout in {torch.sparse_csr, torch.sparse_bsr}:
1008
                n_batchdim = sparse.crow_indices().ndim - 1
1009
            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
1010
                n_batchdim = sparse.ccol_indices().ndim - 1
1011
            else:
1012
                assert 0  # unreachable
1013
            self.assertEqual(sparse, dense)
1014
            for dim in range(sparse.ndim):
1015
                if sparse.shape[dim] == 0:
1016
                    with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"):
1017
                        torch.select_copy(sparse, dim, 0)
1018
                    with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"):
1019
                        torch.select_copy(dense, dim, 0)
1020
                elif n_batchdim and dim >= n_batchdim and dim < n_batchdim + 2:
1021
                    with self.assertRaisesRegex(
1022
                            RuntimeError,
1023
                            "selecting sparse dimensions is not supported for batched sparse compressed tensors"):
1024
                        torch.select_copy(sparse, dim, 0)
1025
                else:
1026
                    for index in {0, sparse.shape[dim] // 2, sparse.shape[dim] - 1}:
1027
                        dense_select = torch.select_copy(dense, dim, index)
1028
                        sparse_select = torch.select_copy(sparse, dim, index)
1029
                        self.assertEqual(sparse_select, dense_select)
1030
                        self.assertFalse(is_view_of(sparse_select.values(), sparse.values()))
1031

1032

1033
def _npref_block_addmm_addmv(c, a, b, alpha, beta):
1034
    return alpha * (a @ b) + beta * c
1035

1036

1037
class TestSparseCSR(TestCase):
1038

1039
    def test_csr_stride(self):
1040
        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1041

1042
        with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"):
1043
            a.stride()
1044

1045
        with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"):
1046
            a.stride(-1)
1047

1048
    def test_csr_storage(self):
1049
        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1050

1051
        with self.assertRaisesRegex(RuntimeError, "Cannot access storage of SparseCsrTensorImpl"):
1052
            a.storage()
1053

1054
    def test_csr_is_contiguous(self):
1055
        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1056

1057
        with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have is_contiguous"):
1058
            a.is_contiguous()
1059

1060
    @onlyCPU
1061
    @largeTensorTest("20GB", "cpu")
1062
    def test_csr_nnz(self):
1063
        # Tests the limits of the number of specified elements in CSR tensors, see gh-102520.
1064
        for nnz in [0, 2**31]:
1065
            rows, cols = 1, max(nnz, 1)
1066
            crow_indices = torch.tensor([0, nnz], dtype=torch.int64)
1067
            col_indices = torch.arange(nnz, dtype=torch.int64)
1068
            values = torch.ones(nnz, dtype=torch.int8)
1069
            a = torch.sparse_csr_tensor(crow_indices, col_indices, values, (rows, cols))
1070
            self.assertEqual(a._nnz(), nnz)
1071

1072
    def test_csr_double_to_sparse_csr(self):
1073
        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1074
        a.to_sparse_csr().to_sparse_csr()
1075

1076
    @all_sparse_compressed_layouts()
1077
    @parametrize("index_dtype", [torch.int32, torch.int64])
1078
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1079
    def test_select(self, device, dtype, index_dtype, layout):
1080
        compressed_indices_mth = {
1081
            torch.sparse_csr: torch.Tensor.crow_indices,
1082
            torch.sparse_bsr: torch.Tensor.crow_indices,
1083
            torch.sparse_csc: torch.Tensor.ccol_indices,
1084
            torch.sparse_bsc: torch.Tensor.ccol_indices,
1085
        }[layout]
1086

1087
        plain_indices_mth = {
1088
            torch.sparse_csr: torch.Tensor.col_indices,
1089
            torch.sparse_bsr: torch.Tensor.col_indices,
1090
            torch.sparse_csc: torch.Tensor.row_indices,
1091
            torch.sparse_bsc: torch.Tensor.row_indices,
1092
        }[layout]
1093
        create_tensor_mth = {
1094
            torch.sparse_csr: torch.sparse_csr_tensor,
1095
            torch.sparse_bsr: torch.sparse_bsr_tensor,
1096
            torch.sparse_csc: torch.sparse_csc_tensor,
1097
            torch.sparse_bsc: torch.sparse_bsc_tensor,
1098
        }[layout]
1099

1100
        shape = (2, 3, 6, 10)
1101
        nnz = 6
1102
        blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
1103
        sparse = self.genSparseCompressedTensor(
1104
            shape, nnz, device=device, layout=layout, dtype=dtype, index_dtype=index_dtype, blocksize=blocksize)
1105
        comp_indices = compressed_indices_mth(sparse)
1106
        plain_indices = plain_indices_mth(sparse)
1107
        values = sparse.values()
1108

1109
        # select from batch dimensions
1110
        sparse_selected12 = sparse.select(1, 2)
1111
        expected_sparse_selected12 = create_tensor_mth(comp_indices.select(1, 2).contiguous(),
1112
                                                       plain_indices.select(1, 2).contiguous(),
1113
                                                       values.select(1, 2).contiguous(),
1114
                                                       size=(2, 6, 10),
1115
                                                       dtype=dtype,
1116
                                                       device=device)
1117
        self.assertEqual(expected_sparse_selected12, sparse_selected12)
1118

1119
        # selecting rows/col with batch dims not allowed
1120
        sparse_non_batched = sparse[0, 0]
1121
        # select from sparse dimensions
1122
        for select_args in [(0, 0), (1, 1)]:
1123
            sparse_selected = sparse_non_batched.select(*select_args)
1124
            dense_selected = sparse_non_batched.to_dense().select(*select_args)
1125
            self.assertEqual(dense_selected, sparse_selected)
1126

1127
        self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0])
1128
        # assigning to sparse through indexing is disabled
1129
        with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
1130
            sparse[0, 0, 0, 0] = 99.0
1131

1132
        # select from sparse dimensions without removing batch dims
1133
        msg = "selecting sparse dimensions is not supported for batched sparse compressed tensors."
1134
        with self.assertRaisesRegex(RuntimeError, msg):
1135
            sparse.select(-2, 0)
1136

1137
        with self.assertRaisesRegex(RuntimeError, msg):
1138
            sparse.select(-1, 0)
1139

1140
    @skipMeta
1141
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1142
    def test_resize(self, device, dtype):
1143

1144
        def numel(tensor):
1145
            r = 1
1146
            for s in tensor.shape:
1147
                r *= s
1148
            return r
1149

1150
        batch_shapes = [(), (2,), (2, 3)]
1151
        for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
1152
            shape = (*b, 2, 3)
1153
            nnz = 6
1154
            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1155
            self.assertEqual(a.numel(), numel(a))
1156

1157
            new_shape = (*b, 4, 5)
1158
            a.resize_(new_shape)
1159

1160
            self.assertEqual(a.shape, new_shape)
1161
            # resize to larger shape doesn't add specified elements
1162
            self.assertEqual(a._nnz(), nnz)
1163
            self.assertEqual(a.numel(), numel(a))
1164

1165
            new_shape = (*b, 1, 5)
1166
            a.resize_(new_shape)
1167

1168
            self.assertEqual(a.shape, new_shape)
1169
            # resize to smaller shape trims specified elements
1170
            self.assertEqual(a._nnz(), 5)
1171
            self.assertEqual(a.numel(), numel(a))
1172

1173
            # trim batched dimensions
1174
            a.resize_(new_shape[-2], new_shape[-1])
1175
            self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
1176
            self.assertEqual(a._nnz(), 5)
1177
            self.assertEqual(a.numel(), numel(a))
1178

1179
    @skipMeta
1180
    @dtypes(torch.float, torch.bool)
1181
    @all_sparse_compressed_layouts()
1182
    def test_resize_as_sparse_compressed(self, device, dtype, layout):
1183

1184
        def _check_resize_b_as_a(b, a):
1185
            br = b.clone()
1186
            br.resize_as_sparse_(a)
1187

1188
            # shape is inherited from a
1189
            self.assertEqual(a.shape, br.shape)
1190
            # other metadata is not affected
1191
            self.assertEqual(b.layout, br.layout)
1192
            self.assertEqual(b.device, br.device)
1193
            self.assertEqual(b.dtype, br.dtype)
1194

1195
            def _get_compressed_plain_inds(t):
1196
                compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[t.layout]
1197
                return compressed_indices_mth(t), plain_indices_mth(t)
1198

1199
            br_compressed_indices, br_plain_indices = _get_compressed_plain_inds(br)
1200
            br_values = br.values()
1201

1202
            b_compressed_indices, b_plain_indices = _get_compressed_plain_inds(b)
1203
            a_compressed_indices, a_plain_indices = _get_compressed_plain_inds(a)
1204
            self.assertEqual(a_plain_indices.shape, br_plain_indices.shape)
1205
            self.assertEqual(a_compressed_indices.shape, br_compressed_indices.shape)
1206
            # We don't check the content of br_plain_indices and br_compressed_indices
1207
            # because it is not well-defined (the content depends on the original
1208
            # shape of `b` that `resize_as` ought to discard) nor needed (the
1209
            # subsequent operation likely updates the indices and values of `b` anyway).
1210
            # the device/dtype of indices should always be unaffected
1211
            self.assertEqual(b_plain_indices.dtype, br_plain_indices.dtype)
1212
            self.assertEqual(b_plain_indices.device, br_plain_indices.device)
1213
            self.assertEqual(b_compressed_indices.dtype, br_compressed_indices.dtype)
1214
            self.assertEqual(b_compressed_indices.device, br_compressed_indices.device)
1215
            # values are generated empty, shape is updated
1216
            self.assertEqual(a.values().shape, br_values.shape)
1217
            # the device/dtype of indices should always be unaffected
1218
            b_values = b.values()
1219
            self.assertEqual(b_values.dtype, br_values.dtype)
1220
            self.assertEqual(b_values.device, br_values.device)
1221
            # nnz will be picked up from a via new shape of values
1222
            self.assertEqual(a._nnz(), br._nnz())
1223

1224
            # post resize the invariants of the layout are respected
1225
            torch._validate_sparse_compressed_tensor_args(br_compressed_indices, br_plain_indices, br_values, br.shape,
1226
                                                          br.layout)
1227

1228
        block_sparse = layout in (torch.sparse_bsr, torch.sparse_bsc)
1229
        shape = (2, 1, 6, 4)
1230
        nnz = 4
1231
        blocksize = (2, 1) if block_sparse else ()
1232
        for index_dtype in [torch.int32, torch.int64]:
1233
            a = self.genSparseCompressedTensor(shape,
1234
                                               layout=layout,
1235
                                               device=device,
1236
                                               index_dtype=index_dtype,
1237
                                               dtype=dtype,
1238
                                               nnz=nnz,
1239
                                               blocksize=blocksize)
1240

1241
            # same size, resize should not trigger
1242
            b = self.genSparseCompressedTensor(shape,
1243
                                               layout=layout,
1244
                                               device=device,
1245
                                               index_dtype=index_dtype,
1246
                                               dtype=dtype,
1247
                                               nnz=nnz,
1248
                                               blocksize=blocksize)
1249

1250
            # This test will not always trigger a resize, if the layouts are the same nothing should happen to b.
1251
            # The invariants of the function as checked should still hold
1252
            _check_resize_b_as_a(b, a)
1253

1254
            # same ndim, but bigger, more nnz, different dtype, different blocksize if blocked
1255
            b = self.genSparseCompressedTensor(tuple(s * 2 for s in shape),
1256
                                               layout=layout,
1257
                                               device=device,
1258
                                               dtype=torch.chalf,
1259
                                               index_dtype=torch.int64 if index_dtype == torch.int32 else torch.int32,
1260
                                               nnz=nnz * 2,
1261
                                               blocksize=tuple(2 * bi for bi in blocksize))
1262
            _check_resize_b_as_a(b, a)
1263

1264
            # different device, only check on cuda pass as we know we are testing in an environment
1265
            # that has multiple devices
1266

1267
            # TODO: .cpu() does not seem to work correctly for sparse. Causes a call to `copy_` which
1268
            # complains about incompatible nnz between src and self?
1269
            if torch.device(device).type == 'cuda' and (layout not in (torch.sparse_bsc, torch.sparse_bsr)):
1270
                a_cpu = self.genSparseCompressedTensor(shape,
1271
                                                       layout=layout,
1272
                                                       device='cpu',
1273
                                                       index_dtype=index_dtype,
1274
                                                       dtype=dtype,
1275
                                                       nnz=nnz,
1276
                                                       blocksize=blocksize)
1277
                _check_resize_b_as_a(b, a)
1278

1279
            # error on a strided
1280
            a_strided = a.to_dense()
1281
            with self.assertRaisesRegex(
1282
                    RuntimeError, r'resize_as_sparse_compressed_: src  expected sparse compressed tensor layout'):
1283
                b.resize_as_sparse_(a_strided)
1284

1285
            # error on b strided
1286
            b_strided = b.to_dense()
1287
            with self.assertRaisesRegex(
1288
                    RuntimeError, r'resize_as_sparse_compressed_: self  expected sparse compressed tensor layout'):
1289
                b_strided.resize_as_sparse_(a)
1290

1291
            # error if layout does not match, transpose induces layout flip
1292
            with self.assertRaisesRegex(RuntimeError,
1293
                                        r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"):
1294
                b.transpose(-2, -1).resize_as_sparse_(a)
1295
            with self.assertRaisesRegex(RuntimeError,
1296
                                        r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"):
1297
                b.resize_as_sparse_(a.transpose(-2, -1))
1298

1299
    @skipMeta
1300
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1301
    def test_resize_errors(self, device, dtype):
1302
        for index_dtype in [torch.int32, torch.int64]:
1303
            shape = (2, 3)
1304
            nnz = 6
1305
            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1306

1307
            with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
1308
                new_shape = (4,)
1309
                a.resize_(new_shape)
1310

1311
            # resizing of columns to smaller size is not implemented
1312
            with self.assertRaisesRegex(
1313
                RuntimeError,
1314
                "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.",
1315
            ):
1316
                new_shape = (2, 2)
1317
                a.resize_(new_shape)
1318

1319
    @skipIfTorchDynamo()
1320
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1321
    def test_sparse_csr_from_dense(self, device, dtype):
1322
        dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device)
1323
        sparse = dense.to_sparse_csr()
1324
        self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
1325
        self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
1326
        self.assertEqual(torch.tensor([4, 5, 1], dtype=dtype), sparse.values())
1327

1328
        dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=dtype, device=device)
1329
        sparse = dense.to_sparse_csr()
1330
        self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
1331
        self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
1332
        self.assertEqual(torch.tensor([1, 1], dtype=dtype), sparse.values())
1333

1334
        dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=dtype, device=device)
1335
        sparse = dense.to_sparse_csr()
1336
        self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
1337
        self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
1338
        self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values())
1339

1340
    def _test_sparse_compressed_to_dense(self, device, dtype, layout):
1341
        compressed_format_str = str(layout)[-3:]
1342

1343
        def to_compressed(t):
1344
            return getattr(t, f"to_sparse_{compressed_format_str}")()
1345

1346
        def compressed_constructor(*input, **kwargs):
1347
            constructor = getattr(torch, f"sparse_{compressed_format_str}_tensor")
1348
            return constructor(*input, **kwargs)
1349

1350
        def get_dense_shape(shape, batch_ndim):
1351
            if layout is torch.sparse_csc:
1352
                compressed_dims_slice = slice(batch_ndim + 1, batch_ndim - 1, -1)
1353
            else:
1354
                compressed_dims_slice = slice(batch_ndim, batch_ndim + 2)
1355
            return shape[:batch_ndim] + shape[compressed_dims_slice] + shape[batch_ndim + 2:]
1356

1357
        def transpose(t, batch_ndim):
1358
            if layout is torch.sparse_csc:
1359
                return t.transpose(batch_ndim, batch_ndim + 1)
1360
            return t
1361

1362
        mn = [5, 2, 0]
1363
        for (m, n) in itertools.product(mn, mn):
1364
            size = (m, n)
1365
            dense = make_tensor(size, dtype=dtype, device=device)
1366
            sparse = to_compressed(dense)
1367
            self.assertEqual(sparse.to_dense(), dense)
1368

1369
        batch_shape = (2, 3)
1370
        compressed_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
1371
        plain_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
1372
        values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
1373
        sparse = compressed_constructor(compressed_indices, plain_indices, values, dtype=dtype, device=device)
1374
        dense_shape = get_dense_shape(sparse.shape, len(batch_shape))
1375
        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(dense_shape)
1376
        self.assertEqual(sparse.to_dense(), transpose(dense, len(batch_shape)))
1377

1378
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1379
    def test_sparse_csr_to_dense(self, device, dtype):
1380
        self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csr)
1381

1382
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1383
    def test_sparse_csc_to_dense(self, device, dtype):
1384
        self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csc)
1385

1386
    @skipMeta
1387
    @skipCPUIfNoMklSparse
1388
    @coalescedonoff
1389
    @dtypes(torch.double)
1390
    def test_coo_to_csr_convert(self, device, dtype, coalesced):
1391
        with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"):
1392
            torch._convert_indices_from_coo_to_csr(
1393
                torch.randint(100, (5, 5), device=device),
1394
                size=100)
1395

1396
        size = (5, 5)
1397
        sparse_dim = 2
1398
        nnz = 10
1399
        sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz, coalesced, device, dtype)
1400
        sparse_csr = sparse_coo.to_sparse_csr()
1401

1402
        self.assertTrue(sparse_csr.is_sparse_csr)
1403
        self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense())
1404

1405
        vec = torch.randn((5, 1), dtype=dtype, device=device)
1406
        coo_product = sparse_coo.matmul(vec)
1407
        csr_product = sparse_csr.matmul(vec)
1408

1409
        self.assertEqual(coo_product, csr_product)
1410

1411
        vec = torch.randn((100, 1), dtype=dtype, device=device)
1412
        index = torch.tensor([
1413
            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
1414
            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
1415
        ], dtype=torch.int32)
1416
        values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
1417
        coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]), dtype=dtype, device=device)
1418
        csr = coo.to_sparse_csr()
1419

1420
        self.assertEqual(coo.matmul(vec), csr.matmul(vec))
1421

1422
        col_indices = torch.tensor([
1423
            31, 92, 65, 50, 34, 62, 22, 56, 74, 89
1424
        ], dtype=torch.int64, device=device)
1425
        self.assertEqual(csr.col_indices(), col_indices)
1426

1427
        values = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device)
1428
        self.assertEqual(csr.values(), values)
1429

1430
    @parametrize("blocksize", [2, 4])
1431
    @dtypes((torch.double, torch.int32), (torch.double, torch.int64))
1432
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1433
    @skipMeta
1434
    def test_csr_to_block_csr(self, device, dtypes, blocksize):
1435
        for shape in [(24, 24), (12, 24)]:
1436
            dtype, index_dtype = dtypes
1437
            m, k = shape
1438
            nnz = random.randint(0, m * k)
1439
            t = self.genSparseCSRTensor((m * blocksize, k * blocksize), nnz, dtype=dtype,
1440
                                        device=device, index_dtype=index_dtype)
1441
            st = sp.csr_matrix((t.values().cpu(), t.col_indices().cpu(), t.crow_indices().cpu()), shape=tuple(t.size()))
1442
            block_t = t.to_sparse_bsr((blocksize, blocksize))
1443
            self.assertEqual(block_t.values().dim(), 3)
1444
            self.assertTrue(block_t.layout == torch.sparse_bsr)
1445
            block_st = st.tobsr(blocksize=(blocksize, blocksize))
1446
            block_st.sort_indices()
1447
            self.assertEqual(block_t.values().cpu(), block_st.data)
1448
            self.assertEqual(block_t.col_indices().cpu(), torch.tensor(block_st.indices).to(index_dtype))
1449
            self.assertEqual(block_t.crow_indices().cpu(), torch.tensor(block_st.indptr).to(index_dtype))
1450

1451
    @dtypes(torch.double)
1452
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1453
    def test_csr_to_block_csr_errors(self, device, dtype):
1454
        for index_dtype in [torch.int32, torch.int64]:
1455
            nnz = 15
1456
            t = self.genSparseCSRTensor((16, 16), nnz, dtype=dtype,
1457
                                        device=device, index_dtype=index_dtype)
1458

1459
            with self.assertRaisesRegex(RuntimeError,
1460
                                        r"tensor sparse size \(.*,.*\) must be divisible by given blocksize \(.*,.*\)"):
1461
                block_t = t.to_sparse_bsr((5, 5))
1462

1463
    # TODO: Support auto generation of device check for sparse tensors
1464
    # See: https://github.com/pytorch/pytorch/issues/59058
1465
    @onlyCUDA
1466
    @dtypes(torch.double)
1467
    def test_matmul_device_mismatch(self, device, dtype):
1468
        cpu = torch.rand((10, 10))
1469
        cuda = cpu.cuda()
1470
        for s, m1, m2 in itertools.product((cpu, cuda), repeat=3):
1471
            csr = m1.to_sparse()
1472
            if s.device == csr.device == m2.device:
1473
                torch.addmm(s, csr, m2)
1474
            else:
1475
                with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
1476
                    torch.addmm(s, csr, m2)
1477

1478
    @skipCPUIfNoMklSparse
1479
    @skipCUDAIfNoSparseGeneric
1480
    @dtypes(*floating_and_complex_types())
1481
    @dtypesIfCUDA(*floating_and_complex_types_and(
1482
                  *[torch.half] if SM53OrLater else [],
1483
                  *[torch.bfloat16] if SM80OrLater else []))
1484
    def test_csr_matvec(self, device, dtype):
1485

1486
        if TEST_WITH_ROCM and (dtype == torch.half or dtype == torch.bfloat16):
1487
            self.skipTest("ROCm doesn't work with half dtypes correctly.")
1488

1489
        side = 100
1490
        for index_dtype in [torch.int32, torch.int64]:
1491
            csr = self.genSparseCSRTensor((side, side), 1000, device=device, dtype=dtype, index_dtype=index_dtype)
1492
            vec = torch.randn(side, dtype=dtype, device=device)
1493

1494
            res = csr.matmul(vec)
1495
            expected = csr.to_dense().matmul(vec)
1496

1497
            self.assertEqual(res, expected)
1498

1499
            bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
1500
            err_msg = "size mismatch, got"
1501
            with self.assertRaisesRegex(RuntimeError, err_msg):
1502
                csr.matmul(bad_vec)
1503

1504
    @onlyCUDA
1505
    # hmm, the test passes ok on CUDA when Rocm is not available:
1506
    @skipCUDAIfRocmVersionLessThan((5, 2))
1507
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1508
    def test_baddbmm(self, device, dtype):
1509

1510
        # TODO: disable the invariant checks within torch.baddbmm that
1511
        # constructs unconventional csr tensors leading to
1512
        # RuntimeError: tensor dimensionality must be sum of batch,
1513
        #     base, and dense dimensionalities (=0 + 2 + 0) but got 3
1514
        # when invariant checking is enabled. When done, undecorate run_test.
1515
        @torch.sparse.check_sparse_tensor_invariants(enable=False)
1516
        def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
1517
            alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1518
            beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1519
            b = b.mH if (op_b and a.shape == b.shape) else b
1520

1521
            actual = torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta)
1522

1523
            out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
1524
            torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta, out=out)
1525

1526
            expected = [torch.addmm(c[i], a, b[i], alpha=alpha, beta=beta) for i in range(c.shape[0])]
1527
            expected = torch.stack(expected, 0)
1528

1529
            self.assertEqual(actual, out)
1530
            self.assertEqual(actual, expected)
1531

1532
        for index_dtype in [torch.int32, torch.int64]:
1533
            for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]):
1534
                nnz = random.randint(0, m * k)
1535
                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1536

1537
                # a_batched is a regular CSR tensor but with a batch dimension in the shape
1538
                a_batched = torch.sparse_csr_tensor(
1539
                    a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
1540

1541
                b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1542
                c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1543
                for op_b, op_out in itertools.product([True, False], repeat=2):
1544
                    run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
1545

1546
    @onlyCUDA
1547
    @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1548
    @skipCUDAIfNoSparseGeneric
1549
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1550
    def test_bmm(self, device, dtype):
1551
        def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
1552
            b = b.mH if (op_b and a.shape == b.shape) else b
1553

1554
            actual = torch.bmm(a_batched, b)
1555

1556
            out = torch.empty_like(actual.mH if op_out and a.shape == b.shape else actual)
1557
            torch.bmm(a_batched, b, out=out)
1558

1559
            expected = [torch.mm(a, b[i]) for i in range(b.shape[0])]
1560
            expected = torch.stack(expected, 0)
1561

1562
            self.assertEqual(actual, out)
1563
            self.assertEqual(actual, expected)
1564

1565
        for index_dtype in [torch.int32, torch.int64]:
1566
            for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]):
1567
                nnz = random.randint(0, m * k)
1568
                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1569

1570
                # a_batched is a regular CSR tensor but with a batch
1571
                # dimension in the shape. It is unorthodox in PyTorch
1572
                # to represent a batch sparse tensor in this way,
1573
                # hence checking the tensor invariants is locally
1574
                # turned off.
1575
                a_batched = torch.sparse_csr_tensor(
1576
                    a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
1577

1578
                b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1579
                for op_b, op_out in itertools.product([True, False], repeat=2):
1580
                    run_test(a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
1581

1582
    def run_test_block_addmm_addmv(self,
1583
                                   addmv_addmm,
1584
                                   c,
1585
                                   a,
1586
                                   b,
1587
                                   op_b=False,
1588
                                   op_out=False,
1589
                                   *,
1590
                                   dtype=None,
1591
                                   device=None,
1592
                                   ref=_npref_block_addmm_addmv):
1593
        alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1594
        beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1595
        b = b.mH if (op_b and a.shape == b.shape) else b
1596

1597
        actual = addmv_addmm(c, a, b, alpha=alpha, beta=beta)
1598

1599
        out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
1600
        addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
1601
        expected = ref(c, a, b, alpha, beta)
1602

1603
        self.assertEqual(actual, out)
1604
        self.assertEqual(actual, expected, lambda msg: f"{msg}\na={a}\nc={c}\nb={b}\nalpha={alpha} beta={beta}")
1605

1606
    # TODO: block_size 1 is broken
1607
    @parametrize("block_size", [2, 3])
1608
    @parametrize("index_dtype", [torch.int32, torch.int64])
1609
    @parametrize("noncontiguous", [True, False])
1610
    @skipCPUIfNoMklSparse
1611
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1612
    @skipIfTorchDynamo("raises 'sparse matrix length is ambiguous; use getnnz()'")
1613
    @dtypes(*floating_and_complex_types())
1614
    @dtypesIfCUDA(*floating_and_complex_types_and(
1615
                  *[torch.half] if SM53OrLater else [],
1616
                  *[torch.bfloat16] if SM80OrLater else []))
1617
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
1618
                        torch.float64: 1e-5, torch.complex128: 1e-5,
1619
                        torch.float16: 1e-3, torch.bfloat16: 1e-3})
1620
    def test_block_addmm(self, device, dtype, index_dtype, block_size, noncontiguous):
1621

1622
        def make_transposed_addmm_op(f):
1623

1624
            def tt(t):
1625
                if isinstance(t, torch.Tensor):
1626
                    return t.transpose(-2, -1)
1627
                else:
1628
                    # assume numpy/scipy spmatrix
1629
                    return t.transpose()
1630

1631
            @functools.wraps(f)
1632
            def wrapper(c, a, b, alpha=None, beta=None, out=None):
1633
                if out is not None:
1634
                    # the ref takes no out kwarg
1635
                    assert isinstance(out, torch.Tensor)
1636
                    # transpose inplace to propagate out to checking context
1637
                    out.transpose_(-2, -1)
1638
                    return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta, out=out)
1639
                else:
1640
                    return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta)
1641

1642
            return wrapper
1643

1644
        def ref_sp_numpy(c, a, b, alpha=None, beta=None, out=None):
1645

1646
            def prep_input(t):
1647

1648
                def to_sp_block_compressed(t):
1649

1650
                    if t.layout is torch.sparse_bsc:
1651
                        tt = t.transpose(-1, -2)
1652
                    else:
1653
                        tt = t
1654

1655
                    t_sp_bsr = sp.bsr_matrix(
1656
                        (
1657
                            tt.values().cpu().numpy(),
1658
                            tt.col_indices().cpu().numpy(),
1659
                            tt.crow_indices().cpu().numpy(),
1660
                        ),
1661
                        shape=tt.shape,
1662
                    )
1663

1664
                    if t.layout is torch.sparse_bsc:
1665
                        return t_sp_bsr.transpose()
1666
                    else:
1667
                        return t_sp_bsr
1668

1669
                if t.layout is not torch.strided:
1670
                    return to_sp_block_compressed(t)
1671
                else:
1672
                    return t.cpu().resolve_conj().numpy()
1673

1674
            res = _npref_block_addmm_addmv(
1675
                *(prep_input(t) for t in (c, a, b)),
1676
                alpha,
1677
                beta
1678
            )
1679

1680
            if out is not None:
1681
                out.copy_(res)
1682
                return out
1683
            else:
1684
                return res
1685

1686
        def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
1687
            res = alpha * (a.to_dense().to(torch.float32) @ b.to_dense().to(torch.float32)).to(a.dtype) + beta * c
1688
            if out is not None:
1689
                out.copy_(res)
1690
                return out
1691
            else:
1692
                return res
1693

1694
        if dtype in (torch.half, torch.bfloat16):
1695
            ref = ref_half_bfloat16
1696
        else:
1697
            ref = ref_sp_numpy
1698

1699
        for (m, n, k) in itertools.product([2, 5], repeat=3):
1700
            nnz = random.randint(0, m * k)
1701
            a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1702
            a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1703
            a_data = a_data.mT if noncontiguous else a_data
1704
            a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1705
                                        a_data, (m * block_size, k * block_size), check_invariants=False)
1706
            b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
1707
            c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
1708
            for op_b, op_out in itertools.product([True, False], repeat=2):
1709
                self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device, ref=ref)
1710
                self.run_test_block_addmm_addmv(make_transposed_addmm_op(torch.addmm),
1711
                                                c,
1712
                                                a,
1713
                                                b,
1714
                                                op_b,
1715
                                                op_out,
1716
                                                dtype=dtype,
1717
                                                device=device,
1718
                                                ref=make_transposed_addmm_op(ref))
1719

1720
    @parametrize("block_size", [2, 3])
1721
    @parametrize("index_dtype", [torch.int32, torch.int64])
1722
    @parametrize("noncontiguous", [True, False])
1723
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1724
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1725
    def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous):
1726
        # TODO: Explicitly disable block size 1 support
1727
        # if (TEST_WITH_ROCM or not TEST_CUSPARSE_GENERIC) and block_size == 1:
1728
        #     return
1729
        def ref_block_addmv(c, a, b, alpha, beta):
1730
            return _npref_block_addmm_addmv(c, a.to_dense(), b, alpha, beta)
1731

1732
        for (m, k) in itertools.product([2, 5], repeat=2):
1733
            nnz = random.randint(0, m * k)
1734
            if not noncontiguous:
1735
                a = self.genSparseCSRTensor((m * block_size, k * block_size), nnz,
1736
                                            dtype=dtype, device=device, index_dtype=index_dtype)
1737
                a = a.to_sparse_bsr((block_size, block_size))
1738
            else:
1739
                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1740
                a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1741
                a_data = a_data.mT if noncontiguous else a_data   # Test column-major blocks
1742
                a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1743
                                            a_data, (m * block_size, k * block_size), check_invariants=False)
1744
            b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
1745
            c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
1746
            self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device, ref=ref_block_addmv)
1747

1748
    @parametrize("matrix_shape", [(3, 3), (5, 7), (11, 9)], name_fn=lambda x: "shape_{}x{}".format(*x))
1749
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1750
    @onlyCPU
1751
    def test_addmv(self, device, dtype, matrix_shape):
1752
        mat = torch.randn(matrix_shape, dtype=dtype, device=device)
1753
        mat[mat.real < 0] = 0
1754
        sparse_mat = mat.to_sparse_csr()
1755
        mvec = torch.randn((mat.size(1),), dtype=dtype, device=device)
1756
        avec = torch.randn((mat.size(0),), dtype=torch.float64, device=device)
1757
        ref_output = torch.addmv(avec, mat, mvec)
1758
        output = torch.addmv(avec, sparse_mat, mvec)
1759
        self.assertEqual(ref_output, output)
1760

1761
    @parametrize("block_size", [2, 3])
1762
    @parametrize("index_dtype", [torch.int32, torch.int64])
1763
    @parametrize("noncontiguous", [True, False])
1764
    @skipCPUIfNoMklSparse
1765
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1766
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1767
    def test_block_triangular_solve(self, device, dtype, index_dtype, block_size, noncontiguous):
1768
        def run_test(a, b, upper, transpose, unitriangular, op_out):
1769
            if unitriangular and self.device_type == 'cpu':
1770
                # TODO: When unitriangular=True results are not correct on CPU
1771
                return
1772

1773
            if not upper and self.device_type == 'cpu':
1774
                # TODO: When upper=False some generated inputs might crash on CPU
1775
                return
1776

1777
            actual = torch.triangular_solve(b, a, upper=upper, unitriangular=unitriangular, transpose=transpose)
1778
            actual_X = actual.solution
1779
            actual_A_clone = actual.cloned_coefficient
1780
            self.assertTrue(actual_A_clone.numel() == 0)
1781
            if a._nnz() == 0:
1782
                self.assertTrue(actual_X.isnan().all())
1783
                return
1784

1785
            # TODO: replace with torch method when implemented to_dense() on block sparse tensor
1786
            a_bsr = sp.bsr_matrix(
1787
                (
1788
                    a.values().cpu().numpy(),
1789
                    a.col_indices().cpu().numpy(),
1790
                    a.crow_indices().cpu().numpy(),
1791
                ),
1792
                shape=a.shape,
1793
            )
1794
            expected_X, _ = torch.triangular_solve(
1795
                b,
1796
                torch.tensor(a_bsr.todense(), device=device),
1797
                transpose=transpose,
1798
                upper=upper,
1799
                unitriangular=unitriangular)
1800

1801
            if expected_X.isnan().any():
1802
                # TODO: zeros on the diagonal are not handled for CPU path
1803
                # there's no way to query this info from MKL
1804
                if self.device_type == 'cuda' and not TEST_WITH_ROCM:
1805
                    self.assertTrue(actual_X.isnan().any() or actual_X.isinf().any())
1806
                return
1807

1808
            self.assertEqual(actual_X, expected_X)
1809

1810
            out = torch.empty_like(b.mH if op_out and a.shape == b.shape else b)
1811
            torch.triangular_solve(
1812
                b, a,
1813
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
1814
            )
1815
            self.assertEqual(out, actual_X)
1816
            self.assertEqual(out, expected_X)
1817

1818
        for (m, k) in itertools.product([2, 3], [1, 3]):
1819
            nnz = random.randint(0, m * m)
1820
            if not noncontiguous:
1821
                a = self.genSparseCSRTensor((m * block_size, m * block_size), nnz,
1822
                                            dtype=dtype, device=device, index_dtype=index_dtype)
1823
                a = a.to_sparse_bsr((block_size, block_size))
1824
            else:
1825
                a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1826
                a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1827
                a_data = a_data.mT if noncontiguous else a_data  # Test column-major blocks
1828
                a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1829
                                            a_data, (m * block_size, m * block_size), check_invariants=False)
1830
            b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
1831

1832
            for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4):
1833
                run_test(a, b, upper, unitriangular, transpose, op_out)
1834

1835
    @skipCPUIfNoMklSparse
1836
    @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1837
    @dtypes(torch.double)
1838
    def test_mm(self, device, dtype):
1839
        def test_shape(di, dj, dk, nnz0=None, nnz1=None):
1840
            for index_dtype in [torch.int32, torch.int64]:
1841
                alpha = random.random()
1842
                beta = random.random()
1843

1844
                def _test_addmm(t, x, y):
1845
                    # TODO: addmm doesn't support strided result for sparse inputs.
1846
                    # res = beta * t  + alpha * (x @ y)
1847
                    res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
1848
                    expected = torch.addmm(t, x.to_dense(), y.to_dense(), beta=beta, alpha=alpha)
1849
                    self.assertEqual(res, expected)
1850

1851
                    res = torch.addmm(t, x, y)
1852
                    expected = torch.addmm(t, x.to_dense(), y.to_dense())
1853
                    self.assertEqual(res, expected)
1854

1855
                def _test_mm(x, y):
1856
                    res = torch.mm(x, y)
1857
                    expected = torch.mm(x.to_dense(), y.to_dense())
1858
                    if x.layout is torch.strided or y.layout is torch.strided:
1859
                        self.assertEqual(res.layout, torch.strided)
1860
                    else:
1861
                        self.assertEqual(res.layout, torch.sparse_csr)
1862
                    self.assertEqual(res.to_dense(), expected)
1863

1864
                def _test(t, x, y):
1865
                    _test_addmm(t, x, y)
1866
                    _test_mm(x, y)
1867

1868
                if nnz0 is None:
1869
                    nnz0 = random.randint(di * dk // 2, di * dk)
1870
                t = torch.randn(di, dj, dtype=dtype, device=device)
1871
                x = self.genSparseCSRTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1872
                y = torch.randn(dk, dj, dtype=dtype, device=device)
1873
                _test(t, x, y)
1874

1875
                t = torch.randn(di, dj, dtype=dtype, device=device)
1876
                x = self.genSparseCSCTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1877
                y = torch.randn(dk, dj, dtype=dtype, device=device)
1878
                _test(t, x, y)
1879

1880
                if nnz1 is None:
1881
                    nnz1 = random.randint(dk * dj // 2, dk * dj)
1882
                t = torch.randn(di, dj, dtype=dtype, device=device)
1883
                x = torch.randn(di, dk, dtype=dtype, device=device)
1884
                y = self.genSparseCSRTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1885
                _test(t, x, y)
1886

1887
                t = torch.randn(di, dj, dtype=dtype, device=device)
1888
                x = torch.randn(di, dk, dtype=dtype, device=device)
1889
                y = self.genSparseCSCTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1890
                _test(t, x, y)
1891

1892
                x_shape, y_shape = x.shape, y.shape
1893

1894
                gen_csr_csc = [self.genSparseCSRTensor, self.genSparseCSCTensor]
1895

1896
                # Test mm({CSR, CSC}, {CSR, CSC})
1897
                for gen_x, gen_y in itertools.product(gen_csr_csc, gen_csr_csc):
1898
                    x = gen_x(x_shape, nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1899
                    y = gen_y(y_shape, nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1900
                    _test_mm(x, y)
1901

1902
        def test_empty_inputs(lhs_layout, rhs_layout):
1903
            xd = torch.rand(10, 0, device=device, dtype=dtype)
1904
            yd = xd.transpose(-2, -1)
1905
            zd = torch.rand(0, 0, device=device, dtype=dtype)
1906

1907
            xls, yls, zls = (t.to_sparse(layout=lhs_layout) for t in (xd, yd, zd))
1908
            xrs, yrs, zrs = (t.to_sparse(layout=rhs_layout) for t in (xd, yd, zd))
1909

1910
            for ls, rs, ld, rd in [(xls, yrs, xd, yd), (xls, zrs, xd, zd), (zls, yrs, zd, yd), (zls, zrs, zd, zd)]:
1911
                res_sparse = ls @ rs
1912
                res_dense = ld @ rd
1913
                self.assertEqual(res_sparse.to_dense(), res_dense)
1914

1915
        def test_orthogonal_inputs(lhs_layout, rhs_layout):
1916
            ones = torch.ones(2, 2, device=device, dtype=dtype)
1917
            zeros = torch.zeros(2, 2, device=device, dtype=dtype)
1918
            x = torch.cat((ones, zeros), -1).to_sparse(layout=lhs_layout)
1919
            y = torch.cat((zeros, ones), -2).to_sparse(layout=rhs_layout)
1920
            res = x @ y
1921
            res_expected = torch.zeros(*res.shape, device=device, dtype=dtype, layout=res.layout)
1922
            self.assertEqual(res, res_expected)
1923

1924
        for lhs_layout, rhs_layout in itertools.product([torch.sparse_csr, torch.sparse_csc], repeat=2):
1925
            test_empty_inputs(lhs_layout, rhs_layout)
1926
            test_orthogonal_inputs(lhs_layout, rhs_layout)
1927

1928
        for i in [2, 4]:
1929
            for j in [2, 4, 7]:
1930
                for k in [2, 3, 7]:
1931
                    test_shape(i, j, k)
1932
        test_shape(4, 4, 4, 0, 0)
1933

1934
    @skipCPUIfNoMklSparse
1935
    @dtypes(*floating_and_complex_types())
1936
    @dtypesIfCUDA(*floating_and_complex_types_and(
1937
                  *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1938
                  *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1939
    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1940
    def test_sparse_mm(self, device, dtype):
1941
        def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
1942
            if transposed:
1943
                D = torch.randn(d3, d2, dtype=dtype, device=device).t_()
1944
            else:
1945
                D = torch.randn(d2, d3, dtype=dtype, device=device)
1946
            S = self.genSparseCSRTensor((d1, d2), nnz, device=device, dtype=dtype, index_dtype=index_dtype)
1947
            S_dense = S.to_dense()
1948
            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
1949

1950
        for index_dtype in [torch.int32, torch.int64]:
1951
            test_shape(7, 8, 9, 20, False, index_dtype)
1952
            test_shape(7, 8, 9, 20, True, index_dtype)
1953

1954
    @dtypes(*floating_and_complex_types())
1955
    @dtypesIfCUDA(*floating_and_complex_types_and(
1956
                  *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1957
                  *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1958
    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1959
    def test_sparse_addmm(self, device, dtype):
1960
        def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
1961
            if alpha_beta is None:
1962
                alpha = random.random()
1963
                beta = random.random()
1964
            else:
1965
                alpha, beta = alpha_beta
1966
            if broadcast:
1967
                D1 = make_tensor((), dtype=dtype, device=device)
1968
            else:
1969
                D1 = make_tensor([n, p], dtype=dtype, device=device)
1970
            D2 = make_tensor([m, p], dtype=dtype, device=device)
1971
            S = self.genSparseCSRTensor([n, m], nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1972
            S_dense = S.to_dense()
1973
            Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1974
            Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
1975
            self.assertEqual(Y, Y_dense)
1976

1977
        for index_dtype in [torch.int32, torch.int64]:
1978
            test_shape(7, 8, 9, 20, False, index_dtype, None)
1979
            test_shape(7, 8, 9, 20, True, index_dtype, None)
1980
            test_shape(7, 8, 9, 20, False, index_dtype, (1, 0))
1981
            test_shape(7, 8, 9, 20, True, index_dtype, (1, 0))
1982
            test_shape(7, 8, 9, 20, False, index_dtype, (1, 1))
1983
            test_shape(7, 8, 9, 20, True, index_dtype, (1, 1))
1984

1985
    @skipCPUIfNoMklSparse
1986
    @dtypes(*floating_and_complex_types())
1987
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
1988
                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
1989
    @dtypesIfCUDA(*floating_types_and(torch.complex64,
1990
                                      *[torch.bfloat16] if SM80OrLater else [],
1991
                                      *[torch.half] if SM53OrLater else [],
1992
                                      *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
1993
    @sparse_compressed_nonblock_layouts()
1994
    @skipCUDAIf(
1995
        not _check_cusparse_spgemm_available(),
1996
        "cuSparse Generic API SpGEMM is not available"
1997
    )
1998
    def test_addmm_all_sparse_csr(self, device, dtype, layout):
1999
        M = torch.randn(10, 25, device=device).to(dtype)
2000
        m1 = torch.randn(10, 50, device=device).to(dtype)
2001
        m2 = torch.randn(50, 25, device=device).to(dtype)
2002
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse")
2003

2004
        # Test 0-strided
2005
        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
2006
        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
2007
        m2 = torch.randn(50, 25, device=device).to(dtype)
2008
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse")
2009

2010
        # Test beta=0, M=nan
2011
        M = torch.full((10, 25), float('nan'), device=device).to(dtype)
2012
        m1 = torch.randn(10, 50, device=device).to(dtype)
2013
        m2 = torch.randn(50, 25, device=device).to(dtype)
2014
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="all_sparse")
2015

2016
        # Test transpose
2017
        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
2018
            def maybe_transpose(cond, m):
2019
                if not cond:
2020
                    return m
2021
                return m.t().clone(memory_format=torch.contiguous_format).t()
2022

2023
            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
2024
            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
2025
            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
2026
            _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="all_sparse")
2027

2028
    @onlyCPU
2029
    @skipCPUIfNoMklSparse
2030
    @dtypes(*floating_and_complex_types())
2031
    @sparse_compressed_nonblock_layouts()
2032
    def test_addmm_dense_result(self, device, dtype, layout):
2033
        M = torch.randn(10, 25, device=device).to(dtype)
2034
        m1 = torch.randn(10, 50, device=device).to(dtype)
2035
        m2 = torch.randn(50, 25, device=device).to(dtype)
2036
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result")
2037

2038
        # Test 0-strided
2039
        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
2040
        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
2041
        m2 = torch.randn(50, 25, device=device).to(dtype)
2042
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result")
2043

2044
        # Test beta=0, M=nan
2045
        M = torch.full((10, 25), float('nan'), device=device).to(dtype)
2046
        m1 = torch.randn(10, 50, device=device).to(dtype)
2047
        m2 = torch.randn(50, 25, device=device).to(dtype)
2048
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="dense_result")
2049

2050
        # Test transpose
2051
        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
2052
            def maybe_transpose(cond, m):
2053
                if not cond:
2054
                    return m
2055
                return m.t().clone(memory_format=torch.contiguous_format).t()
2056

2057
            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
2058
            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
2059
            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
2060
            _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="dense_result")
2061

2062
    @parametrize("k", [0, 1, 8])
2063
    @parametrize("n", [0, 1, 10])
2064
    @parametrize("m", [0, 1, 25])
2065
    @skipCPUIfNoMklSparse
2066
    @dtypes(*floating_and_complex_types())
2067
    @dtypesIfCUDA(*floating_types_and(torch.complex64,
2068
                                      *[torch.bfloat16] if SM80OrLater else [],
2069
                                      *[torch.half] if SM53OrLater else [],
2070
                                      *[torch.complex128]
2071
                                      if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
2072
                                      else []))
2073
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
2074
                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
2075
    def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
2076
        if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0):
2077
            self.skipTest("Skipped on ROCm")
2078
        M = torch.randn(n, m, device=device).to(dtype)
2079
        m1 = torch.randn(n, k, device=device).to(dtype)
2080
        m2 = torch.randn(k, m, device=device).to(dtype)
2081
        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, mode="all_sparse")
2082

2083
        M = torch.randn(n, m, device=device).to(dtype).to_sparse_csr()
2084
        m1 = torch.randn(n, k + 1, device=device).to(dtype).to_sparse_csr()
2085
        m2 = torch.randn(k, m, device=device).to(dtype).to_sparse_csr()
2086
        self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
2087
        self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
2088

2089
    @skipCPUIfNoMklSparse
2090
    @dtypes(torch.float)
2091
    def test_addmm_errors(self, device, dtype):
2092
        # test that the errors are the same for dense and sparse versions
2093
        import re
2094

2095
        def test1(*, is_sparse):
2096
            # shapes must be compatible for matrix multiplication
2097
            a = make_tensor((2, 3), dtype=dtype, device=device)
2098
            if is_sparse:
2099
                a_sparse = a.to_sparse_csr()
2100
                return torch.addmm(a, a_sparse, a)
2101
            else:
2102
                return torch.addmm(a, a, a)
2103

2104
        def test2(*, is_sparse):
2105
            # mat2 must be a matrix
2106
            a = make_tensor((2, 3), dtype=dtype, device=device)
2107
            if is_sparse:
2108
                a_sparse = a.to_sparse_csr()
2109
                return torch.addmm(a, a_sparse, a.unsqueeze(0))
2110
            else:
2111
                return torch.addmm(a, a, a.unsqueeze(0))
2112

2113
        def test3(*, is_sparse):
2114
            # the first input needs to be 1D or 2D
2115
            a = make_tensor((3, 3), dtype=dtype, device=device)
2116
            if is_sparse:
2117
                a_sparse = a.to_sparse_csr()
2118
                return torch.addmm(a.unsqueeze(0), a_sparse, a)
2119
            else:
2120
                return torch.addmm(a.unsqueeze(0), a, a)
2121

2122
        for test in (test1, test2, test3):
2123
            try:
2124
                test(is_sparse=False)
2125
            except RuntimeError as msg:
2126
                with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
2127
                    test(is_sparse=True)
2128

2129
    @skipCPUIfNoMklSparse
2130
    @dtypes(torch.float)
2131
    def test_mm_errors(self, device, dtype):
2132
        # test that the errors are the same for dense and sparse versions
2133
        import re
2134

2135
        def test1(*, is_sparse):
2136
            # shapes must be compatible for matrix multiplication
2137
            a = make_tensor((2, 3), dtype=dtype, device=device)
2138
            if is_sparse:
2139
                a_sparse = a.to_sparse_csr()
2140
                return torch.mm(a_sparse, a)
2141
            else:
2142
                return torch.mm(a, a)
2143

2144
        def test2(*, is_sparse):
2145
            # mat2 must be a matrix
2146
            a = make_tensor((2, 3), dtype=dtype, device=device)
2147
            if is_sparse:
2148
                a_sparse = a.to_sparse_csr()
2149
                return torch.mm(a_sparse, a.unsqueeze(0))
2150
            else:
2151
                return torch.mm(a, a.unsqueeze(0))
2152

2153
        for test in (test1, test2):
2154
            try:
2155
                test(is_sparse=False)
2156
            except RuntimeError as msg:
2157
                with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
2158
                    test(is_sparse=True)
2159

2160
    @sparse_compressed_nonblock_layouts()
2161
    @dtypes(torch.float, torch.double)
2162
    def test_add(self, device, layout, dtype):
2163
        def _test_spadd_shape(nnz, shape):
2164
            # sparse.to_dense() uses torch.add internally so if torch.add is wrong,
2165
            # the dense tensor will be wrong but this test would still pass
2166
            # there's a separate test that checks for the correctness of the .to_dense() call
2167
            x = self.genSparseCompressedTensor(shape, nnz,
2168
                                               dtype=dtype,
2169
                                               device=device,
2170
                                               index_dtype=torch.int32,
2171
                                               layout=layout,
2172
                                               blocksize=())
2173
            y = torch.randn(*shape, dtype=dtype, device=device)
2174
            r = random.random()
2175

2176
            res = torch.add(y, x, alpha=r)
2177
            expected = y + r * x.to_dense()
2178
            self.assertEqual(res, expected)
2179
            res_perm = torch.add(x, y, alpha=r)
2180
            self.assertEqual(res_perm, expected)
2181

2182
            # Non contiguous dense tensor
2183
            s = list(shape)
2184
            s[0] = shape[-1]
2185
            s[-1] = shape[0]
2186
            y = torch.randn(*s, dtype=torch.double, device=device)
2187
            y.transpose_(0, len(s) - 1)
2188
            r = random.random()
2189

2190
            res = torch.add(y, x, alpha=r)
2191
            expected = y + r * x.to_dense()
2192
            res_perm = torch.add(x, y, alpha=r)
2193

2194
            self.assertEqual(res, expected)
2195
            self.assertEqual(res_perm, expected)
2196

2197

2198
        ns = [2, 5]
2199
        batch_shapes = [(), (2,), (2, 3)]
2200
        for b, m, n in itertools.product(batch_shapes, ns, ns):
2201
            _test_spadd_shape(0, (*b, m, n))
2202
            _test_spadd_shape(m * n // 2, (*b, m, n))
2203
            _test_spadd_shape(m * n, (*b, m, n))
2204

2205
    @dtypes(torch.float, torch.double)
2206
    def test_mul(self, device, dtype):
2207
        # TODO: This whole test should be migrated to OpInfos
2208
        def _test_spadd_shape(fn, nnz, shape):
2209
            x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2210
            y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2211

2212
            # Forward comparison
2213
            res_sparse_sparse = fn(y, x)
2214
            res_dense_sparse = fn(y.to_dense(), x)
2215
            res_sparse_dense = fn(y, x.to_dense())
2216
            expected = fn(y.to_dense(), x.to_dense())
2217
            self.assertEqual(res_sparse_sparse, expected)
2218
            # TODO: While result of mul(dense, csr) is csr, it is not fully compressed.
2219
            # That means it may contain materialized zeros, since the dense argument
2220
            # is converted according to the sparsity pattern of csr. In the future
2221
            # we might require the result to be fully compressed.
2222
            self.assertEqual(res_dense_sparse, expected)
2223
            self.assertEqual(res_sparse_dense, expected)
2224

2225
            # Grad comparison
2226
            x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2227
            y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2228
            z = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2229

2230
            # csr * csr -> csr with csr, csr gradients
2231
            x_a = x.clone().requires_grad_()
2232
            y_a = y.clone().requires_grad_()
2233

2234
            fn(y_a, x_a).backward(z)
2235

2236
            x_dense_a = x.to_dense().requires_grad_()
2237
            y_dense_a = y.to_dense().requires_grad_()
2238

2239
            fn(y_dense_a, x_dense_a).backward(z.to_dense())
2240

2241
            self.assertEqual(x_a.grad.layout, torch.sparse_csr)
2242
            self.assertEqual(y_a.grad.layout, torch.sparse_csr)
2243

2244
            self.assertEqual(x_a.grad.to_dense(), x_dense_a.grad)
2245
            self.assertEqual(y_a.grad.to_dense(), y_dense_a.grad)
2246

2247
            # TODO: Currently strided Tensors cannot have csr gradients
2248
            # dense * csr -> csr with csr, dense gradients
2249
            x_a = x.clone().requires_grad_()
2250
            y_a = y.to_dense().clone().requires_grad_()
2251
            err_msg = "Function MulBackward0 returned an invalid gradient at index 0 - expected layout Strided but got SparseCsr"
2252
            with self.assertRaisesRegex(RuntimeError, err_msg):
2253
                fn(y_a, x_a).backward(z)
2254

2255
            # csr * dense -> csr with dense, csr gradients
2256
            x_a = x.to_dense().clone().requires_grad_()
2257
            y_a = y.clone().requires_grad_()
2258
            err_msg = "Function MulBackward0 returned an invalid gradient at index 1 - expected layout Strided but got SparseCsr"
2259
            with self.assertRaisesRegex(RuntimeError, err_msg):
2260
                fn(y_a, x_a).backward(z)
2261

2262
        _test_spadd_shape(torch.mul, 100, [100, 100])
2263
        _test_spadd_shape(torch.mul, 0, [100, 100])
2264
        _test_spadd_shape(torch.mul, 100, [100, 1])
2265
        _test_spadd_shape(torch.mul, 100, [1, 100])
2266

2267
    # TODO: enable hybrid once to_dense supports it
2268
    @parametrize('enable_hybrid', [False])
2269
    @all_sparse_compressed_layouts()
2270
    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
2271
    def test_mul_scalar(self, layout, device, dtype, enable_hybrid):
2272
        for sparse in self.generate_simple_inputs(
2273
                layout, device=device, dtype=dtype, index_dtype=torch.int32, enable_hybrid=enable_hybrid):
2274
            for scalar_dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
2275
                # ComplexHalf is experimental
2276
                if dtype is torch.half and scalar_dtype.is_complex:
2277
                    continue
2278

2279
                scalar_t = torch.tensor(2, dtype=scalar_dtype)
2280
                for scalar in (scalar_t, scalar_t.item()):
2281
                    res_out = sparse.mul(scalar)
2282
                    self.assertEqual(res_out, scalar * sparse)
2283

2284
                    res_dense_out = sparse.to_dense().mul(scalar)
2285
                    # BUG: dispatcher ignores mul.Scalar(Tensor, Scalar)
2286
                    # This issues is circumvented in the mul(Tensor, Tensor) kernel.
2287
                    self.assertEqual(res_out, res_dense_out)
2288

2289
                    if dtype == torch.result_type(sparse, scalar):
2290
                        res_in_dense = sparse.to_dense().mul_(scalar)
2291
                        res_in = sparse.clone().mul_(scalar)
2292
                        self.assertEqual(res_in, res_in_dense)
2293
                        self.assertEqual(res_out, res_in)
2294

2295
    @skipCPUIfNoMklSparse
2296
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2297
    def test_sparse_add(self, device, dtype):
2298
        def run_test(m, n, index_dtype):
2299

2300
            alpha = random.random()
2301
            nnz1 = random.randint(0, m * n)
2302
            nnz2 = random.randint(0, m * n)
2303
            nnz3 = random.randint(0, m * n)
2304

2305
            if TEST_WITH_ROCM:
2306
                # ROCm fails when nnz = 0
2307
                nnz1, nnz2, nnz3 = max(1, nnz1), max(1, nnz2), max(1, nnz3)
2308

2309
            S1 = self.genSparseCSRTensor([m, n], nnz1, dtype=dtype, device=device, index_dtype=index_dtype)
2310
            S2 = self.genSparseCSRTensor([m, n], nnz2, dtype=dtype, device=device, index_dtype=index_dtype)
2311
            S3 = self.genSparseCSRTensor([m, n], nnz3, dtype=dtype, device=device, index_dtype=index_dtype)
2312
            sparse_args = [S1, S2, S3]
2313
            dense_args = [t.to_dense() for t in sparse_args]
2314
            arg_idx = list(range(len(sparse_args)))
2315
            out_idx = arg_idx + [None]
2316

2317
            for idx1, idx2, idx3 in itertools.product(arg_idx, arg_idx, out_idx):
2318
                s1 = sparse_args[idx1]
2319
                s2 = sparse_args[idx2]
2320
                s3 = None if idx3 is None else sparse_args[idx3]
2321
                d1 = dense_args[idx1]
2322
                d2 = dense_args[idx2]
2323
                d3 = None if idx3 is None else dense_args[idx3]
2324

2325
                expected = torch.add(d1, d2, alpha=alpha, out=d3)
2326
                actual = torch.add(s1, s2, alpha=alpha, out=s3)
2327
                self.assertEqual(actual.crow_indices().dtype, index_dtype)
2328
                self.assertEqual(actual.col_indices().dtype, index_dtype)
2329
                self.assertEqual(actual, expected)
2330
                self.assertEqual(s3, d3)
2331
                if s3 is not None:
2332
                    self.assertEqual(s3.crow_indices().dtype, index_dtype)
2333
                    self.assertEqual(s3.col_indices().dtype, index_dtype)
2334

2335
        for index_dtype in [torch.int32, torch.int64]:
2336
            for m, n in itertools.product([3, 5], [3, 5]):
2337
                run_test(m, n, index_dtype)
2338

2339
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2340
    def test_sparse_add_errors(self, device, dtype):
2341
        def run_test(index_type):
2342
            a = self.genSparseCSRTensor((2, 2), 3, dtype=dtype, device=device, index_dtype=index_dtype)
2343
            b = self.genSparseCSRTensor((2, 1), 2, dtype=dtype, device=device, index_dtype=index_dtype)
2344
            with self.assertRaisesRegex(RuntimeError, "Expected input tensors to have the same shape"):
2345
                torch.add(a, b)
2346

2347
        for index_dtype in [torch.int32, torch.int64]:
2348
            run_test(index_dtype)
2349

2350
    @skipCPUIfNoMklSparse
2351
    @skipCUDAIf(
2352
        not _check_cusparse_triangular_solve_available(),
2353
        "cuSparse Generic API SpSV is not available"
2354
    )
2355
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2356
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2357
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2358
    def test_sparse_triangular_solve(self, device, dtype):
2359

2360
        def run_test(n, k, upper, unitriangular, transpose, zero):
2361
            if not unitriangular:
2362
                triangle_function = torch.triu if upper else torch.tril
2363
            else:
2364
                # Make sure diagonal elements are not materialized.
2365
                # This is to exercise `unitriangular=True` not relying on
2366
                # explicit presence of these indices.
2367
                if upper:
2368
                    def remove_diagonal(t):
2369
                        return t.triu(-1)
2370

2371
                else:
2372
                    def remove_diagonal(t):
2373
                        return t.tril(-1)
2374

2375
                triangle_function = remove_diagonal
2376

2377
            make_A = torch.zeros if zero else make_tensor
2378
            A = make_A((n, n), dtype=dtype, device=device)
2379
            A = triangle_function(A)
2380
            A_sparse = A.to_sparse_csr()
2381
            B = make_tensor((n, k), dtype=dtype, device=device)
2382

2383
            expected = torch.triangular_solve(B, A, upper=upper, unitriangular=unitriangular, transpose=transpose)
2384
            expected_X = expected.solution
2385

2386
            actual = torch.triangular_solve(B, A_sparse, upper=upper, unitriangular=unitriangular, transpose=transpose)
2387
            actual_X = actual.solution
2388
            actual_A_clone = actual.cloned_coefficient
2389
            self.assertTrue(actual_A_clone.numel() == 0)
2390
            if A_sparse._nnz() == 0:
2391
                self.assertTrue(actual_X.isnan().all())
2392
                return
2393
            self.assertEqual(actual_X, expected_X)
2394

2395
            # test out with C contiguous strides
2396
            out = torch.empty_strided((n, k), (k, 1), dtype=dtype, device=device)
2397
            torch.triangular_solve(
2398
                B, A_sparse,
2399
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2400
            )
2401
            self.assertEqual(out, expected_X)
2402

2403
            # test out with F contiguous strides
2404
            out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device)
2405
            torch.triangular_solve(
2406
                B, A_sparse,
2407
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2408
            )
2409
            self.assertEqual(out, expected_X)
2410
            self.assertEqual(out.stride(), (1, n))
2411

2412
            # test out with discontiguous strides
2413
            out = torch.empty_strided((2 * n, k), (1, 2 * n), dtype=dtype, device=device)[::2]
2414
            if n > 0 and k > 0:
2415
                self.assertFalse(out.is_contiguous())
2416
                self.assertFalse(out.t().is_contiguous())
2417
            before_stride = out.stride()
2418
            torch.triangular_solve(
2419
                B, A_sparse,
2420
                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2421
            )
2422
            self.assertEqual(out, expected_X)
2423
            self.assertEqual(out.stride(), before_stride)
2424

2425
        ks = [0, 1, 3]
2426
        ns = [5, 3, 0]
2427
        for (k, n), (upper, unitriangular, transpose, zero) in itertools.product(itertools.product(ks, ns),
2428
                                                                                 itertools.product([True, False], repeat=4)):
2429
            run_test(n, k, upper, unitriangular, transpose, zero)
2430

2431
    @skipCUDAIf(
2432
        not _check_cusparse_sddmm_available(),
2433
        "cuSparse Generic API SDDMM is not available"
2434
    )
2435
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2436
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2437
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2438
    def test_sampled_addmm(self, device, dtype):
2439
        def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None):
2440
            if dtype.is_complex:
2441
                alpha = random.random() + 0.3j if alpha is None else alpha
2442
                beta = random.random() + 0.6j if beta is None else beta
2443
            else:
2444
                alpha = random.random() if alpha is None else alpha
2445
                beta = random.random() if beta is None else beta
2446

2447
            if op_a and a.shape == b.shape:
2448
                a = a.mH
2449
            if op_b and a.shape == b.shape:
2450
                b = b.mH
2451

2452
            actual = torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta)
2453

2454
            out = torch.sparse_csr_tensor(
2455
                *map(torch.clone, (actual.crow_indices(), actual.col_indices())),
2456
                torch.empty_like(actual.values()),
2457
                size=actual.shape
2458
            )
2459
            torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
2460

2461
            spy_c = torch.sparse_csr_tensor(c.crow_indices(), c.col_indices(), torch.ones_like(c.values()), size=c.shape)
2462
            expected = alpha * (a @ b) * spy_c.to_dense() + beta * c.to_dense()
2463
            self.assertEqual(actual.to_dense(), out.to_dense())
2464
            self.assertEqual(actual.to_dense(), expected)
2465

2466
        mnk = list(itertools.product([2, 5], repeat=3))
2467

2468
        # Add a test case for size 0 a and b tensors
2469
        mnk = mnk + [(5, 5, 0)]
2470

2471
        batch_shapes = [(), (2,), (2, 3)]
2472
        tf = [True, False]
2473
        for index_dtype in [torch.int32, torch.int64]:
2474
            for (m, n, k), b, noncontiguous, bcast_c in itertools.product(mnk, batch_shapes, tf, tf):
2475
                if bcast_c and len(b) == 0:
2476
                    continue
2477
                nnz = random.randint(0, m * n)
2478
                c_batch = () if bcast_c else b
2479
                c = self.genSparseCSRTensor((*c_batch, m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2480
                a = make_tensor((*b, m, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
2481
                b = make_tensor((*b, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
2482
                for op_a, op_b in itertools.product([True, False], repeat=2):
2483
                    run_test(c, a, b, op_a, op_b)
2484

2485
    @skipCUDAIf(
2486
        not _check_cusparse_sddmm_available(),
2487
        "cuSparse Generic API SDDMM is not available"
2488
    )
2489
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2490
    def test_sampled_addmm_autograd(self, device, dtype):
2491
        from torch.testing._internal.common_methods_invocations import sample_inputs_sparse_sampled_addmm
2492

2493
        samples = list(sample_inputs_sparse_sampled_addmm(None, device, dtype, requires_grad=True))
2494

2495
        for sample, dense_covector in zip(samples, [True, False]):
2496
            c = sample.input
2497
            a = sample.args[0]
2498
            b = sample.args[1]
2499

2500
            # Compute sparse result
2501
            output = torch.sparse.sampled_addmm(c, a, b, **sample.kwargs)
2502
            covector = torch.randn_like(output).to_dense() if dense_covector else torch.randn_like(output)
2503
            output.backward(covector)
2504

2505
            # Compute dense result and compare with sparse result
2506
            c1, a1, b1 = (x.detach().to_dense().requires_grad_(True) for x in [c, a, b])
2507
            dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
2508
            self.assertEqual(output, dense_output)
2509
            dense_covector = covector.to_dense()
2510
            dense_output.backward(dense_covector)
2511
            self.assertEqual(c.grad, c1.grad)
2512
            self.assertEqual(a.grad, a1.grad)
2513
            self.assertEqual(b.grad, b1.grad)
2514

2515
    @onlyCUDA
2516
    # It works on ROCm and CUDA issue is currently active
2517
    @skipCUDAIf(not TEST_WITH_ROCM, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177")
2518
    @skipCUDAIf(
2519
        not _check_cusparse_sddmm_available(),
2520
        "cuSparse Generic API SDDMM is not available"
2521
    )
2522
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2523
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2524
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2525
    def test_sampled_addmm_zero_sized(self, device, dtype):
2526
        def run_test(c, a, b):
2527
            actual = torch.sparse.sampled_addmm(c, a, b)
2528
            self.assertEqual(actual.shape, c.shape)
2529

2530
        for m, n, k in itertools.product([0, 5], repeat=3):
2531
            c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
2532
            a = make_tensor((m, k), dtype=dtype, device=device)
2533
            b = make_tensor((k, n), dtype=dtype, device=device)
2534
            run_test(c, a, b)
2535

2536
    @onlyCUDA
2537
    @skipCUDAIf(
2538
        not _check_cusparse_sddmm_available(),
2539
        "cuSparse Generic API SDDMM is not available"
2540
    )
2541
    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2542
    def test_sampled_addmm_errors(self, device, dtype):
2543
        # test that the errors are the same for dense and sparse sampled versions
2544
        # import re
2545

2546
        # shapes must be compatible for matrix multiplication
2547
        a = make_tensor((2, 3), dtype=dtype, device=device)
2548
        a_sparse = a.to_sparse_csr()
2549
        with self.assertRaisesRegex(RuntimeError, r"cannot be multiplied"):
2550
            torch.sparse.sampled_addmm(a_sparse, a, a)
2551

2552
        # mat1 must be a matrix
2553
        with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to be a matrix"):
2554
            torch.sparse.sampled_addmm(a_sparse, a[..., 0, :], a)
2555

2556
        # mat2 must be a matrix
2557
        with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to be a matrix"):
2558
            torch.sparse.sampled_addmm(a_sparse, a, a[..., 0, :])
2559

2560
        a = make_tensor((2, 2), dtype=dtype, device=device)
2561
        b = make_tensor((3, 3), dtype=dtype, device=device)
2562
        b_sparse = b.to_sparse_csr()
2563
        with self.assertRaisesRegex(RuntimeError, r"self.shape\[-2\] must match mat1.shape\[-2\]"):
2564
            torch.sparse.sampled_addmm(b_sparse, a, a)
2565

2566
        b = make_tensor((2, 3), dtype=dtype, device=device)
2567
        b_sparse = b.to_sparse_csr()
2568
        with self.assertRaisesRegex(RuntimeError, r"self.shape\[-1\] must match mat2.shape\[-1\]"):
2569
            torch.sparse.sampled_addmm(b_sparse, a, a)
2570

2571
        a = make_tensor((2, 2), dtype=dtype, device=device)
2572
        a_sparse = a.to_sparse_csr()
2573
        with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to have strided layout"):
2574
            torch.sparse.sampled_addmm(a_sparse, a_sparse, a_sparse)
2575

2576
        with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"):
2577
            torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
2578

2579
    @onlyCPU
2580
    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2581
    @precisionOverride({torch.bfloat16: 0.01})
2582
    def test_sparse_mm_reduce_sum(self, device, dtype):
2583
        def run_test(m, n, k, nnz, train):
2584
            sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)
2585
            dense = sparse.to_dense()
2586

2587
            mat = torch.randn(k, n, dtype=dtype)
2588
            ref_mat = mat.clone()
2589

2590
            if train:
2591
                sparse.requires_grad_()
2592
                mat.requires_grad_()
2593
                dense.requires_grad_()
2594
                ref_mat.requires_grad_()
2595

2596
            ref_out = torch.mm(dense, ref_mat)
2597
            out = torch.sparse.mm(sparse, mat, 'sum')
2598

2599
            self.assertEqual(out, ref_out)
2600

2601
            if train:
2602
                ref_out.sum().backward()
2603
                out.sum().backward()
2604

2605
                grad_input = sparse.grad
2606
                ref_grad_input = dense.grad
2607
                grad_mat = mat.grad
2608
                ref_grad_mat = ref_mat.grad
2609

2610
                self.assertEqual(grad_input.to_dense(), ref_grad_input)
2611
                self.assertEqual(grad_mat, ref_grad_mat)
2612

2613
        run_test(4, 5, 4, 10, False)
2614
        run_test(4, 4, 4, 16, True)
2615

2616
    @skipIfTorchDynamo()
2617
    @onlyCPU
2618
    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2619
    @precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
2620
    def test_sparse_mm_reduce(self, device, dtype):
2621
        def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
2622
            csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2623
            mat = torch.randn(n, k, dtype=dtype)
2624
            ref_mat = mat.clone()
2625
            ref_values = csr.values().clone()
2626

2627
            out_int32 = index_dtype == torch.int32
2628
            coo_indices = torch._convert_indices_from_csr_to_coo(
2629
                csr.crow_indices(),
2630
                csr.col_indices(),
2631
                out_int32=out_int32)
2632
            row, col = coo_indices[0], coo_indices[1]
2633

2634
            def ref(row, col, val, mat):
2635
                out = torch.zeros([m, k], dtype=dtype)
2636
                weight = mat.index_select(0, col)
2637
                src = weight.mul(val.view(-1, 1))
2638
                index = row.view(-1, 1).expand_as(weight)
2639
                index = index.to(dtype=torch.int64)
2640
                # scatter_reduce expect index to be int64
2641
                out.scatter_reduce_(0, index, src, reduce=reduce_type, include_self=False)
2642
                return out
2643

2644
            if train:
2645
                csr.requires_grad_()
2646
                mat.requires_grad_()
2647
                ref_values.requires_grad_()
2648
                ref_mat.requires_grad_()
2649

2650
            ref_out = ref(row, col, ref_values, ref_mat)
2651
            out = torch.sparse.mm(csr, mat, reduce_type)
2652
            self.assertEqual(out, ref_out)
2653

2654
            if train and dtype not in (torch.bfloat16, torch.float16):
2655
                ref_out.sum().backward()
2656
                out.sum().backward()
2657

2658
                grad_values = csr.grad.values()
2659
                grad_weight = mat.grad
2660
                ref_grad_values = ref_values.grad
2661
                ref_grad_weight = ref_mat.grad
2662
                self.assertEqual(grad_values, ref_grad_values)
2663
                self.assertEqual(grad_weight, ref_grad_weight)
2664

2665
        for train in [False, True]:
2666
            for index_dtype in [torch.int32, torch.int64]:
2667
                for reduce_type in ["sum", "mean", "amax", "amin"]:
2668
                    # by setting nnz < M, create empty rows
2669
                    run_test(3, 4, 11, 1, reduce_type, index_dtype, train)
2670
                    run_test(3, 4, 11, 6, reduce_type, index_dtype, train)
2671
                    run_test(3, 4, 11, 12, reduce_type, index_dtype, train)
2672
                    # we are doing blocking with 4x vector length in the kernel,
2673
                    # so need to test when K > 4x vector length
2674
                    run_test(4, 7, 33, 13, reduce_type, index_dtype, train)
2675

2676
    @skipMeta
2677
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2678
    def test_coo_csr_conversion(self, device, dtype):
2679
        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2680
            size = (m, n)
2681
            dense = make_tensor(size, dtype=dtype, device=device)
2682
            coo_sparse = dense.to_sparse()
2683
            csr_sparse = coo_sparse.to_sparse_csr()
2684

2685
            self.assertEqual(csr_sparse.to_dense(), dense)
2686

2687
    @skipMeta
2688
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2689
    def test_csr_coo_conversion(self, device, dtype):
2690
        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2691
            size = (m, n)
2692
            dense = make_tensor(size, dtype=dtype, device=device)
2693
            csr_sparse = dense.to_sparse_csr()
2694
            coo_sparse = csr_sparse.to_sparse()
2695

2696
            self.assertEqual(coo_sparse.to_dense(), dense)
2697

2698
    # Currently, there is no rule in PyTorch for filling zeros in the outputs
2699
    #   from operations on Sparse CSR tensors. Hence only those operators are supported
2700
    #   which have 0->0 correspondence, example: sin(0) = 0, tan(0) = 0 but
2701
    #   cos(0) = 1 (and hence it's not supported).
2702
    # Note: here, we do this test only for unary operators
2703
    @ops(sparse_csr_unary_ufuncs)
2704
    def test_zero_to_zero_correspondence_unary(self, device, dtype, op):
2705
        zero = torch.zeros((1, 2), dtype=dtype, device=device)
2706
        tensor_explicit_zeros = torch.sparse_csr_tensor([0, 1], [1], [0], dtype=dtype, device=device)
2707

2708
        output_zero = op(zero)
2709
        expected_zero = zero.to(output_zero.dtype)
2710

2711
        output_explicit_zeros = op(tensor_explicit_zeros).to_dense()
2712
        expected_explicit_zeros = tensor_explicit_zeros.to_dense().to(output_explicit_zeros.dtype)
2713

2714
        for (output, expected) in [
2715
                (output_zero, expected_zero),
2716
                (output_explicit_zeros, expected_explicit_zeros)
2717
        ]:
2718
            self.assertEqual(output, expected, f"This operator ({op.name}) should not be supported for "
2719
                             "Sparse CSR as it breaks 0->0 correspondence.")
2720

2721
        for inp in [zero.to_sparse_csr(), tensor_explicit_zeros]:
2722
            self.assertEqual(op(inp).values().numel(), inp.values().numel(),
2723
                             f"{op.name} fails to preserve sparsity pattern.")
2724

2725
    @ops(sparse_csr_unary_ufuncs)
2726
    def test_sparse_csr_unary_out(self, device, dtype, op):
2727
        samples = op.sample_inputs(device, dtype)
2728

2729
        if not op.supports_out:
2730
            self.skipTest("Skipped! Out not supported")
2731

2732
        for sample in samples:
2733
            assert torch.is_tensor(sample.input)
2734
            # Sparse CSR only supports 2D tensors as inputs
2735
            # Fail early to prevent silent success with this test
2736
            if sample.input.ndim != 2:
2737
                raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")
2738

2739
            sample.input = sample.input.to_sparse_csr()
2740
            expect = op(sample.input, *sample.args, **sample.kwargs)
2741

2742
            out = self.genSparseCSRTensor(sample.input.size(), sample.input._nnz(),
2743
                                          device=sample.input.device, dtype=expect.dtype,
2744
                                          index_dtype=sample.input.crow_indices().dtype)
2745
            op(sample.input, *sample.args, **sample.kwargs, out=out)
2746

2747
            self.assertEqual(out, expect)
2748

2749
    @ops(sparse_csr_unary_ufuncs)
2750
    def test_sparse_csr_unary_inplace(self, device, dtype, op):
2751
        samples = op.sample_inputs(device, dtype)
2752

2753
        if op.inplace_variant is None:
2754
            self.skipTest("Skipped! Inplace variant not supported!")
2755

2756
        for sample in samples:
2757
            assert torch.is_tensor(sample.input)
2758
            # Sparse CSR only supports 2D tensors as inputs
2759
            # Fail early to prevent silent success with this test
2760
            if sample.input.ndim != 2:
2761
                raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")
2762

2763
            sample.input = sample.input.to_sparse_csr()
2764
            expect = op(sample.input, *sample.args, **sample.kwargs)
2765

2766
            if not torch.can_cast(expect.dtype, dtype):
2767
                with self.assertRaisesRegex(RuntimeError, "result type"):
2768
                    op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2769
                continue
2770

2771
            if sample.input.is_complex() and op.name == "abs":
2772
                with self.assertRaisesRegex(RuntimeError, "not supported"):
2773
                    op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2774
                continue
2775

2776
            actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2777

2778
            self.assertIs(actual, sample.input)
2779
            self.assertEqual(actual, expect)
2780

2781
    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2782
    @ops(sparse_csr_unary_ufuncs, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble])
2783
    def test_autograd_sparse_csr_unary(self, device, dtype, op):
2784
        if op.name not in UNARY_EWISE_CSR_ALLOW_AUTOGRAD:
2785
            self.skipTest(f"Skipped! Unary op {op.name} not supported with CSR input and autograd")
2786

2787
        samples = list(op.sample_inputs(device, dtype))
2788

2789
        # Fail early to prevent silent success with this test
2790
        ndims_equals_2d = (s.input.ndim == 2 for s in samples)
2791
        if not any(ndims_equals_2d):
2792
            raise ValueError("Expected at least one 2D tensor in samples.")
2793

2794
        for sample in samples:
2795
            # We must skip samples of low dimensionality, we can't covert them to sparsed compressed layouts
2796
            if sample.input.ndim < 2:
2797
                continue
2798
            sparse_input = sample.input.to_sparse_csr().requires_grad_(True)
2799

2800
            def fn(input):
2801
                output = op.gradcheck_wrapper(op.get_op(), input, *sample.args, **sample.kwargs)
2802
                if sample.output_process_fn_grad is not None:
2803
                    return sample.output_process_fn_grad(output)
2804
                return output
2805

2806
            # Compute sparse result
2807
            output = fn(sparse_input)
2808
            covector = torch.randn_like(output)
2809
            output.backward(covector)
2810
            self.assertTrue(torch.is_tensor(sparse_input.grad))
2811
            self.assertTrue(sparse_input.grad.is_sparse_csr)
2812

2813
            # Compute dense result and compare with sparse result
2814
            dense_input = sparse_input.detach().to_dense().requires_grad_(True)
2815
            dense_output = fn(dense_input)
2816
            dense_covector = covector.to_dense()
2817
            dense_output.backward(dense_covector)
2818
            self.assertEqual(sparse_input.grad, dense_input.grad)
2819

2820
    @skipCUDAIf(
2821
        not _check_cusparse_sddmm_available(),
2822
        "cuSparse Generic API SDDMM is not available"
2823
    )
2824
    @dtypes(torch.float64)
2825
    def test_autograd_dense_output_addmm(self, device, dtype):
2826
        from torch.testing._internal.common_methods_invocations import sample_inputs_addmm
2827

2828
        samples = list(sample_inputs_addmm(None, device, dtype, requires_grad=True))
2829

2830
        # Fail early to prevent silent success with this test
2831
        ndims_equals_2d = (s.args[0].ndim == 2 for s in samples)
2832
        if not any(ndims_equals_2d):
2833
            raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.")
2834

2835
        for sample in samples:
2836
            a = sample.args[0].relu().to_sparse_csr()
2837
            if sample.args[0].shape == sample.args[1].shape:
2838
                import warnings
2839
                warnings.warn("Broken for square matrices, see https://github.com/pytorch/pytorch/issues/116565")
2840
                continue
2841

2842
            # This path tests the autograd path wrt dense inputs
2843
            for addmm in [torch.addmm, torch.sparse.addmm]:
2844

2845
                def fn(c, b):
2846
                    output = addmm(c, a, b, **sample.kwargs)
2847
                    if sample.output_process_fn_grad is not None:
2848
                        return sample.output_process_fn_grad(output)
2849
                    return output
2850

2851
                self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True))
2852

2853
                # noncontiguous
2854
                c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2855
                b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2856
                self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True))
2857

2858
                # Now test the autograd path wrt sparse inputs
2859
                for reverse in [True, False]:
2860
                    c, b = sample.input, sample.args[1]
2861
                    if reverse and a.shape != b.shape:
2862
                        continue
2863

2864
                    def fn(a):
2865
                        inputs = (c, b, a) if reverse else (c, a, b)
2866
                        output = addmm(*inputs, **sample.kwargs)
2867
                        if sample.output_process_fn_grad is not None:
2868
                            return sample.output_process_fn_grad(output)
2869
                        return output
2870

2871
                    # gradcheck doesn't work for sparse CSR yet, compare against dense path
2872
                    # Compute sparse result
2873
                    a = a.detach().requires_grad_(True)
2874
                    output = fn(a)
2875
                    covector = torch.randn_like(output)
2876
                    output.backward(covector)
2877
                    self.assertTrue(torch.is_tensor(a.grad))
2878
                    if addmm == torch.sparse.addmm:
2879
                        self.assertTrue(a.grad.is_sparse_csr)
2880
                    else:
2881
                        self.assertTrue(a.grad.layout == torch.strided)
2882

2883
                    # Compute dense result and compare with sparse result
2884
                    dense_a = a.detach().to_dense().requires_grad_(True)
2885
                    dense_output = fn(dense_a)
2886
                    self.assertEqual(output, dense_output)
2887
                    dense_covector = covector.to_dense()
2888
                    dense_output.backward(dense_covector)
2889

2890
                    if addmm == torch.sparse.addmm:
2891
                        self.assertEqual(a.grad, dense_a.grad.sparse_mask(a))
2892
                    else:
2893
                        self.assertEqual(a.grad, dense_a.grad)
2894

2895
    @skipCPUIfNoMklSparse
2896
    @dtypes(torch.float64)
2897
    def test_autograd_dense_output_addmv(self, device, dtype):
2898
        from torch.testing._internal.common_methods_invocations import sample_inputs_addmv
2899

2900
        samples = list(sample_inputs_addmv(None, device, dtype, requires_grad=True))
2901

2902
        # Fail early to prevent silent success with this test
2903
        ndims_equals_2d = (s.args[0].ndim == 2 for s in samples)
2904
        if not any(ndims_equals_2d):
2905
            raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.")
2906

2907
        for sample in samples:
2908
            # TODO: Remove detach once we have autograd support for CSR input
2909
            a = sample.args[0].to_sparse_csr().detach()
2910

2911
            def fn(c, b):
2912
                output = torch.addmv(c, a, b, **sample.kwargs)
2913
                if sample.output_process_fn_grad is not None:
2914
                    return sample.output_process_fn_grad(output)
2915
                return output
2916

2917
            self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True))
2918

2919
            # noncontiguous
2920
            c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2921
            b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2922
            self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True))
2923

2924
    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2925
    @ops(binary_ops_with_dense_output, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, ])
2926
    def test_autograd_dense_output(self, device, dtype, op):
2927
        if op.name == "mv" and no_mkl_sparse and self.device_type == 'cpu':
2928
            self.skipTest("MKL Sparse is not available")
2929

2930
        samples = list(op.sample_inputs(device, dtype, requires_grad=True))
2931

2932
        # Fail early to prevent silent success with this test
2933
        ndims_equals_2d = (s.input.ndim == 2 for s in samples)
2934
        if not any(ndims_equals_2d):
2935
            raise ValueError("Expected at least one 2D tensor in samples.")
2936

2937
        # Here we assume that the signature is op(sparse_input, dense_input) -> dense_output
2938
        for sample in samples:
2939
            # TODO: Remove detach once we have autograd support for CSR input
2940
            sparse_input = sample.input.to_sparse_csr().detach()
2941

2942
            def fn(*args):
2943
                output = op.gradcheck_wrapper(op.get_op(), sparse_input, *args, **sample.kwargs)
2944
                if sample.output_process_fn_grad is not None:
2945
                    return sample.output_process_fn_grad(output)
2946
                return output
2947

2948
            self.assertTrue(torch.autograd.gradcheck(fn, sample.args, fast_mode=True))
2949

2950
            # noncontiguous
2951
            args = [make_tensor(a.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) for a in sample.args]
2952
            self.assertTrue(torch.autograd.gradcheck(fn, args, fast_mode=True))
2953

2954
    @dtypes(*all_types_and_complex())
2955
    def test_direct_coo_csr_conversion(self, device, dtype):
2956
        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2957
            size = (m, n)
2958
            dense = make_tensor(size, dtype=dtype, device=device)
2959
            coo_sparse = dense.to_sparse_coo()
2960

2961
            self.assertEqual(coo_sparse.to_sparse_csr().to_sparse_coo(), coo_sparse)
2962

2963
    @skipMeta
2964
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2965
    def test_sum(self, device, dtype):
2966
        def run_test(shape, nnz, index_type):
2967
            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2968
            self.assertEqual(a.sum(), a.values().sum())
2969
            if dtype in floating_types():
2970
                a.requires_grad_(True)
2971
                a.sum().backward()
2972
                self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device))
2973
        for shape, index_dtype in itertools.product(
2974
                [(10, 5), (10, 10)],
2975
                [torch.int32, torch.int64]):
2976
            run_test(shape, 0, index_dtype)
2977
            run_test(shape, max(shape), index_dtype)
2978
            run_test(shape, shape[0] * shape[1], index_dtype)
2979

2980
    @skipIfTorchDynamo()
2981
    @skipMeta
2982
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2983
    @all_sparse_compressed_layouts()
2984
    def test_transpose(self, device, dtype, layout):
2985

2986
        def _check_transpose_view(subject, transpose):
2987
            self.assertTrue(transpose.values()._is_view())
2988
            self.assertTrue(transpose._is_view())
2989
            self.assertTrue(transpose._base is subject)
2990

2991
        def _check_layout_invariants(transpose):
2992
            self.assertEqual(transpose.device, torch.device(device))
2993
            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[transpose.layout]
2994
            compressed_indices, plain_indices = compressed_indices_mth(transpose), plain_indices_mth(transpose)
2995
            torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, transpose.values(),
2996
                                                          transpose.shape, transpose.layout)
2997

2998
        def check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout):
2999
            transpose = subject.transpose(dim0, dim1)
3000
            # correct layout
3001
            self.assertEqual(transpose.layout, expected_layout)
3002
            # transpose must be return a view
3003
            _check_transpose_view(subject, transpose)
3004
            # result uses unsafe construction, so we check invariants
3005
            _check_layout_invariants(transpose)
3006
            self.assertEqual(transpose.to_dense(), subject_dense.transpose(dim0, dim1))
3007

3008
            round_trip = transpose.transpose(dim0, dim1)
3009
            self.assertEqual(round_trip.layout, subject.layout)
3010
            # transpose must be return a view
3011
            _check_transpose_view(subject, round_trip)
3012
            # result uses unsafe construction, so we check invariants
3013
            _check_layout_invariants(round_trip)
3014
            self.assertEqual(round_trip.to_dense(), subject_dense)
3015

3016
        def check_same_dim_transpose(subject, subject_dense, dim):
3017
            transpose = subject.transpose(dim, dim)
3018
            # correct layout
3019
            self.assertEqual(transpose.layout, subject.layout)
3020
            # transpose must be return a view
3021
            _check_transpose_view(subject, transpose)
3022
            # result uses unsafe construction, so we check invariants
3023
            _check_layout_invariants(transpose)
3024
            self.assertEqual(transpose.to_dense(), subject_dense)
3025

3026
        def check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1):
3027
            mismatch_name = f"{dim0}\\({name0}\\) and {dim1}\\({name1}\\)"
3028
            err = r"transpose\(\): can only transpose dimensions of the same type \(Batch, Sparse, Dense\), got " + mismatch_name
3029

3030
            with self.assertRaisesRegex(RuntimeError, err):
3031
                subject.transpose(dim0, dim1)
3032

3033
        def run_test(shape, nnz, index_type, n_dense, blocksize=()):
3034
            subject = self.genSparseCompressedTensor(shape,
3035
                                                     nnz,
3036
                                                     layout=layout,
3037
                                                     device=device,
3038
                                                     index_dtype=index_type,
3039
                                                     blocksize=blocksize,
3040
                                                     dense_dims=n_dense,
3041
                                                     dtype=dtype)
3042

3043

3044
            sparse0 = len(shape) - n_dense - 1
3045
            sparse1 = sparse0 - 1
3046

3047
            dense0 = sparse0 + 1 if n_dense > 0 else None
3048
            dense1 = dense0 + 1 if n_dense > 1 else None
3049

3050
            n_batch = len(shape) - n_dense - 2
3051
            batch0 = sparse1 - 1 if n_batch > 0 else None
3052
            batch1 = 0 if n_batch > 1 else None
3053

3054
            sparse_dims = (sparse0, sparse1)
3055
            dense_dims = (dense0, dense1)
3056
            batch_dims = (batch0, batch1)
3057

3058
            named0 = [(name, d[0]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))]
3059
            named1 = [(name, d[1]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))]
3060

3061
            flipped_layout = {
3062
                torch.sparse_csr: torch.sparse_csc,
3063
                torch.sparse_csc: torch.sparse_csr,
3064
                torch.sparse_bsr: torch.sparse_bsc,
3065
                torch.sparse_bsc: torch.sparse_bsr
3066
            }[layout]
3067
            if n_dense > 0:
3068
                # expect all transpose to throw
3069
                for (name0, dim0), (name1, dim1) in itertools.product(named0, named1):
3070
                    msg = r"transpose\(\): hybrid sparse compressed tensors with dense dimensions are not supported"
3071
                    if (dim0 is not None) and (dim1 is not None):
3072
                        with self.assertRaisesRegex(RuntimeError, msg):
3073
                            subject.transpose(dim0, dim1)
3074
            else:
3075
                subject_dense = subject.to_dense()
3076
                for (name0, dim0), (name1, dim1) in itertools.product(named0, named1):
3077
                    if dim0 is not None:
3078
                        check_same_dim_transpose(subject, subject_dense, dim0)
3079

3080
                        if dim1 is not None:
3081
                            if name0 == name1:
3082
                                expected_layout = flipped_layout if name0 == "Sparse" else layout
3083
                                check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout)
3084
                            else:
3085
                                check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1)
3086

3087
        # batch/sparse, sparse/dense only and full hybrid cases
3088
        shape_ndense = list(itertools.product([(2, 4, 6, 2), (10, 6, 4, 2), (2, 4, 4, 2, 6)], [0, 1, 2]))
3089
        # sparse only cases
3090
        shape_ndense += [[(4, 8), 0], [(2, 2), 0], [(8, 4), 0]]
3091
        for (shape, n_dense), index_dtype in itertools.product(shape_ndense, [torch.int32, torch.int64]):
3092
            n_batch = len(shape) - n_dense - 2
3093
            sparse_shape = shape[n_batch: n_batch + 2]
3094
            if layout in (torch.sparse_bsr, torch.sparse_bsc):
3095
                # for blocked all combinations of 2,1 should be valid blocksizes
3096
                run_test(shape, 0, index_dtype, n_dense, blocksize=(2, 2))
3097
                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 2))
3098
                run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense, blocksize=(2, 2))
3099
                # repeat the realistic sparseity case with varried block sizes
3100
                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 1))
3101
                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 2))
3102
                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 1))
3103
            else:
3104
                run_test(shape, 0, index_dtype, n_dense)
3105
                run_test(shape, max(sparse_shape), index_dtype, n_dense)
3106
                run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense)
3107

3108
    # TODO: This is a stopgap for a rigorous extension of our autograd tests
3109
    # to test the functionality of detach
3110
    @skipMeta
3111
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3112
    def test_exercise_detach(self, device, dtype):
3113
        shape = (3, 3)
3114
        nnz = 4
3115
        for index_dtype in [torch.int32, torch.int64]:
3116
            inp = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
3117
            detached_inp = inp.detach()
3118
            self.assertEqual(inp, detached_inp)
3119

3120
    def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)):
3121
        if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]:
3122
            tensor = tensor.to_dense()
3123
        else:
3124
            raise NotImplementedError(repr(tensor))
3125
        if layout is torch.sparse_csr:
3126
            return sp.csr_matrix(tensor.cpu().numpy())
3127
        if layout is torch.sparse_csc:
3128
            return sp.csc_matrix(tensor.cpu().numpy())
3129
        if layout is torch.sparse_bsr:
3130
            return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
3131
        if layout is torch.sparse_bsc:
3132
            # SciPy doesn't have native BSC support - but our tests don't need the full
3133
            # functionality so fake it by using a transposed BSR matrix.
3134
            class FakeBscMatrix:
3135
                def __init__(self, matrix):
3136
                    self._matrix = matrix
3137
                    self.shape = tuple(reversed(matrix.shape))
3138
                    self.indptr = matrix.indptr
3139
                    self.indices = matrix.indices
3140
                    self.data = [x.transpose() for x in matrix.data]
3141

3142
                @staticmethod
3143
                def from_matrix(matrix, blocksize):
3144
                    blocksize = tuple(reversed(blocksize))
3145
                    matrix = matrix.transpose()
3146
                    return FakeBscMatrix(sp.bsr_matrix(matrix, blocksize=blocksize))
3147

3148
                def sorted_indices(self):
3149
                    sub = self._matrix.sorted_indices()
3150
                    return FakeBscMatrix(sub)
3151

3152
            return FakeBscMatrix.from_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
3153
        raise NotImplementedError(repr(tensor))
3154

3155
    @skipMeta
3156
    @all_sparse_compressed_layouts('to_layout')
3157
    @all_sparse_compressed_layouts('from_layout')
3158
    def test_compressed_layout_conversions_coverage(self, device, from_layout, to_layout):
3159
        """This test performs a smoke test for covered conversion and verifies
3160
        that an exception is thrown for unsupported conversions.
3161

3162
        TODO: This test covers a subset of
3163
        TestSparseAny.test_to_sparse tests and can be
3164
        eliminated. Keeping the test until the new
3165
        `Tensor.to_sparse(*, layout, blocksize)` has landed.
3166
        """
3167

3168
        allowed_pairwise_layouts_sets = {
3169
            frozenset({torch.sparse_csc}),
3170
            frozenset({torch.sparse_csr}),
3171
            frozenset({torch.sparse_csc, torch.sparse_csr}),
3172
            frozenset({torch.sparse_csc, torch.sparse_bsc}),
3173
            frozenset({torch.sparse_csc, torch.sparse_bsr}),
3174
            frozenset({torch.sparse_csr, torch.sparse_bsc}),
3175
            frozenset({torch.sparse_csr, torch.sparse_bsr}),
3176
            frozenset({torch.sparse_bsc}),
3177
            frozenset({torch.sparse_bsr}),
3178
            frozenset({torch.sparse_bsc, torch.sparse_bsr}),
3179
        }
3180
        block_layouts = (torch.sparse_bsr, torch.sparse_bsc)
3181

3182
        def _to_from_layout(layout_a, layout_b, a):
3183
            expect_error = True
3184
            if {layout_a, layout_b} in allowed_pairwise_layouts_sets:
3185
                expect_error = False
3186

3187
            # BSR -> CSR is not yet supported
3188
            if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csr):
3189
                expect_error = True
3190
            # BSR -> CSC is not yet supported
3191
            if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csc):
3192
                expect_error = True
3193
            # BSC -> CSR is not yet supported
3194
            if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csr):
3195
                expect_error = True
3196
            # BSC -> CSC is not yet supported
3197
            if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csc):
3198
                expect_error = True
3199
            # CSR -> BSR only works for non-batched inputs
3200
            if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsr):
3201
                if a.dim() > 2:
3202
                    expect_error = True
3203
            # CSR -> BSC only works for non-batched inputs
3204
            if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsc):
3205
                if a.dim() > 2:
3206
                    expect_error = True
3207
            # CSC -> BSR only works for non-batched inputs
3208
            if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsr):
3209
                if a.dim() > 2:
3210
                    expect_error = True
3211
            # CSC -> BSC only works for non-batched inputs
3212
            if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsc):
3213
                if a.dim() > 2:
3214
                    expect_error = True
3215

3216
            blocksize_a = (1, 1) if layout_a in {torch.sparse_bsr, torch.sparse_bsc} else None
3217
            blocksize_b = (1, 1) if layout_b in {torch.sparse_bsr, torch.sparse_bsc} else None
3218
            b = a.to_sparse(layout=layout_a, blocksize=blocksize_a)
3219
            if expect_error:
3220
                with self.assertRaises(RuntimeError):
3221
                    b.to_sparse(layout=layout_b, blocksize=blocksize_b)
3222
            else:
3223
                c = b.to_sparse(layout=layout_b, blocksize=blocksize_b)
3224
                self.assertEqual(a.to_dense(), c.to_dense())
3225

3226
                # change of blocksize upon conversion is not yet supported.
3227
                if b.layout in block_layouts:
3228
                    for block_layout in block_layouts:
3229
                        with self.assertRaisesRegex(RuntimeError,
3230
                                                    "conversion from.*to.*with blocksize changed from.*to.*is not supported"):
3231
                            b.to_sparse(layout=block_layout, blocksize=(3, 3))
3232

3233
        batch_dims = [(), (2,), (2, 2), (2, 2, 2)]
3234
        sparse_dims = (6, 12)
3235
        for batch_dim in batch_dims:
3236
            a = make_tensor(batch_dim + sparse_dims, dtype=torch.float, device=device)
3237
            _to_from_layout(from_layout, to_layout, a)
3238

3239
    @skipMeta
3240
    @all_sparse_compressed_layouts()
3241
    @batched_nonbatched()
3242
    @hybrid_nonhybrid()
3243
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
3244
    def test_dense_to_from_sparse_compressed(self, device, hybrid, batched, layout):
3245
        """This test tests conversion from dense to/from CSR and CSC
3246
        by comparing to SciPy's implementation.
3247

3248
        Here we test only those conversion combinations that SciPy
3249
        supports to ensure that PyTorch conversions are in the same
3250
        page with SciPy.  Independent from SciPy, all conversion
3251
        combinations are tested in TestSparseAny.test_to_sparse.
3252
        """
3253

3254
        blocked_layouts = (torch.sparse_bsr, torch.sparse_bsc)
3255

3256
        # helpers
3257

3258
        def _check_against_scipy_matrix(pt_matrix, dense, blocksize, **kwargs):
3259
            # scipy has no bsc layout, so we check against the bsr layout of the tranposed dense
3260
            if layout == torch.sparse_bsc:
3261
                sp_matrix = self._construct_sp_matrix(dense.t(), layout=torch.sparse_bsr, blocksize=blocksize[::-1])
3262
            else:
3263
                sp_matrix = self._construct_sp_matrix(dense, layout=layout, blocksize=blocksize)
3264

3265
            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3266

3267
            self.assertEqual(layout, pt_matrix.layout)
3268
            if layout == torch.sparse_bsc:
3269
                self.assertEqual(sp_matrix.shape[::-1], pt_matrix.shape)
3270
            else:
3271
                self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3272

3273
            self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3274
            self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3275
            if layout == torch.sparse_bsc:
3276
                # we must tranpose the blocks before comparing
3277
                self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values().transpose(-2, -1))
3278
            else:
3279
                self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3280

3281
        def _check_hybrid_matrix(pt_matrix, dense, blocksize, **kwargs):
3282
            # Calculate COO indices for sparse matrix.
3283
            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3284
            compressed_indices = compressed_indices_mth(pt_matrix)
3285
            plain_indices = plain_indices_mth(pt_matrix)
3286
            coo_indices = torch._convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
3287
            row_indices, col_indices = {
3288
                torch.sparse_csr: (coo_indices[0, ], coo_indices[1, ]),
3289
                torch.sparse_csc: (coo_indices[1, ], coo_indices[0, ]),
3290
                torch.sparse_bsr: (coo_indices[0, ], coo_indices[1, ]),
3291
                torch.sparse_bsc: (coo_indices[1, ], coo_indices[0, ]),
3292
            }[pt_matrix.layout]
3293

3294
            # If sparse matrix layout blocked, rearrange dense matrix
3295
            # so that the shape past first two dimensions match the
3296
            # shape of sparse matrix values.
3297
            dense_to_check = dense
3298
            if blocksize:
3299
                dense_shape = dense.shape
3300
                dense_to_check_shape = (dense.shape[0] // blocksize[0],
3301
                                        blocksize[0],
3302
                                        dense.shape[1] // blocksize[1],
3303
                                        blocksize[1]) + dense.shape[2:]
3304
                dense_to_check = dense_to_check.reshape(dense_to_check_shape).transpose(1, 2)
3305

3306
            # Verify that non-zero values of the sparse matrix are
3307
            # equal to corresponding values of the dense matrix.
3308
            self.assertEqual(pt_matrix.values(), dense_to_check[row_indices, col_indices])
3309

3310
            # Verify that the remaining elements of the dense matrix
3311
            # are 0, i.e. that dense are sparse matrix are fully
3312
            # equal.
3313
            mask = torch.ones_like(dense_to_check, dtype=torch.bool)
3314
            mask[row_indices, col_indices] = False
3315
            self.assertTrue(torch.all(torch.masked_select(dense_to_check, mask) == 0))
3316

3317
        def _check_batched(pt_tensor, dense, check_batch=None, batch_shape=(), blocksize=(), **kwargs):
3318
            self.assertEqual(layout, pt_tensor.layout)
3319
            self.assertEqual(pt_tensor.shape, dense.shape)
3320
            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3321
            for batch_index in np.ndindex(batch_shape):
3322
                pt_matrix = pt_tensor[batch_index]
3323
                dense_matrix = dense[batch_index]
3324
                dense_dim = pt_matrix.dim() - 2
3325
                dense_matrix_pt = dense_matrix.to_sparse(layout=layout,
3326
                                                         blocksize=blocksize or None,
3327
                                                         dense_dim=dense_dim)
3328
                # sanity check, selecting batch of to_<layout> and dense[batch].to_<layout> should give the same result
3329
                self.assertEqual(pt_matrix, dense_matrix_pt)
3330
                check_batch(pt_matrix, dense_matrix, blocksize, **kwargs)
3331

3332
        def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
3333
            shape = batch_shape + sparse_shape + hybrid_shape
3334
            n_batch_dim = len(batch_shape)
3335
            n_hybrid_dim = len(hybrid_shape)
3336
            # generate a dense tensor
3337
            dense = make_tensor(shape, dtype=torch.float, device=device)
3338

3339
            # introduce some sparsty, mask is sparse shape, element applies to entire dense sub-tensor (hybrid) and is
3340
            # applied to each batch
3341
            mask = make_tensor(sparse_shape, dtype=torch.bool, device=device)
3342
            # manually expand to match hybrid shape
3343
            if hybrid:
3344
                mask = mask.view(sparse_shape + tuple(1 for _ in range(n_hybrid_dim)))
3345
                mask = mask.expand(sparse_shape + hybrid_shape)
3346

3347
            # mask will broadcast over the batch dims if present
3348

3349
            return dense * mask
3350

3351
        # note: order is important here, the hybrid-ness decides the inner content check which is used to build the
3352
        # batched checker (if needed)
3353
        check_content = _check_against_scipy_matrix
3354
        if hybrid:
3355
            check_content = _check_hybrid_matrix
3356
        if batched:
3357
            check_content = functools.partial(_check_batched, check_batch=check_content)
3358

3359
        sparse_sizes = [(6, 10), (0, 10), (6, 0), (0, 0)]
3360
        blocksizes = [(2, 2), (1, 1), (1, 2)] if layout in blocked_layouts else [()]
3361
        batch_sizes = [(3,), (1, 3), (2, 1, 3)] if batched else [()]
3362
        hybrid_sizes = [(4, ), (2, 2)] if hybrid else [()]
3363

3364
        # general cases, always run
3365
        for sparse_shape, blocksize, batch_shape, hybrid_shape in itertools.product(
3366
                sparse_sizes, blocksizes, batch_sizes, hybrid_sizes):
3367
            dense = _generate_subject(sparse_shape, batch_shape, hybrid_shape)
3368
            sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape))
3369
            check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)
3370
            dense_back = sparse.to_dense()
3371
            self.assertEqual(dense, dense_back)
3372

3373
        # special cases for batched tensors
3374
        if batched:
3375
            # batched sparse tensors need only have the same number of non-zeros in each batch not nessesarily the
3376
            # same sparsity pattern in each batch
3377
            sparse_shape = sparse_sizes[0]
3378
            hybrid_shape = hybrid_sizes[0]
3379
            batch_shape = batch_sizes[0]
3380
            shape = batch_shape + sparse_shape + hybrid_shape
3381
            dense = make_tensor(shape, dtype=torch.float, device=device)
3382
            blocksize = blocksizes[0]
3383
            # number of elements/blocks in each batch (total not nnz)
3384
            batch_mask_shape = sparse_shape
3385
            if layout in blocked_layouts:
3386
                # if we are blocked the mask is genereated for the block valued elemetns
3387
                batch_mask_shape = sparse_shape[0] // blocksize[0], sparse_shape[1] // blocksize[1]
3388

3389
            # random bool vector w/ length equal to max possible nnz for the sparse_shape
3390
            mask_source = make_tensor(batch_mask_shape, dtype=torch.bool, device=device).flatten()
3391
            n_batch = functools.reduce(operator.mul, batch_shape, 1)
3392

3393
            # stack random permutations of the source for each batch
3394
            mask = torch.stack([mask_source[torch.randperm(mask_source.numel())]
3395
                               for _ in range(n_batch)], dim=0).reshape(batch_shape + batch_mask_shape)
3396
            if layout in blocked_layouts:
3397
                # for blocked we need to do a bit of extra work to expand the mask from blocked-space to element-space
3398
                mask_shape = mask.shape
3399
                mask = mask.view(mask_shape + (1, 1))
3400
                mask = mask.expand(mask_shape + blocksize)
3401
                mask = mask.transpose(-3, -2)
3402
                mask = mask.flatten(-4, -3).flatten(-2, -1)
3403
            mask_shape = mask.shape
3404
            mask = mask.view(mask_shape + (1,) * len(hybrid_shape))
3405
            mask = mask.expand(mask_shape + hybrid_shape)
3406
            dense = dense * mask
3407
            sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape))
3408
            check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)
3409

3410
            dense_back = sparse.to_dense()
3411
            self.assertEqual(dense, dense_back)
3412

3413
            # if batches have different nnz we expect the conversion to throw
3414
            mask_0 = mask[0]
3415
            mask_1 = mask[0].clone().fill_(True)
3416
            mask_2 = mask[0].clone().fill_(False)
3417
            mask_true = mask_source.clone().fill_(True)
3418
            mask_false = mask_source.clone().fill_(False)
3419
            mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
3420
            dense = make_tensor(shape, dtype=torch.float, device=device)
3421
            dense = dense * mask
3422
            msg = "Expect the same number of specified elements per batch."
3423
            with self.assertRaisesRegex(RuntimeError, msg):
3424
                dense.to_sparse(layout=layout, blocksize=blocksize or None)
3425

3426
            # Should throw if there is a zero in the batch size
3427
            dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
3428
            layout_code = str(layout).split("_")[-1]
3429
            msg = f"to_sparse_{layout_code}: Expected product of batch dimensions to be non-zero."
3430
            with self.assertRaisesRegex(RuntimeError, msg):
3431
                dense.to_sparse(layout=layout, blocksize=blocksize or None)
3432

3433
    @skipMeta
3434
    @all_sparse_compressed_layouts()
3435
    @coalescedonoff
3436
    @dtypes(torch.double)
3437
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
3438
    def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
3439
        """
3440
        This test tests conversion from COO to CSR and CSC and CSC to CSR and CSC
3441
        by comparing to SciPy's implementation.
3442

3443
        Here we test only those conversion combinations that SciPy
3444
        supports to ensure that PyTorch conversions are in the same
3445
        page with SciPy.  Independent from SciPy, all conversion
3446
        combinations are tested in TestSparseAny.test_to_sparse.
3447
        """
3448

3449
        blocksize_kw = {}
3450
        if layout in (torch.sparse_bsc, torch.sparse_bsr):
3451
            blocksize_kw['blocksize'] = (2, 2)
3452
            # block modes don't support 0 width/height
3453
            shapes = [(6, 10)]
3454
        elif layout in (torch.sparse_csc, torch.sparse_csr):
3455
            shapes = [(0, 10), (6, 0), (6, 10), (0, 0)]
3456
        else:
3457
            raise NotImplementedError("unhandled layout")
3458

3459
        if layout in (torch.sparse_bsc, torch.sparse_csc):
3460
            compressed_indices_mth = torch.Tensor.ccol_indices
3461
            plain_indices_mth = torch.Tensor.row_indices
3462
        elif layout in (torch.sparse_bsr, torch.sparse_csr):
3463
            compressed_indices_mth = torch.Tensor.crow_indices
3464
            plain_indices_mth = torch.Tensor.col_indices
3465
        else:
3466
            raise NotImplementedError("unhandled layout")
3467

3468
        for shape in shapes:
3469
            sparse_dim = 2
3470
            nnz = shape[0] * shape[1] // 2
3471
            sparse, _, _ = self.genSparseTensor(shape, sparse_dim, nnz, coalesced, device, dtype)
3472
            sp_matrix = self._construct_sp_matrix(sparse, layout)
3473
            pt_matrix = sparse.to_sparse(layout=layout, **blocksize_kw)
3474

3475
            self.assertEqual(layout, pt_matrix.layout)
3476
            self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3477
            self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3478
            self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3479
            self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3480

3481
            sparse_csc = sparse.to_sparse_csc()
3482
            sp_matrix = self._construct_sp_matrix(sparse_csc, layout)
3483
            pt_matrix = sparse_csc.to_sparse(layout=layout, **blocksize_kw)
3484

3485
            self.assertEqual(layout, pt_matrix.layout)
3486
            self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3487
            self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3488
            self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3489
            self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3490

3491
    @unittest.skipIf(not TEST_CUDA_CUDSS, "The test requires cudss")
3492
    @dtypes(*floating_types())
3493
    def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
3494
        # https://github.com/krshrimali/pytorch/blob/f5ee21dd87a7c5e67ba03bfd77ea22246cabdf0b/test/test_sparse_csr.py
3495

3496
        try:
3497
            spd = torch.rand(4, 3)
3498
            A = spd.T @ spd
3499
            b = torch.rand(3).cuda()
3500
            A = A.to_sparse_csr().cuda()
3501
            x = torch.sparse.spsolve(A, b)
3502
        except RuntimeError as e:
3503
            if "Calling linear solver with sparse tensors requires compiling " in str(e):
3504
                self.skipTest("PyTorch was not built with cuDSS support")
3505

3506
        samples = sample_inputs_linalg_solve(None, device, dtype)
3507

3508
        for sample in samples:
3509
            if sample.input.ndim != 2:
3510
                continue
3511

3512
            out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
3513
            if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
3514
                with self.assertRaisesRegex(RuntimeError, "b must be a 1D tensor"):
3515
                    out = torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
3516
                break
3517
            if not sample.args[0].numel():
3518
                with self.assertRaisesRegex(RuntimeError,
3519
                                            "Expected non-empty other tensor, but found empty tensor"):
3520
                    torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
3521
                break
3522

3523
            expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
3524
            sample.input = sample.input.to_sparse_csr()
3525
            if sample.args[0].ndim != 1 and sample.args[0].size(-1) == 1:
3526
                expect = expect.squeeze(-1)
3527
                sample.args = (sample.args[0].squeeze(-1), )
3528
            out = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
3529
            self.assertEqual(expect, out)
3530

3531

3532
def skipIfNoTriton(cls):
3533
    from torch.utils._triton import has_triton
3534

3535
    # no-op if triton is present
3536
    if has_triton():
3537
        return cls
3538
    else:
3539

3540
        @functools.wraps(cls, updated=())
3541
        class skipped_cls(cls):
3542
            def setUp(self):
3543
                self.skipTest("Triton is not available.")
3544

3545
        return skipped_cls
3546

3547
@skipIfNoTriton
3548
class TestSparseCompressedTritonKernels(TestCase):
3549

3550
    def _to_block_triangular_inplace(self, d, row_block, col_block):
3551
        """
3552
        This function modifies `d` to become (upper/lower) block-triangular in-place.
3553
        It is assumed that `d.shape[-2]` is divisible by `row_block` and
3554
        `d.shape[-1]` is divisible by `col_block`.
3555
        """
3556

3557
        from torch.sparse._triton_ops import tile_to_blocksize
3558

3559
        m, n = d.shape[-2:]
3560
        d_tiled = tile_to_blocksize(d, (row_block, col_block))
3561
        d_tiled = d_tiled.moveaxis(-4, -1).moveaxis(-4, -1)
3562
        if m // row_block > n // col_block:
3563
            d_tiled.tril_()
3564
        else:
3565
            d_tiled.triu_()
3566

3567
        return d
3568

3569
    @onlyCUDA
3570
    @skipIfRocm(msg="test is too slow on ROCm stack")
3571
    @dtypes(torch.half, torch.bfloat16, torch.float)
3572
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3573
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3574
    def test_triton_bsr_softmax(self, device, dtype):
3575
        from functools import partial
3576
        from torch.sparse._triton_ops import bsr_softmax
3577

3578
        tensor = partial(make_tensor, device=device, dtype=dtype, low=1.0, high=3.0)
3579

3580
        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3581
        batches = [(), (2,), (2, 2)]
3582
        size = [6, 12, 0]
3583
        block_size = [2, 3]
3584

3585
        # General correctness
3586
        for row_block, col_block, b, m, n in itertools.product(block_size, block_size, batches, size, size):
3587
            input = tensor(b + (m, n))
3588
            input.diagonal(dim1=-2, dim2=-1).fill_(m * n)
3589
            input = self._to_block_triangular_inplace(input, row_block, col_block)
3590

3591
            bsr = input.to_sparse_bsr((row_block, col_block))
3592
            coo = input.to_sparse().to(torch.float)
3593

3594
            res_tri = bsr_softmax(bsr)
3595
            res_coo = torch.sparse.softmax(coo, -1)
3596
            self.assertEqual(res_tri, res_coo.to(input.dtype))
3597

3598
        # Test long rows which exceed Triton's max numel limit set to 2 ** 17
3599
        input = tensor(b + (1, 150000))
3600
        bsr = input.to_sparse_bsr(1)
3601
        self.assertEqual(input.softmax(-1), bsr_softmax(bsr))
3602

3603
    @parametrize("block_size", [16, 32, 64])
3604
    @parametrize("index_dtype", [torch.int32, torch.int64])
3605
    @onlyCUDA
3606
    @dtypes(torch.half, torch.bfloat16, torch.float)
3607
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3608
    @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(),
3609
                     "Skipped for deploy and internal with remote GPUs")
3610
    def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
3611
        from functools import partial
3612
        from torch.sparse._triton_ops import bsr_dense_mm
3613

3614
        def kernel_impl(*args, **kwargs):
3615
            return bsr_dense_mm(*args, skip_checks=True, **kwargs)
3616

3617
        kernel = torch._TritonLibrary.registerOp(
3618
            "_triton_bsr_dense_mm_out",
3619
            "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
3620
            kernel_impl,
3621
            "SparseCsrCUDA"
3622
        )
3623

3624
        # kernel != kernel_impl means dispatch was already registered.
3625
        # This is exactly what we need!
3626
        self.assertTrue(kernel is not kernel_impl)
3627

3628
        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3629
        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3630

3631
        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3632
        batches = [(), (2,), (2, 2)]
3633
        size = [128, 256, 0]
3634

3635
        # Whether to make inputs orthogonal so that the product is zero
3636
        make_orthogonal = [True, False]
3637

3638
        for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
3639
            bsr = tensor(bs + (m, k))
3640
            # NOTE: do not get confused, it will be transposed
3641
            dense = tensor(bd + (n, k))
3642

3643
            if is_ortho:
3644
                bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
3645
                dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)
3646

3647
            bsr = bsr.to_sparse_bsr(block_size)
3648

3649
            if bsr.dim() == 2 and dtype != torch.float:
3650
                # Test against linear to check dispatch
3651
                # which takes place for torch.half and torch.bfloat16.
3652
                res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
3653
                res_tri_out = torch.empty_like(res_dense)
3654
                res_tri = torch.nn.functional.linear(dense, bsr, out=res_tri_out)
3655

3656
                # Check dispatch worked with non-trivial outputs
3657
                if m > 0 and n > 0 and k > 0:
3658
                    self.assertTrue(kernel.kernel_invoked)
3659
                    kernel.kernel_invoked = False
3660
            else:
3661
                # Otherwise check correctness against bmm
3662
                # since nn.linear does not support bsr.dim() > 2.
3663
                res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
3664
                res_tri_out = torch.empty_like(res_dense)
3665
                res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out)
3666

3667
            self.assertTrue(res_tri is res_tri_out)
3668
            self.assertEqual(res_tri, res_dense)
3669

3670
            res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
3671
            # check whether bsr_dense_mm handles different grid sizes
3672
            # None means max possible grid size which is CUDA-dependent.
3673
            grid_size = (None, 2, 4)
3674
            grid_gen = itertools.product(grid_size, repeat=3)
3675
            for grid in grid_gen:
3676
                res_tri = torch.sparse._triton_ops.bsr_dense_mm(
3677
                    bsr,
3678
                    dense.transpose(-2, -1),
3679
                    max_grid=grid,
3680
                )
3681
                self.assertEqual(res_tri, res_dense)
3682

3683
    @onlyCUDA
3684
    @dtypes(torch.half)
3685
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
3686
                     "Skipped for deploy and internal with remote GPUs")
3687
    def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
3688
        from torch.sparse._triton_ops import bsr_dense_mm
3689

3690
        rhs = torch.rand(32, 32, dtype=dtype, device=device)
3691
        lhs = rhs.to_sparse_bsr(16)
3692
        with self.assertRaisesRegex(ValueError, "only BSR sparse format is supported"):
3693
            bsr_dense_mm(lhs.to_sparse_bsc(16), rhs)
3694
        with self.assertRaisesRegex(ValueError, "on the same GPU device"):
3695
            bsr_dense_mm(lhs, rhs.cpu())
3696
        if torch.cuda.device_count() > 1:
3697
            with self.assertRaisesRegex(ValueError, "on the same GPU device"):
3698
                bsr_dense_mm(lhs.to("cuda:0"), rhs.to("cuda:1"))
3699
        with self.assertRaisesRegex(ValueError, "all inputs are expected to be of the same dtype"):
3700
            bsr_dense_mm(lhs, rhs.to(torch.float))
3701
        with self.assertRaisesRegex(ValueError, r"and one of \(half, bfloat16, float32\)"):
3702
            bsr_dense_mm(lhs.to(torch.double), rhs.to(torch.double))
3703
        with self.assertRaisesRegex(ValueError, "all inputs involved in the matrix product are expected to be at least 2D"):
3704
            bsr_dense_mm(lhs, torch.rand(1, dtype=dtype, device=device))
3705
        with self.assertRaisesRegex(ValueError,
3706
                                    "sizes involved in the matrix product are not compatible for matrix multiplication"):
3707
            bsr_dense_mm(lhs, torch.rand(1, 1, dtype=dtype, device=device))
3708
        with self.assertRaisesRegex(ValueError,
3709
                                    r"dense.size\(-1\) == 15 should be divisible by 16"):
3710
            bsr_dense_mm(lhs, torch.rand(32, 15, dtype=dtype, device=device))
3711
        # Blocksizes check
3712
        for blocksize in (15, 30):
3713
            n = blocksize * 2
3714
            rhs = torch.rand(n, n, dtype=dtype, device=device)
3715
            lhs = rhs.to_sparse_bsr(blocksize)
3716
            with self.assertRaisesRegex(ValueError, "should be at least 16 and a power of 2"):
3717
                bsr_dense_mm(lhs, rhs)
3718
        # out check
3719
        rhs = torch.rand(2, 32, 32, dtype=dtype, device=device)
3720
        lhs = rhs.to_sparse_bsr(16)
3721
        with self.assertRaisesRegex(ValueError, r"`out` argument has wrong shape"):
3722
            out = torch.rand(2, 30, 30, dtype=dtype, device=device)
3723
            bsr_dense_mm(lhs, rhs, out=out)
3724
        with self.assertRaisesRegex(ValueError, r"only row-major/col-major `out`"):
3725
            out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1)
3726
            bsr_dense_mm(lhs, rhs, out=out)
3727

3728
    @parametrize("block_size", [16, 32, 64])
3729
    @onlyCUDA
3730
    @skipIfRocm
3731
    @dtypes(torch.half, torch.bfloat16, torch.float)
3732
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3733
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3734
    @precisionOverride({torch.float16: 1e-3})
3735
    def test_triton_scaled_dot_product_attention(self, device, dtype, block_size):
3736
        from functools import partial
3737
        from torch.sparse._triton_ops import _scaled_dot_product_attention
3738

3739
        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3740
        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)
3741

3742
        def broadcast_input(*ts):
3743
            batch_dims = torch.broadcast_shapes(*(t.shape[:-2] for t in ts))
3744
            yield from (torch.broadcast_to(t, batch_dims + t.shape[-2:]) for t in ts)
3745

3746
        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3747
        batches = [(), (2,), (2, 2)]
3748
        size = [128, 256, 0]
3749

3750
        for bam, bq, bk, bv, m, n, k in itertools.product(batches, batches, batches, batches, size, size, size):
3751
            query = tensor(bq + (m, k))
3752
            key = tensor(bk + (n, k))
3753
            value = tensor(bv + (n, k))
3754

3755
            # We make attn_mask block lower/upper triangular so that BSR and Strided
3756
            # function variants are directly comparable.
3757
            attn_mask = torch.ones(bam + (m, n), device=device, dtype=torch.bool)
3758
            attn_mask = self._to_block_triangular_inplace(attn_mask, block_size, block_size)
3759
            attn_mask_bsr = attn_mask.to_sparse_bsr(block_size)
3760

3761
            # NOTE: only boolean mask is directly compatible with the Strided version
3762
            # without any pre-/post-processing. Hence we test against a boolean mask.
3763
            for scale in (None, 1. / 16):
3764
                if scale is None and query.size(-1) == 0:
3765
                    scale = 1
3766
                expected = torch.nn.functional.scaled_dot_product_attention(
3767
                    *broadcast_input(query, key, value, attn_mask), scale=scale
3768
                )
3769

3770
                for mask_dtype in (torch.bool, dtype):
3771
                    res = _scaled_dot_product_attention(query, key, value, attn_mask_bsr.to(mask_dtype), scale=scale)
3772
                    self.assertEqual(res, expected)
3773

3774

3775
    @parametrize("block_size", [16, 32, 64])
3776
    @onlyCUDA
3777
    @skipIfRocm
3778
    @dtypes(torch.half, torch.bfloat16, torch.float)
3779
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3780
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3781
    def test_triton_sampled_addmm(self, device, dtype, block_size):
3782
        from functools import partial
3783
        from torch.sparse._triton_ops import sampled_addmm, broadcast_batch_dims_bsr
3784

3785
        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3786
        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)
3787

3788
        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3789
        batches = [(), (2,), (2, 2)]
3790
        size = [128, 256, 0]
3791

3792
        delta_k = (-3,)
3793
        for bi, bm1, bm2, m, n, k, dk in itertools.product(batches, batches, batches, size, size, size, delta_k):
3794
            # Test not powers of 2 ks as well.
3795
            k = max(0, k + dk)
3796
            # Non-trivial sparsity pattern.
3797
            # Plus with tril inputs the result is also tril,
3798
            # so we can compare BSR and CSR implementations.
3799
            input = tensor(bi + (m, n)).tril_()
3800
            bsr = input.to_sparse_bsr(block_size)
3801
            mat1 = tensor(bm1 + (m, k)).tril_()
3802
            mat2 = tensor(bm2 + (k, n)).tril_()
3803

3804
            batch_dim = torch.broadcast_shapes(input.shape[:-2], mat1.shape[:-2], mat2.shape[:-2])
3805

3806
            csr = input.broadcast_to(batch_dim + input.shape[-2:]).to_sparse_csr().to(torch.float)
3807
            mat1csr = mat1.broadcast_to(batch_dim + mat1.shape[-2:]).to(torch.float)
3808
            mat2csr = mat2.broadcast_to(batch_dim + mat2.shape[-2:]).to(torch.float)
3809

3810
            input_broadcasted_clone = broadcast_batch_dims_bsr(
3811
                "test_triton_sampled_addmm",
3812
                bsr, mat1, mat2
3813
            ).clone()
3814
            input_broadcasted_clone = torch.sparse_compressed_tensor(
3815
                input_broadcasted_clone.crow_indices(),
3816
                input_broadcasted_clone.col_indices(),
3817
                # For testing `out=` let's make values to have "weird" strides
3818
                # so that if the kernel modifies values to it's needs, the result
3819
                # is being compied into out.values.
3820
                input_broadcasted_clone.values().transpose(-3, -2).contiguous().transpose(-3, -2),
3821
                layout=input_broadcasted_clone.layout,
3822
                size=input_broadcasted_clone.shape
3823
            )
3824

3825
            scalars = (0.0, 2.0)
3826
            for alpha, beta, out in itertools.product(scalars, scalars, (None, input_broadcasted_clone)):
3827
                res_tri = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, out=out)
3828
                if out is not None:
3829
                    self.assertTrue(res_tri is out)
3830

3831
                batch_broadcasted_shape = torch.broadcast_shapes(*(t.shape[:-2] for t in (input, mat1, mat2)))
3832
                self.assertTrue(res_tri.shape == batch_broadcasted_shape + (m, n))
3833

3834
                res_csr = torch.sparse.sampled_addmm(csr, mat1csr, mat2csr, alpha=alpha, beta=beta).to(input.dtype)
3835
                self.assertEqual(res_tri.to_dense(), res_csr.to_dense())
3836

3837
                # Check different grid sizes to make sure that input slicing works
3838
                # if this input is larger than the grid.
3839
                grid_size = (3, None)
3840
                grid_gen = itertools.product(grid_size, repeat=2)
3841
                for grid in grid_gen:
3842
                    res_tri_grid = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, max_grid=grid)
3843
                    self.assertEqual(res_tri, res_tri_grid)
3844

3845
    @onlyCUDA
3846
    @skipIfRocm
3847
    @dtypes(torch.half, torch.bfloat16, torch.float)
3848
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3849
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3850
    def test_triton_scatter_mm(self, device, dtype):
3851
        from torch.sparse._triton_ops import scatter_mm
3852
        from functools import partial
3853
        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3854
        sizes = [8, 16]
3855
        for m, k, n in itertools.product(sizes, sizes, sizes):
3856
            blocks = torch.stack([tensor(m, k), tensor(m, k)])
3857
            others = torch.stack([tensor(k, n), tensor(k, n)])
3858

3859
            expected = torch.stack([blocks[0] @ others[0] + blocks[1] @ others[0],
3860
                                    blocks[0] @ others[1],
3861
                                    blocks[1] @ others[1]])
3862

3863
            indices_data = (
3864
                'scatter_mm',
3865
                torch.tensor([0, 2, 3, 4], dtype=torch.int32, device=device),
3866
                torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32, device=device))
3867

3868
            result = scatter_mm(blocks, others, indices_data=indices_data)
3869

3870
            self.assertEqual(result, expected)
3871

3872
            indices_data = (
3873
                'bsr_strided_mm',
3874
                torch.tensor([0, 2, 4, 5, 6], dtype=torch.int32, device=device),
3875
                torch.tensor([0, n, 2 * n * m, 2 * n * m + n], dtype=torch.int32, device=device),
3876
                torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.int32, device=device),
3877
                torch.tensor([0, 2 * k * n, n, 2 * k * n + n, 2 * k * n, 2 * k * n + n],
3878
                             dtype=torch.int32, device=device),
3879
                dict(SPLIT_N=2, is_compressed=False, TILE_M=m, TILE_N=n, GROUP_SIZE=1)
3880
            )
3881

3882
            for bsize in [(), (2,), (3, 4)]:
3883
                other = tensor(*bsize, 2 * k, 2 * n)
3884
                expected = torch.cat([
3885
                    torch.cat([blocks[1], blocks[0]], dim=1),
3886
                    torch.cat([torch.zeros_like(blocks[0]), blocks[1]], dim=1)], dim=0) @ other
3887
                result = scatter_mm(blocks, other, indices_data=indices_data)
3888
                self.assertEqual(result, expected)
3889

3890
    @parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64])
3891
    @onlyCUDA
3892
    @dtypes(torch.half, torch.bfloat16, torch.float)
3893
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3894
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3895
    def test_triton_bsr_scatter_mm(self, device, dtype, blocksize):
3896
        import triton
3897
        from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
3898
        from functools import partial
3899
        if isinstance(blocksize, str):
3900
            blocksize = tuple(map(int, blocksize.split('x')))
3901
        else:
3902
            blocksize = (blocksize,) * 2
3903
        # Note that each value in a non-zero block is in range blocksize * [low^2, high^2).
3904
        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3905

3906
        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3907
        batches = [(), (2,), (2, 2)]
3908
        sizes = [blocksize[0], 2 * blocksize[0], 4 * blocksize[0]]
3909
        sizes_K = [blocksize[1], 2 * blocksize[1]]
3910

3911
        for bd, bs, M, K, N, has_zero_row_block in itertools.product(batches, batches[:1], sizes, sizes_K, sizes, (False, True)):
3912
            bsr_dense = tensor(bs + (M, K))
3913
            if has_zero_row_block:
3914
                if M > blocksize[0]:
3915
                    bsr_dense[:blocksize[0]].zero_()
3916
                else:
3917
                    continue
3918
            bsr = bsr_dense.to_sparse_bsr(blocksize)
3919
            dense = tensor(bd + (K, N))
3920
            expected = bsr.to_dense() @ dense
3921

3922
            for indices_format in ('bsr_strided_mm', 'bsr_strided_mm_compressed', 'scatter_mm'):
3923
                if indices_format in {'bsr_strided_mm', 'bsr_strided_mm_compressed'}:
3924
                    SPLIT_N_list = [N]
3925
                    while SPLIT_N_list[-1] > 1:
3926
                        SPLIT_N_list.append(max(1, SPLIT_N_list[-1] // 2))
3927
                else:
3928
                    SPLIT_N_list = [1]
3929
                for SPLIT_N in SPLIT_N_list:
3930
                    indices_data = bsr_scatter_mm_indices_data(
3931
                        bsr, dense, indices_format=indices_format, SPLIT_N=SPLIT_N)
3932
                    try:
3933
                        result = bsr_scatter_mm(bsr, dense, indices_data=indices_data)
3934
                    except triton.compiler.OutOfResources:
3935
                        # ensure that there was at least one succesful test:
3936
                        assert SPLIT_N < SPLIT_N_list[0]
3937
                        break
3938

3939
                    self.assertEqual(result, expected)
3940
        torch.sparse._triton_ops._bsr_scatter_mm_indices_data.cache_clear()
3941

3942
    def test_TensorAsKey(self, device):
3943
        from torch.sparse._triton_ops import TensorAsKey
3944
        assertEqualOptions = dict(exact_dtype=True, exact_device=True, exact_layout=True)
3945

3946
        t = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device=device)
3947
        key = TensorAsKey(t)
3948
        self.assertTrue(key == TensorAsKey(t))
3949
        self.assertTrue(key.obj is t)
3950

3951
        t2 = t[:]
3952
        key2 = TensorAsKey(t2)
3953
        self.assertTrue(key == key2)
3954
        self.assertEqual(key2.obj, t, **assertEqualOptions)
3955
        # deleting object leads to dead key
3956
        del t2
3957
        self.assertTrue(key2.obj is None)
3958
        self.assertTrue(key.obj is t)
3959

3960
        # key with different storage offset and shape:
3961
        self.assertFalse(key == TensorAsKey(t[1:]))
3962

3963
        # key with different strides:
3964
        self.assertFalse(key == TensorAsKey(t[::2]))
3965

3966
        # when object dies, make sure that key represents a dead
3967
        # object as well:
3968
        del t
3969
        self.assertTrue(key.obj is None)
3970

3971
        # Storing a tensor as a dict key:
3972
        d = {}
3973
        t3 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
3974
        key3 = TensorAsKey(t3)
3975
        d[key3] = 123
3976
        self.assertTrue(d.get(key3) == 123)
3977
        t3_ = t3[:]
3978
        self.assertTrue(d.get(TensorAsKey(t3_)) == 123)
3979
        self.assertTrue(d.get(TensorAsKey(t3.clone())) is None)
3980

3981
        d[TensorAsKey(t3_)] = 567
3982
        self.assertTrue(d.get(key3) == 567)
3983

3984
        # t3 and t3_ reference the same data, so, the key becomes dead
3985
        # (that is, its .obj property returns None) until all
3986
        # references are deleted:
3987
        del t3
3988
        self.assertTrue(key3.obj is not None)
3989
        self.assertTrue(d.get(key3) == 567)
3990
        del t3_
3991
        self.assertTrue(key3.obj is None)
3992
        self.assertTrue(d.get(key3) == 567)
3993

3994
        # Storing a tensor as a dict key and value:
3995
        d = {}
3996
        t4 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
3997
        key4 = TensorAsKey(t4)
3998
        d[key4] = (t4, 123)
3999
        self.assertEqual(d.get(key4), (t4, 123), **assertEqualOptions)
4000
        # when object is deleted, the key represents an alive object
4001
        # because the object is referenced by the dict item value:
4002
        del t4
4003
        self.assertTrue(key4.obj is not None)
4004
        # This also means that the life time of the tensor is same as
4005
        # the life time of the corresponding dict item:
4006
        del d[key4]
4007
        self.assertTrue(key4.obj is None)
4008

4009
        # Storing a tensor as a dict key and value wrapped with TensorAsKey:
4010
        d = {}
4011
        t5 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
4012
        key5 = TensorAsKey(t5)
4013
        d[key5] = (key5, 567)
4014
        self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
4015
        self.assertTrue(key5.obj is not None)
4016
        # when object is deleted, it will be dead as the wrapped value
4017
        # hold the tensor instance as a weakref:
4018
        del t5
4019
        self.assertTrue(key5.obj is None)
4020
        # but key is still valid:
4021
        self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
4022

4023
    @suppress_warnings
4024
    @parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm'])
4025
    @parametrize("blocksize", [16, '16x32', 32])
4026
    @onlyCUDA
4027
    @skipIfRocm
4028
    @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
4029
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
4030
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4031
    def test_triton_kernel(self, op, device, dtype, blocksize):
4032
        from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
4033
        from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
4034
                                                   optimize_bsr_dense_addmm, dump)
4035

4036
        def bsr_dense_linear(input, weights, bias=None):
4037
            return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
4038

4039
        operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
4040
                         _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
4041

4042
        def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
4043
            assert mat1.layout is torch.strided
4044
            assert mat2.layout is torch.strided
4045
            if dtype is torch.int8:
4046
                if op == '_int_bsr_dense_addmm':
4047
                    return beta * input + alpha * torch._int_mm(mat1, mat2)
4048
                # workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
4049
                return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
4050
            return beta * input + alpha * (mat1 @ mat2)
4051

4052
        if op == '_int_bsr_dense_addmm':
4053
            # _int_bsr_dense_addmm is same as bsr_dense_addmm except
4054
            # with int8 inputs, _int_bsr_dense_addmm returns int32
4055
            # result. This is covered by operation and reference
4056
            # definitions above and all other definitions below are
4057
            # identical between _int_bsr_dense_addmm and
4058
            # bsr_dense_addmm.
4059
            op = 'bsr_dense_addmm'
4060

4061
        def nc_copy(t, axes=(-1,)):
4062
            """Return a copy of input.
4063

4064
            The returned copy will be a non-contiguous tensor.
4065
            """
4066
            if t.layout is torch.strided:
4067
                shape = list(t.shape)
4068
                for a in axes:
4069
                    shape[a] *= 2
4070
                r = torch.empty(shape, dtype=t.dtype, device=t.device)
4071
                s = r[tuple(slice(None, None, 2 if t.shape[i] != r.shape[i] else None) for i in range(t.ndim))]
4072
                s.copy_(t)
4073
                return s
4074
            elif t.layout is torch.sparse_bsr:
4075
                compressed_indices = t.crow_indices()
4076
                plain_indices = t.col_indices()
4077
                return torch.sparse_compressed_tensor(compressed_indices, plain_indices, nc_copy(t.values()),
4078
                                                      t.shape, layout=t.layout)
4079
            else:
4080
                raise NotImplementedError(t.layout)
4081

4082
        if isinstance(blocksize, str):
4083
            BM, BK = tuple(map(int, blocksize.split('x')))
4084
        else:
4085
            BM, BK = (blocksize,) * 2
4086

4087
        if op in {"bsr_dense_linear"} and BM != BK:
4088
            # todo: eliminate this skip
4089
            self.skipTest(f"{op} does not support non-square blocks")
4090

4091
        if op in {"bsr_dense_linear"} and dtype is torch.int8:
4092
            # todo: eliminate this skip
4093
            self.skipTest(f"{op} does not support int8")
4094

4095
        if dtype is torch.int8 and min(BM, BK) < 32:
4096
            self.skipTest("triton kernel does not support support int8 blocks smaller than 32")
4097

4098
        beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
4099
        alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
4100
        sparsity_lst = [0, 0.5, 1]
4101
        blocks_per_row_lst = [1, 2]
4102
        blocks_per_col_lst = [1, 2]
4103
        result_cols_lst = [16, 32, 64]
4104
        for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
4105
                beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
4106
            M = BM * blocks_per_row
4107
            K = BK * blocks_per_col
4108
            mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
4109
            bsr = mat1.to_sparse_bsr((BM, BK))
4110
            mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
4111
            input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
4112

4113
            if 0 and op == "bsr_dense_addmm":
4114
                # Find optimal kernel parameters, the speed-up is
4115
                # about 10x for running this test.
4116
                #
4117
                # Enable this if-block when the test method is
4118
                # updated, run the test, and finally, disable the
4119
                # if-block.
4120
                key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
4121
                meta = get_meta(op, key, version=(0, dtype, 0.5))
4122
                if meta is None:
4123
                    optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
4124
                    meta = get_meta(op, key, version=(0, dtype, 0.5))
4125
                    assert meta is not None
4126
                    dump()  # this will update torch/sparse/_triton_ops_meta.py
4127

4128
            expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
4129
            kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={},
4130
                          bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
4131

4132
            args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
4133
                        bsr_dense_linear=(mat2.transpose(-1, -2), bsr))[op]
4134
            result = operation(*args, **kwargs)
4135
            self.assertEqual(result, expected)
4136

4137
            # Test non-contiguous input tensors:
4138
            nc_mat2 = nc_copy(mat2)
4139
            nc_input = nc_copy(input)
4140
            nc_bsr = nc_copy(bsr)
4141

4142
            args = dict(bsr_dense_addmm=(input, bsr, nc_mat2), bsr_dense_mm=(bsr, nc_mat2),
4143
                        bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
4144
            result = operation(*args, **kwargs)
4145
            self.assertEqual(result, expected)
4146

4147
            # todo: add bsr_dense_linear to the set below (currently,
4148
            # nn.linear has unnecessarily restrictive arguments
4149
            # checks).
4150
            if op in {'bsr_dense_addmm', 'bsr_dense_mm'}:
4151
                args = dict(bsr_dense_addmm=(input, nc_bsr, mat2), bsr_dense_mm=(nc_bsr, mat2),
4152
                            bsr_dense_linear=(mat2.transpose(-1, -2), nc_bsr))[op]
4153
                result = operation(*args, **kwargs)
4154
                self.assertEqual(result, expected)
4155

4156
            if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
4157
                args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
4158
                            bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
4159
                kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
4160
                              bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
4161
                result = operation(*args, **kwargs)
4162
                self.assertEqual(result, expected)
4163

4164
    @parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm'])
4165
    @onlyCUDA
4166
    @skipIfRocm
4167
    @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
4168
    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
4169
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4170
    def test_triton_tune(self, op, device, dtype):
4171
        from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm
4172
        from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta)
4173

4174
        operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
4175
        tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm,
4176
                     _int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op]
4177

4178
        if op == '_int_bsr_dense_addmm':
4179
            M, K, N = 32, 32, 32
4180
            blocksize = (32, 32)
4181
        else:
4182
            M, K, N = 16, 16, 32
4183
            blocksize = (16, 16)
4184
        sparsity = 1.0
4185
        bsr = create_blocked_tensor(0, M, K, blocksize, sparsity, dtype, device).to_sparse_bsr(blocksize)
4186
        sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K)
4187
        input = make_tensor(K, N, dtype=dtype, device=device)
4188
        dense = make_tensor(K, N, dtype=dtype, device=device)
4189

4190
        if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}:
4191
            args = (input, bsr, dense)
4192

4193
            def get_current_meta():
4194
                version = (0, dtype, sparsity)
4195
                meta_key = (M, K, N, *blocksize, False, True, True)
4196
                return get_meta(op, meta_key, version=version, exact=True)
4197
        else:
4198
            raise NotImplementedError(op)
4199

4200
        self.assertEqual(get_current_meta(), None)
4201

4202
        meta = tuner(*args, **dict(store=True, verbose=False))
4203
        self.assertEqual(get_current_meta(), meta)
4204

4205
        expected = operation(*args)
4206
        result = operation(*args, **dict(meta=meta))
4207
        self.assertEqual(result, expected)
4208

4209
    @onlyCUDA
4210
    @skipIfRocm
4211
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4212
    def test_triton_bsr_dense_addmm_meta(self, device):
4213
        from torch.sparse._triton_ops import bsr_dense_addmm_meta
4214
        from torch.sparse._triton_ops_meta import update as update_bsr_dense_addmm_meta
4215

4216
        dtype = torch.float32
4217
        Ms = Ks = 16
4218
        beta = 0.0
4219
        alpha = 1.0
4220

4221
        def get_meta(M, K, N, sparsity=None):
4222
            return bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, dtype=dtype, sparsity=sparsity,
4223
                                        _version="test_triton_bsr_dense_addmm_meta")
4224

4225
        def update_meta(M, K, N, value, sparsity=0.5):
4226
            key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
4227
            update_bsr_dense_addmm_meta("bsr_dense_addmm", torch.cuda.get_device_name(),
4228
                                        ("test_triton_bsr_dense_addmm_meta", dtype, sparsity),
4229
                                        key, value)
4230

4231
        def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None):
4232
            f = io.StringIO()
4233
            with redirect_stderr(f):
4234
                result = get_meta(M, K, N, sparsity=sparsity)
4235
            msg = f.getvalue()
4236
            FileCheck().check_count(
4237
                str=f"UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M={M} K={K} N={N}",
4238
                count=warn_count, exactly=True
4239
            ).run(msg)
4240
            return result
4241

4242
        # Test warn_once when requesting non-existing tuned parameters multiple times
4243
        f = io.StringIO()
4244
        with redirect_stderr(f):
4245
            for i in range(5):
4246
                get_meta(16, 16, 16)
4247
            for i in range(5):
4248
                get_meta(16, 16, 32)
4249

4250
        msg = f.getvalue()
4251
        FileCheck().check_count(
4252
            str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=16", count=1, exactly=True
4253
        ).run(msg)
4254
        FileCheck().check_count(
4255
            str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=32", count=1, exactly=True
4256
        ).run(msg)
4257

4258
        # Test warn_once when tuned parameters are missing
4259
        default_meta = dict(GROUP_SIZE_ROW=4, SPLIT_N=2, num_stages=1, num_warps=4)
4260
        self.assertEqual(get_meta_with_checks(32, 32, 32, warn_count=1), default_meta)
4261

4262
        # Test (no)warn_once when tuned parameters are available
4263
        update_meta(32, 32, 48, (2, 8, 5, 6))
4264
        expected_meta = dict(GROUP_SIZE_ROW=2, SPLIT_N=8, num_stages=5, num_warps=6)
4265
        self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0), expected_meta)
4266

4267
        # Test non-existing tuned parameters with non-default sparsity
4268
        # while for default sparsity 0.5 the parameters are available
4269
        self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0, sparsity=0.6), expected_meta)
4270

4271
        # Test non-existing tuned parameters while there exists
4272
        # parameters with consistent N // SPLIT_N ratio:
4273
        self.assertEqual(get_meta_with_checks(32, 32, 72, warn_count=0),
4274
                         dict(GROUP_SIZE_ROW=2, SPLIT_N=12, num_stages=5, num_warps=6))
4275
        # ... or not:
4276
        self.assertEqual(get_meta_with_checks(32, 32, 64, warn_count=1),
4277
                         dict(GROUP_SIZE_ROW=4, SPLIT_N=4, num_stages=1, num_warps=4))
4278

4279

4280
# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
4281
instantiate_device_type_tests(TestSparseCSR, globals())
4282
instantiate_device_type_tests(TestSparseCompressed, globals())
4283
instantiate_device_type_tests(TestSparseCompressedTritonKernels, globals())
4284

4285
if __name__ == '__main__':
4286
    run_tests()
4287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.