pytorch

Форк
0
/
test_linalg.py 
8423 строки · 385.2 Кб
1
# Owner(s): ["module: linear algebra"]
2

3
import torch
4
import numpy as np
5

6
import unittest
7
import itertools
8
import warnings
9
import math
10
from math import inf, nan, isnan
11
import re
12
import random
13
from random import randrange
14
from itertools import product
15
from functools import reduce, partial
16

17
from torch.testing._internal.common_utils import \
18
    (TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
19
     TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
20
     make_fullrank_matrices_with_distinct_singular_values,
21
     freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
22
     setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
23
from torch.testing._internal.common_device_type import \
24
    (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
25
     onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
26
     skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA,
27
     onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm,
28
     dtypesIfMPS, largeTensorTest)
29
from torch.testing import make_tensor
30
from torch.testing._internal.common_dtype import (
31
    all_types, all_types_and_complex_and, floating_and_complex_types, integral_types,
32
    floating_and_complex_types_and, floating_types_and, complex_types,
33
)
34
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
35
    _get_torch_cuda_version, CDNA2OrLater
36
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
37
from torch.testing._internal.common_mkldnn import bf32_on_and_off
38
from torch.distributions.binomial import Binomial
39
import torch.backends.opt_einsum as opt_einsum
40
import operator
41

42
# Protects against includes accidentally setting the default dtype
43
assert torch.get_default_dtype() is torch.float32
44

45
if TEST_SCIPY:
46
    import scipy
47

48
def blaslt_supported_device():
49
    if torch.cuda.is_available():
50
        if torch.version.hip:
51
            for arch in ['gfx90a', 'gfx94']:
52
                if arch in torch.cuda.get_device_properties(0).gcnArchName:
53
                    return True
54
        else:
55
            return True
56
    return False
57

58
def set_tunableop_defaults():
59
    if not torch.cuda.is_available():
60
        # TunableOp not supported on CPU at this time.
61
        return
62

63
    # disable TunableOp and restore to default values
64
    ordinal = torch.cuda.current_device()
65
    filename = f"tunableop_results{ordinal}.csv"
66
    torch.cuda.tunable.enable(False)
67
    torch.cuda.tunable.tuning_enable(True)
68
    torch.cuda.tunable.set_filename(filename)  # reset back to default filename for next unit test
69
    torch.cuda.tunable.set_max_tuning_duration(30)
70
    torch.cuda.tunable.set_max_tuning_iterations(100)
71

72

73
class TestLinalg(TestCase):
74
    def setUp(self):
75
        super(self.__class__, self).setUp()
76
        torch.backends.cuda.matmul.allow_tf32 = False
77

78
    def tearDown(self):
79
        torch.backends.cuda.matmul.allow_tf32 = True
80
        super(self.__class__, self).tearDown()
81

82
    exact_dtype = True
83

84
    @dtypes(torch.float, torch.cfloat)
85
    @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06})
86
    @tf32_on_and_off(5e-3)
87
    @bf32_on_and_off(5e-3)
88
    def test_inner(self, device, dtype):
89
        def check(a_sizes_, b_sizes_):
90
            for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)):
91
                a = torch.randn(a_sizes, dtype=dtype, device=device)
92
                b = torch.randn(b_sizes, dtype=dtype, device=device)
93
                res = torch.inner(a, b)
94
                ref = np.inner(a.cpu().numpy(), b.cpu().numpy())
95
                self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref)))
96
                out = torch.zeros_like(res)
97
                torch.inner(a, b, out=out)
98
                self.assertEqual(res, out)
99

100
        check([], [])                       # scalar x scalar
101
        check([], [0])                      # scalar x empty
102
        check([], [3])                      # scalar x 1D
103
        check([], [2, 3, 4])                # scalar x 3D
104

105
        check([0], [0])                     # empty x empty
106
        check([0], [2, 0])                  # empty x 2D
107

108
        check([2], [2])                     # 1D x 1D
109
        check([2], [3, 1, 2])               # 1D x 3D
110
        check([2], [3, 0, 2])               # 1D x 3D empty
111

112
        check([1, 2], [3, 2])               # 2D x 2D
113
        check([1, 2], [3, 4, 2])            # 2D x 3D
114
        check([2, 1, 3, 2], [1, 3, 2, 2])   # 4D x 4D
115

116
        # Test error message
117
        with self.assertRaisesRegex(RuntimeError,
118
                                    r"inner\(\) the last dimension must match on both "
119
                                    r"input tensors but got shapes \[2, 3\] and \[2, 2\]"):
120
            torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype))
121

122
    # Tests torch.outer, and its alias, torch.ger, vs. NumPy
123
    @precisionOverride({torch.bfloat16: 1e-1})
124
    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
125
    def test_outer(self, device, dtype):
126
        def run_test_case(a, b):
127
            if dtype == torch.bfloat16:
128
                a_np = a.to(torch.double).cpu().numpy()
129
                b_np = b.to(torch.double).cpu().numpy()
130
                exact_dtype = False
131
            else:
132
                a_np = a.cpu().numpy()
133
                b_np = b.cpu().numpy()
134
                exact_dtype = True
135
            expected = np.outer(a_np, b_np)
136

137
            self.assertEqual(torch.outer(a, b), expected, exact_dtype=False)
138
            self.assertEqual(torch.Tensor.outer(a, b), expected, exact_dtype=False)
139

140
            self.assertEqual(torch.ger(a, b), expected, exact_dtype=False)
141
            self.assertEqual(torch.Tensor.ger(a, b), expected, exact_dtype=False)
142

143
            # test out variant
144
            out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
145
            torch.outer(a, b, out=out)
146
            self.assertEqual(out, expected, exact_dtype=False)
147

148
            out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
149
            torch.ger(a, b, out=out)
150
            self.assertEqual(out, expected, exact_dtype=False)
151

152
        a = torch.randn(50).to(device=device, dtype=dtype)
153
        b = torch.randn(50).to(device=device, dtype=dtype)
154
        run_test_case(a, b)
155

156
        # test 0 strided tensor
157
        zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50)
158
        run_test_case(zero_strided, b)
159
        run_test_case(a, zero_strided)
160

161
    def test_matrix_rank_removed_error(self, device):
162
        a = make_tensor(5, 5, device=device, dtype=torch.float32)
163
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
164
            torch.matrix_rank(a)
165

166
    def test_solve_removed_error(self, device):
167
        a = make_tensor(5, 5, device=device, dtype=torch.float32)
168
        b = make_tensor(5, 1, device=device, dtype=torch.float32)
169
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
170
            torch.solve(b, a)
171
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
172
            b.solve(a)
173

174
    def test_eig_removed_error(self, device):
175
        a = make_tensor(5, 5, device=device, dtype=torch.float32)
176
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
177
            torch.eig(a)
178
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
179
            a.eig()
180

181
    def test_symeig_removed_error(self, device):
182
        a = make_tensor(5, 5, device=device, dtype=torch.float32)
183
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
184
            torch.symeig(a)
185
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
186
            a.symeig()
187

188
    def test_lstsq_removed_error(self, device):
189
        a = make_tensor(5, 5, device=device, dtype=torch.float32)
190
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
191
            torch.lstsq(a, a)
192
        with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
193
            a.lstsq(a)
194

195
    @skipCUDAIfNoMagma
196
    @skipCPUIfNoLapack
197
    @skipIfTorchDynamo("flaky, needs investigation")
198
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
199
    def test_linalg_lstsq(self, device, dtype):
200
        from torch.testing._internal.common_utils import random_well_conditioned_matrix
201
        if self.device_type == 'cpu':
202
            drivers = ('gels', 'gelsy', 'gelsd', 'gelss', None)
203
        else:
204
            drivers = ('gels', None)
205

206
        def check_solution_correctness(a, b, sol):
207
            sol2 = a.pinverse() @ b
208
            self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
209

210
        def check_correctness_ref(a, b, res, ref, driver="default"):
211
            def apply_if_not_empty(t, f):
212
                if t.numel():
213
                    return f(t)
214
                else:
215
                    return t
216

217
            def select_if_not_empty(t, i):
218
                selected = apply_if_not_empty(t, lambda x: x.select(0, i))
219
                return selected
220

221
            m = a.size(-2)
222
            n = a.size(-1)
223
            nrhs = b.size(-1)
224
            batch_size = int(np.prod(a.shape[:-2]))
225
            if batch_size == 0:
226
                batch_size = 1
227
            a_3d = a.view(batch_size, m, n)
228
            b_3d = b.view(batch_size, m, nrhs)
229

230
            solution_3d = res.solution.view(batch_size, n, nrhs)
231
            residuals_2d = apply_if_not_empty(res.residuals, lambda t: t.view(-1, nrhs))
232
            rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1))
233
            singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1])
234

235
            if a.numel() > 0:
236
                for i in range(batch_size):
237
                    sol, residuals, rank, singular_values = ref(
238
                        a_3d.select(0, i).numpy(),
239
                        b_3d.select(0, i).numpy()
240
                    )
241
                    # Singular values are None when lapack_driver='gelsy' in SciPy
242
                    if singular_values is None:
243
                        singular_values = []
244
                    self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
245
                    self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
246
                    self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
247

248
                    # SciPy and NumPy operate only on non-batched input and
249
                    # return an empty array with shape (0,) if rank(a) != n
250
                    # in PyTorch the batched inputs are supported and
251
                    # matrices in the batched input can have different ranks
252
                    # we compute residuals only if all matrices have rank == n
253
                    # see https://github.com/pytorch/pytorch/issues/56483
254
                    if m > n:
255
                        if torch.all(rank_1d == n):
256
                            self.assertEqual(
257
                                residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5, exact_dtype=False
258
                            )
259
                        else:
260
                            self.assertTrue(residuals_2d.numel() == 0)
261

262
            else:
263
                self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs))
264
                self.assertEqual(res.rank.shape, a.shape[:-2])
265

266
                # residuals are not always computed (and have non-zero shape)
267
                if m > n and driver != "gelsy":
268
                    self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0))
269
                else:
270
                    self.assertEqual(res.residuals.shape, (0, ))
271

272
                # singular_values are not always computed (and have non-zero shape)
273
                if driver == "default" or driver == "gelsd" or driver == "gelss":
274
                    self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n)))
275
                else:
276
                    self.assertEqual(res.singular_values.shape, (0, ))
277

278
        def check_correctness_scipy(a, b, res, driver, cond):
279
            # SciPy provides 3 driver options: gelsd, gelss, gelsy
280
            if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
281
                import scipy.linalg
282

283
                def scipy_ref(a, b):
284
                    return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
285
                check_correctness_ref(a, b, res, scipy_ref, driver=driver)
286

287
        def check_correctness_numpy(a, b, res, driver, rcond):
288
            # NumPy uses only gelsd routine
289
            if driver == 'gelsd':
290

291
                def numpy_ref(a, b):
292
                    return np.linalg.lstsq(a, b, rcond=rcond)
293
                check_correctness_ref(a, b, res, numpy_ref)
294

295
        ms = [2 ** i for i in range(5)]
296
        m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
297
        # cases m < n are only supported on CPU and for cuSOLVER path on CUDA
298
        m_l_n_sizes = [(m // 2, m) for m in ms]
299
        include_m_l_n_case = (has_cusolver() or device == 'cpu')
300
        matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if include_m_l_n_case else [])
301
        batches = [(), (2,), (2, 2), (2, 2, 2)]
302
        # we generate matrices with singular values sampled from a normal distribution,
303
        # that is why we use `cond=1.0`, the mean to cut roughly half of all
304
        # the singular values and compare whether torch.linalg.lstsq agrees with
305
        # SciPy and NumPy.
306
        # if rcond is True then set value for it based on the used algorithm
307
        # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance
308
        rconds = (None, True, -1)
309

310
        for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
311
            # keep the rcond value if it is None or -1, set the driver specific value if it is True
312
            if rcond and rcond != -1:
313
                if driver in ('gelss', 'gelsd'):
314
                    # SVD based algorithm; set to zero roughly half of all the singular values
315
                    rcond = 1.0
316
                else:
317
                    # driver == 'gelsy'
318
                    # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests
319
                    # so we skip this case
320
                    continue
321

322
            # specifying rcond value has no effect for gels driver so no need to run the tests again
323
            if driver == 'gels' and rcond is not None:
324
                continue
325

326
            shape = batch + matrix_size
327
            a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
328
            b = torch.rand(*shape, dtype=dtype, device=device)
329

330
            m = a.size(-2)
331
            n = a.size(-1)
332
            res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
333
            sol = res.solution
334

335
            # Only checks gelsd, gelss, gelsy drivers
336
            check_correctness_scipy(a, b, res, driver, rcond)
337

338
            # Only checks gelsd driver
339
            check_correctness_numpy(a, b, res, driver, rcond)
340

341
            # gels driver is not checked by comparing to NumPy or SciPy implementation
342
            # because NumPy and SciPy do not implement this driver
343
            if driver == 'gels' and rcond is None:
344
                check_solution_correctness(a, b, sol)
345

346
    @skipCUDAIfNoMagma
347
    @skipCPUIfNoLapack
348
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
349
    def test_linalg_lstsq_batch_broadcasting(self, device, dtype):
350
        from torch.testing._internal.common_utils import random_well_conditioned_matrix
351

352
        def check_correctness(a, b):
353
            sol = torch.linalg.lstsq(a, b).solution
354
            sol2 = a.pinverse() @ b
355
            self.assertEqual(sol, sol2, rtol=1e-5, atol=1e-5)
356

357
        ms = [2 ** i for i in range(5)]
358
        batches = [(), (0,), (2,), (2, 2), (2, 2, 2)]
359
        # the case when a single matrix is batch-broadcasted over the rhs
360
        for m, batch in itertools.product(ms, batches):
361
            a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m)
362
            b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device)
363
            check_correctness(a, b)
364

365
        # cases with broadcastable shapes
366
        for m in ms:
367
            a = random_well_conditioned_matrix(1, 3, 1, 3, m, m, dtype=dtype, device=device)
368
            b = torch.rand(3, 1, 3, 1, m, m // 2, dtype=dtype, device=device)
369
            check_correctness(a, b)
370

371
            # rhs are vectors, not matrices in this test
372
            b = torch.rand(3, 1, 3, 1, m, dtype=dtype, device=device)
373
            # unsqueeze for b because `check_correctness` checks against
374
            # a.pinverse() @ b, which requires b to be a matrix
375
            check_correctness(a, b.unsqueeze(-1))
376

377
            a = random_well_conditioned_matrix(3, 1, 3, 1, m, m, dtype=dtype, device=device)
378
            b = torch.rand(1, 3, 1, 3, m, m // 2, dtype=dtype, device=device)
379
            check_correctness(a, b)
380

381
            # rhs are vectors, not matrices in this test
382
            b = torch.rand(1, 3, 1, 3, m, dtype=dtype, device=device)
383
            check_correctness(a, b.unsqueeze(-1))
384

385
    @skipCPUIfNoLapack
386
    @skipCUDAIfNoMagma
387
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
388
    def test_linalg_lstsq_input_checks(self, device, dtype):
389
        # check empty inputs
390
        # empty batches
391
        a = torch.rand(0, 0, 3, 3, dtype=dtype, device=device)
392
        b = torch.rand(0, 0, 3, 2, dtype=dtype, device=device)
393
        self.assertEqual(
394
            torch.linalg.lstsq(a, b)[0],
395
            torch.zeros(0, 0, 3, 2, dtype=dtype, device=device)
396
        )
397
        # empty a and b
398
        a = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
399
        b = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
400
        self.assertEqual(
401
            torch.linalg.lstsq(a, b)[0],
402
            torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
403
        )
404
        # empty a and b
405
        a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
406
        b = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
407
        self.assertEqual(
408
            torch.linalg.lstsq(a, b)[0],
409
            torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
410
        )
411
        # empty a but not b
412
        a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
413
        b = torch.rand(2, 2, 3, 2, dtype=dtype, device=device)
414
        self.assertEqual(
415
            torch.linalg.lstsq(a, b)[0],
416
            torch.zeros(2, 2, 0, 2, dtype=dtype, device=device)
417
        )
418

419
        # empty a and b
420
        if torch.device(device).type == 'cpu':
421
            # only CPU since CUDA does not support overdetermined systems
422
            a = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
423
            b = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
424
            self.assertEqual(
425
                torch.linalg.lstsq(a, b)[0],
426
                torch.zeros(2, 2, 3, 3, dtype=dtype, device=device)
427
            )
428

429
        a = torch.rand(2, 3, dtype=dtype, device=device)
430
        b = torch.rand(3, dtype=dtype, device=device)
431

432
        with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'):
433
            torch.linalg.lstsq(b, b)
434

435
        with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'):
436
            torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device))
437

438
        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'):
439
            torch.linalg.lstsq(a, b)
440

441
        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
442
            torch.linalg.lstsq(a, b.unsqueeze(-1))
443

444
        a = torch.randn(1, 1, 1, dtype=dtype, device=device)
445
        b = torch.randn(3, 1, dtype=dtype, device=device)
446

447
        with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
448
            torch.linalg.lstsq(a, b)
449

450
        def complement_device(device):
451
            if device == 'cpu' and torch.cuda.is_available():
452
                return 'cuda'
453
            else:
454
                return 'cpu'
455

456
        a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
457
        b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device))
458
        if a.device != b.device:
459
            with self.assertRaisesRegex(RuntimeError, 'be on the same device'):
460
                torch.linalg.lstsq(a, b)
461

462
        b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long()
463
        with self.assertRaisesRegex(RuntimeError, 'the same dtype'):
464
            torch.linalg.lstsq(a, b)
465

466
        a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
467
        b = torch.rand(2, 2, 2, dtype=dtype, device=device)
468

469
        if device != 'cpu':
470
            with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'):
471
                torch.linalg.lstsq(a, b, driver='fictitious_driver')
472
        # if on cpu
473
        else:
474
            with self.assertRaisesRegex(RuntimeError, r'parameter `driver` should be one of \(gels, gelsy, gelsd, gelss\)'):
475
                torch.linalg.lstsq(a, b, driver='fictitious_driver')
476

477
        # cuSOLVER path supports underdetermined systems
478
        version = torch.testing._internal.common_cuda._get_torch_cuda_version()
479
        cusolver_not_available = (version < (10, 1))
480

481
        if device != 'cpu' and cusolver_not_available:
482
            a = torch.rand(2, 3, dtype=dtype, device=device)
483
            b = torch.rand(2, 1, dtype=dtype, device=device)
484
            with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'):
485
                torch.linalg.lstsq(a, b)
486

487
    @skipCUDAIfNoMagma
488
    @skipCPUIfNoLapack
489
    @dtypes(*floating_and_complex_types())
490
    def test_cholesky(self, device, dtype):
491
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
492

493
        def run_test(shape, batch, contiguous):
494
            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
495
            if A.numel() > 0 and not contiguous:
496
                A = A.mT
497
                self.assertFalse(A.is_contiguous())
498
            expected_L = np.linalg.cholesky(A.cpu().numpy())
499
            actual_L = torch.linalg.cholesky(A)
500

501
            # For fp32 individual entries in matrices can differ between PyTorch and NumPy
502
            # Let's compare the norms of matrices instead
503
            if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
504
                # axis is specified to calculate matrix norm for batched input
505
                expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
506
                actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
507
                # Compare the norms with standard tolerances
508
                self.assertEqual(actual_norm, expected_norm)
509
                # and individual values with a higher tolerance
510
                self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
511
            else:
512
                self.assertEqual(actual_L, expected_L)
513

514
        shapes = (0, 3, 5)
515
        batches = ((), (3, ), (2, 2))
516
        larger_input_case = [(100, (5, ), True)]
517
        for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case:
518
            run_test(shape, batch, contiguous)
519

520
        # check the out= variant
521
        A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device)
522
        out = torch.empty_like(A)
523
        ans = torch.linalg.cholesky(A, out=out)
524
        self.assertEqual(ans, out)
525
        expected = torch.linalg.cholesky(A)
526
        self.assertEqual(expected, out)
527

528
        # check the upper= variant
529
        expected = torch.linalg.cholesky(A).mH
530
        actual = torch.linalg.cholesky(A, upper=True)
531
        self.assertEqual(expected, actual)
532

533
    @skipCUDAIfNoMagma
534
    @skipCPUIfNoLapack
535
    @dtypes(*floating_and_complex_types())
536
    def test_cholesky_errors_and_warnings(self, device, dtype):
537
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
538

539
        # cholesky requires the input to be a square matrix or batch of square matrices
540
        A = torch.randn(2, 3, device=device, dtype=dtype)
541
        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
542
            torch.linalg.cholesky(A)
543
        A = torch.randn(2, 2, 3, device=device, dtype=dtype)
544
        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
545
            torch.linalg.cholesky(A)
546
        with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'):
547
            np.linalg.cholesky(A.cpu().numpy())
548

549
        # cholesky requires the input to be at least 2 dimensional tensor
550
        A = torch.randn(2, device=device, dtype=dtype)
551
        with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
552
            torch.linalg.cholesky(A)
553
        with self.assertRaisesRegex(np.linalg.LinAlgError,
554
                                    r'1-dimensional array given\. Array must be at least two-dimensional'):
555
            np.linalg.cholesky(A.cpu().numpy())
556

557
        # if the input matrix is not positive definite, an error should be raised
558
        A = torch.eye(3, 3, dtype=dtype, device=device)
559
        A[-1, -1] = 0  # Now A is not positive definite
560
        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
561
            torch.linalg.cholesky(A)
562
        with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
563
            np.linalg.cholesky(A.cpu().numpy())
564

565
        # if at least one matrix in the batch is singular, an error should be raised
566
        A = torch.eye(3, 3, dtype=dtype, device=device)
567
        A = A.reshape((1, 3, 3))
568
        A = A.repeat(5, 1, 1)
569
        A[4, -1, -1] = 0  # Now A[4] is not positive definite
570
        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'):
571
            torch.linalg.cholesky(A)
572

573
        # if out tensor with wrong shape is passed a warning is given
574
        A = random_hermitian_pd_matrix(3, dtype=dtype, device=device)
575
        out = torch.empty(2, 3, dtype=dtype, device=device)
576
        with warnings.catch_warnings(record=True) as w:
577
            # Trigger warning
578
            torch.linalg.cholesky(A, out=out)
579
            # Check warning occurs
580
            self.assertEqual(len(w), 1)
581
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
582

583
        # dtypes should be safely castable
584
        out = torch.empty(*A.shape, dtype=torch.int, device=device)
585
        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
586
            torch.linalg.cholesky(A, out=out)
587

588
        # device should match
589
        if torch.cuda.is_available():
590
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
591
            out = torch.empty(0, device=wrong_device, dtype=dtype)
592
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
593
                torch.linalg.cholesky(A, out=out)
594

595
    # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py
596
    @slowTest
597
    @skipCUDAIfNoMagma
598
    @skipCPUIfNoLapack
599
    @dtypes(torch.double)
600
    def test_old_cholesky_batched_many_batches(self, device, dtype):
601
        from torch.testing._internal.common_utils import random_symmetric_pd_matrix
602

603
        def cholesky_test_helper(n, batchsize, device, upper):
604
            A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device)
605
            chol_fact = torch.cholesky(A, upper=upper)
606
            if upper:
607
                # Correctness check
608
                self.assertEqual(A, chol_fact.mT.matmul(chol_fact))
609
                # Upper triangular check
610
                self.assertEqual(chol_fact, chol_fact.triu())
611
            else:
612
                # Correctness check
613
                self.assertEqual(A, chol_fact.matmul(chol_fact.mT))
614
                # Lower triangular check
615
                self.assertEqual(chol_fact, chol_fact.tril())
616

617
        for upper, batchsize in itertools.product([True, False], [262144, 524288]):
618
            cholesky_test_helper(2, batchsize, device, upper)
619

620
    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
621
    @skipCUDAIfNoMagma
622
    @skipCPUIfNoLapack
623
    @dtypes(*floating_and_complex_types())
624
    def test_old_cholesky_batched(self, device, dtype):
625
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
626

627
        def cholesky_test_helper(n, batch_dims, upper):
628
            A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
629
            cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
630
            cholesky_exp = cholesky_exp.reshape_as(A)
631
            self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
632

633
        for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]):
634
            cholesky_test_helper(3, batchsize, upper)
635

636
    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
637
    @skipCUDAIfNoMagma
638
    @skipCPUIfNoLapack
639
    @dtypes(*floating_and_complex_types())
640
    @tf32_on_and_off(0.01)
641
    @bf32_on_and_off(0.01)
642
    def test_old_cholesky(self, device, dtype):
643
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
644

645
        A = random_hermitian_pd_matrix(10, dtype=dtype, device=device)
646

647
        # default Case
648
        C = torch.cholesky(A)
649
        B = torch.mm(C, C.t().conj())
650
        self.assertEqual(A, B, atol=1e-14, rtol=0)
651

652
        # test Upper Triangular
653
        U = torch.cholesky(A, True)
654
        B = torch.mm(U.t().conj(), U)
655
        self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix')
656

657
        # test Lower Triangular
658
        L = torch.cholesky(A, False)
659
        B = torch.mm(L, L.t().conj())
660
        self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix')
661

662
    @skipCUDAIfNoMagma
663
    @skipCPUIfNoLapack
664
    @dtypes(*floating_and_complex_types())
665
    def test_old_cholesky_empty(self, device, dtype):
666
        def run_test(upper):
667
            A = torch.empty(0, 0, dtype=dtype, device=device)
668
            chol = torch.cholesky(A, upper)
669
            chol_A = torch.matmul(chol, chol.t().conj())
670
            self.assertEqual(A, chol_A)
671
        for upper in [True, False]:
672
            run_test(upper)
673

674
    # Test for issue
675
    # https://github.com/pytorch/pytorch/issues/57032
676
    # torch.cholesky with upper=True for batched CUDA inputs was wrong
677
    # it was using the lower triangular part instead of the upper one
678
    @onlyCUDA
679
    @skipCUDAIfNoMagma
680
    @dtypes(*floating_and_complex_types())
681
    def test_old_cholesky_batched_upper(self, device, dtype):
682
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
683

684
        batchsize = 2
685
        A = random_hermitian_pd_matrix(3, batchsize, dtype=dtype, device=device)
686
        A_triu = A.triu()  # fill the lower triangular part with zero
687

688
        U = torch.cholesky(A_triu, upper=True)
689

690
        reconstruct_A = U.mH @ U
691
        self.assertEqual(A, reconstruct_A)
692

693
    @skipCUDAIfNoMagmaAndNoCusolver
694
    @skipCPUIfNoLapack
695
    @dtypes(*floating_and_complex_types())
696
    def test_cholesky_ex(self, device, dtype):
697
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
698

699
        def run_test(n, batch):
700
            A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
701
            expected_L = np.linalg.cholesky(A.cpu().numpy())
702
            expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
703
            actual_L, actual_info = torch.linalg.cholesky_ex(A)
704

705
            # For fp32 individual entries in matrices can differ between PyTorch and NumPy
706
            # Let's compare the norms of matrices instead
707
            if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
708
                # axis is specified to calculate matrix norm for batched input
709
                expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
710
                actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
711
                # Compare the norms with standard tolerances
712
                self.assertEqual(actual_norm, expected_norm)
713
                # and individual values with a higher tolerance
714
                self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
715
            else:
716
                self.assertEqual(actual_L, expected_L)
717
            self.assertEqual(actual_info, expected_info)
718

719
        ns = (0, 3, 5)
720
        batches = ((), (2, ), (2, 1))
721
        for n, batch in itertools.product(ns, batches):
722
            run_test(n, batch)
723

724
    @skipCUDAIfNoMagmaAndNoCusolver
725
    @skipCPUIfNoLapack
726
    @dtypes(*floating_and_complex_types())
727
    def test_cholesky_ex_non_pd(self, device, dtype):
728
        # if the input matrix is not positive definite, info with positive integer is returned
729
        A = torch.eye(3, 3, dtype=dtype, device=device)
730
        A[-1, -1] = 0  # Now A is singular
731
        _, info = torch.linalg.cholesky_ex(A)
732
        self.assertEqual(info, 3)
733
        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
734
            torch.linalg.cholesky_ex(A, check_errors=True)
735

736
        # if at least one matrix in the batch is not positive definite,
737
        # batched info with positive integer for the corresponding matrix is returned
738
        A = torch.eye(3, 3, dtype=dtype, device=device)
739
        A = A.reshape((1, 3, 3))
740
        A = A.repeat(5, 1, 1)
741
        A[3, -2, -2] = 0  # Now A[3] is singular
742
        _, info = torch.linalg.cholesky_ex(A)
743

744
        expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
745
        expected_info[3] = 2
746
        self.assertEqual(info, expected_info)
747
        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'):
748
            torch.linalg.cholesky_ex(A, check_errors=True)
749

750
    def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1):
751
        def check(m, a, b, beta, alpha):
752
            if dtype == torch.bfloat16:
753
                a_np = a.to(torch.double).cpu().numpy()
754
                b_np = b.to(torch.double).cpu().numpy()
755
                m_np = m.to(torch.double).cpu().numpy()
756
                exact_dtype = False
757
            else:
758
                a_np = a.cpu().numpy()
759
                b_np = b.cpu().numpy()
760
                m_np = m.cpu().numpy()
761
                exact_dtype = True
762
            if beta == 0:
763
                expected = alpha * np.outer(a_np, b_np)
764
            else:
765
                expected = beta * m_np + alpha * np.outer(a_np, b_np)
766

767
            res = torch.addr(m, a, b, beta=beta, alpha=alpha)
768
            self.assertEqual(res, expected, exact_dtype=exact_dtype)
769

770
            # Test out variant
771
            out = torch.empty_like(res)
772
            torch.addr(m, a, b, beta=beta, alpha=alpha, out=out)
773
            self.assertEqual(out, expected, exact_dtype=exact_dtype)
774

775
        m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2)
776
        a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
777
        b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
778

779
        check(m, a, b, beta, alpha)
780

781
        # test transpose
782
        m_transpose = torch.transpose(m, 0, 1)
783
        check(m_transpose, a, b, beta, alpha)
784

785
        # test 0 strided tensor
786
        zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50)
787
        check(m, zero_strided, b, beta, alpha)
788

789
        # test scalar
790
        m_scalar = torch.tensor(1, device=device, dtype=dtype)
791
        check(m_scalar, a, b, beta, alpha)
792

793
        # test nans and infs are not propagated to the output when beta == 0
794
        float_and_complex_dtypes = floating_and_complex_types_and(torch.half, torch.bfloat16)
795
        if beta == 0 and dtype in float_and_complex_dtypes:
796
            m[0][10] = m[10][10] = m[20][20] = float('inf')
797
            m[1][10] = m[11][10] = m[21][20] = float('nan')
798
        check(m, a, b, 0, alpha)
799

800
    @dtypes(torch.bool)
801
    def test_addr_bool(self, device, dtype):
802
        self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False)
803
        self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True)
804
        self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False)
805
        self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True)
806

807
    @dtypes(*integral_types())
808
    def test_addr_integral(self, device, dtype):
809
        with self.assertRaisesRegex(RuntimeError,
810
                                    'argument beta must not be a floating point number.'):
811
            self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1)
812
        with self.assertRaisesRegex(RuntimeError,
813
                                    'argument alpha must not be a floating point number.'):
814
            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.)
815
        with self.assertRaisesRegex(RuntimeError,
816
                                    'Boolean beta only supported for Boolean results.'):
817
            self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
818
        with self.assertRaisesRegex(RuntimeError,
819
                                    'Boolean alpha only supported for Boolean results.'):
820
            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
821

822
        # when beta is zero
823
        self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2)
824
        # when beta is not zero
825
        self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2)
826

827
    @precisionOverride({torch.bfloat16: 1e-1})
828
    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
829
    def test_addr_float_and_complex(self, device, dtype):
830
        with self.assertRaisesRegex(RuntimeError,
831
                                    'Boolean beta only supported for Boolean results.'):
832
            self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
833
        with self.assertRaisesRegex(RuntimeError,
834
                                    'Boolean alpha only supported for Boolean results.'):
835
            self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
836

837
        # when beta is zero
838
        self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2)
839
        # when beta is not zero
840
        self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2)
841
        if dtype in complex_types():
842
            self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j))
843

844
    @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
845
                               all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
846
    def test_outer_type_promotion(self, device, dtypes):
847
        a = torch.randn(5).to(device=device, dtype=dtypes[0])
848
        b = torch.randn(5).to(device=device, dtype=dtypes[1])
849
        for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger):
850
            result = op(a, b)
851
            self.assertEqual(result.dtype, torch.result_type(a, b))
852

853
    # don't use @dtypes decorator to avoid generating ~1700 tests per device
854
    def test_addr_type_promotion(self, device):
855
        for dtypes0, dtypes1, dtypes2 in product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), repeat=3):
856
            a = make_tensor((5,), device=device, dtype=dtypes0, low=-2, high=2)
857
            b = make_tensor((5,), device=device, dtype=dtypes1, low=-2, high=2)
858
            m = make_tensor((5, 5), device=device, dtype=dtypes2, low=-2, high=2)
859

860
            desired_dtype = torch.promote_types(torch.promote_types(dtypes0, dtypes1),
861
                                                dtypes2)
862
            for op in (torch.addr, torch.Tensor.addr):
863
                result = op(m, a, b)
864
                self.assertEqual(result.dtype, desired_dtype)
865

866
    # Tests migrated from test_torch.py
867
    # 1) test the shape of the result tensor when there is empty input tensor
868
    # 2) test the Runtime Exception when there is scalar input tensor
869
    def test_outer_ger_addr_legacy_tests(self, device):
870
        for size in ((0, 0), (0, 5), (5, 0)):
871
            a = torch.rand(size[0], device=device)
872
            b = torch.rand(size[1], device=device)
873

874
            self.assertEqual(torch.outer(a, b).shape, size)
875
            self.assertEqual(torch.ger(a, b).shape, size)
876

877
            m = torch.empty(size, device=device)
878
            self.assertEqual(torch.addr(m, a, b).shape, size)
879

880
        m = torch.randn(5, 6, device=device)
881
        a = torch.randn(5, device=device)
882
        b = torch.tensor(6, device=device)
883
        self.assertRaises(RuntimeError, lambda: torch.outer(a, b))
884
        self.assertRaises(RuntimeError, lambda: torch.outer(b, a))
885
        self.assertRaises(RuntimeError, lambda: torch.ger(a, b))
886
        self.assertRaises(RuntimeError, lambda: torch.ger(b, a))
887
        self.assertRaises(RuntimeError, lambda: torch.addr(m, a, b))
888
        self.assertRaises(RuntimeError, lambda: torch.addr(m, b, a))
889

890
    # Tests torch.det and its alias, torch.linalg.det, vs. NumPy
891
    @skipCUDAIfNoMagma
892
    @skipCPUIfNoLapack
893
    @dtypes(torch.double, torch.cdouble)
894
    def test_det(self, device, dtype):
895
        tensors = (
896
            torch.randn((2, 2), device=device, dtype=dtype),
897
            torch.randn((129, 129), device=device, dtype=dtype),
898
            torch.randn((3, 52, 52), device=device, dtype=dtype),
899
            torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
900

901

902
        ops = (torch.det, torch.Tensor.det,
903
               torch.linalg.det)
904
        for t in tensors:
905
            expected = np.linalg.det(t.cpu().numpy())
906
            for op in ops:
907
                actual = op(t)
908
                self.assertEqual(actual, expected)
909
                self.compare_with_numpy(op, np.linalg.det, t)
910

911
        # NOTE: det requires a 2D+ tensor
912
        t = torch.randn(1, device=device, dtype=dtype)
913
        with self.assertRaises(RuntimeError):
914
            op(t)
915

916
    @skipCUDAIfNoMagma
917
    @skipCPUIfNoLapack
918
    @dtypes(*floating_and_complex_types())
919
    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
920
    def test_eigh(self, device, dtype):
921
        from torch.testing._internal.common_utils import random_hermitian_matrix
922

923
        def run_test(shape, batch, uplo):
924
            matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
925
            expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
926
            actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
927
            self.assertEqual(actual_w, expected_w)
928
            # sign of eigenvectors is not unique and therefore absolute values are compared
929
            self.assertEqual(abs(actual_v), abs(expected_v))
930
            # additionally we can multiply the eigenvector with a phase factor e^{i\phi} and then compare the values
931
            # let's choose the convention that the first element of the eigenvectors from torch and numpy be the same
932
            # for real inputs, this phase factor is plus or minus one
933
            if matrix.numel() > 0:
934
                phase = torch.from_numpy(expected_v[..., 0, :]).to(device=device).div(actual_v[..., 0, :])
935
                actual_v_rotated = actual_v * phase.unsqueeze(-2).expand_as(actual_v)
936
                self.assertEqual(actual_v_rotated, expected_v)
937

938
            # check the out= variant
939
            out_w = torch.empty_like(actual_w)
940
            out_v = torch.empty_like(actual_v)
941
            ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v))
942
            self.assertEqual(ans_w, out_w)
943
            self.assertEqual(ans_v, out_v)
944
            self.assertEqual(ans_w, actual_w)
945
            self.assertEqual(abs(ans_v), abs(actual_v))
946

947
        shapes = (0, 3, 5)
948
        batches = ((), (3, ), (2, 2))
949
        uplos = ["U", "L"]
950
        for shape, batch, uplo in itertools.product(shapes, batches, uplos):
951
            run_test(shape, batch, uplo)
952

953
    @skipCUDAIfNoMagma
954
    @skipCPUIfNoLapack
955
    @dtypes(*floating_and_complex_types())
956
    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
957
    def test_eigh_lower_uplo(self, device, dtype):
958
        def run_test(shape, batch, uplo):
959
            # check lower case uplo
960
            # use non-symmetric input to check whether uplo argument is working as intended
961
            matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device)
962
            expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
963
            actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
964
            self.assertEqual(actual_w, expected_w)
965
            self.assertEqual(abs(actual_v), abs(expected_v))
966

967
        uplos = ["u", "l"]
968
        for uplo in uplos:
969
            run_test(3, (2, 2), uplo)
970

971
    @skipCUDAIfNoMagma
972
    @skipCPUIfNoLapack
973
    @dtypes(*floating_and_complex_types())
974
    def test_eigh_errors_and_warnings(self, device, dtype):
975
        from torch.testing._internal.common_utils import random_hermitian_matrix
976

977
        # eigh requires a square matrix
978
        t = torch.randn(2, 3, device=device, dtype=dtype)
979
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
980
            torch.linalg.eigh(t)
981

982
        # eigh requires 'uplo' parameter to be 'U' or 'L'
983
        t = torch.randn(3, 3, device=device, dtype=dtype)
984
        for uplo in ["a", "wrong"]:
985
            with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
986
                torch.linalg.eigh(t, UPLO=uplo)
987
            with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
988
                np.linalg.eigh(t.cpu().numpy(), UPLO=uplo)
989

990
        # if non-empty out tensor with wrong shape is passed a warning is given
991
        a = random_hermitian_matrix(3, dtype=dtype, device=device)
992
        real_dtype = a.real.dtype if dtype.is_complex else dtype
993
        out_w = torch.empty(7, 7, dtype=real_dtype, device=device)
994
        out_v = torch.empty(7, 7, dtype=dtype, device=device)
995
        with warnings.catch_warnings(record=True) as w:
996
            # Trigger warning
997
            torch.linalg.eigh(a, out=(out_w, out_v))
998
            # Check warning occurs
999
            self.assertEqual(len(w), 2)
1000
            self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
1001
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1002

1003
        # dtypes should be safely castable
1004
        out_w = torch.empty(0, dtype=real_dtype, device=device)
1005
        out_v = torch.empty(0, dtype=torch.int, device=device)
1006
        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1007
            torch.linalg.eigh(a, out=(out_w, out_v))
1008

1009
        out_w = torch.empty(0, dtype=torch.int, device=device)
1010
        out_v = torch.empty(0, dtype=dtype, device=device)
1011
        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1012
            torch.linalg.eigh(a, out=(out_w, out_v))
1013

1014
        # device should match
1015
        if torch.cuda.is_available():
1016
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1017
            out_w = torch.empty(0, device=wrong_device, dtype=dtype)
1018
            out_v = torch.empty(0, device=device, dtype=dtype)
1019
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1020
                torch.linalg.eigh(a, out=(out_w, out_v))
1021
            out_w = torch.empty(0, device=device, dtype=dtype)
1022
            out_v = torch.empty(0, device=wrong_device, dtype=dtype)
1023
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1024
                torch.linalg.eigh(a, out=(out_w, out_v))
1025

1026
    @skipCPUIfNoLapack
1027
    @dtypes(torch.float, torch.double)
1028
    @unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.")
1029
    def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype):
1030
        # See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359
1031
        # This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8,
1032
        # but passes on cuda 12.1 update 1 or later.
1033
        a = torch.ones(512, 512, dtype=dtype, device=device)
1034
        a[0, 0] = 1.0e-5
1035
        a[-1, -1] = 1.0e5
1036

1037
        eigh_out = torch.linalg.eigh(a)
1038
        svd_out = torch.linalg.svd(a)
1039

1040
        # Matrix input a is too ill-conditioned.
1041
        # We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0
1042
        # The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge
1043
        # to exact values.
1044
        self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1045
        self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1046

1047
    @skipCUDAIfNoMagma
1048
    @skipCPUIfNoLapack
1049
    @dtypes(*floating_and_complex_types())
1050
    @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
1051
    def test_eigvalsh(self, device, dtype):
1052
        from torch.testing._internal.common_utils import random_hermitian_matrix
1053

1054
        def run_test(shape, batch, uplo):
1055
            matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
1056
            expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo)
1057
            actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo)
1058
            self.assertEqual(actual_w, expected_w)
1059

1060
            # check the out= variant
1061
            out = torch.empty_like(actual_w)
1062
            ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out)
1063
            self.assertEqual(ans, out)
1064
            self.assertEqual(ans, actual_w)
1065

1066
        shapes = (0, 3, 5)
1067
        batches = ((), (3, ), (2, 2))
1068
        uplos = ["U", "L"]
1069
        for shape, batch, uplo in itertools.product(shapes, batches, uplos):
1070
            run_test(shape, batch, uplo)
1071

1072
    @skipCUDAIfNoMagma
1073
    @skipCPUIfNoLapack
1074
    @dtypes(*floating_and_complex_types())
1075
    def test_eigvalsh_errors_and_warnings(self, device, dtype):
1076
        # eigvalsh requires a square matrix
1077
        t = torch.randn(2, 3, device=device, dtype=dtype)
1078
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1079
            torch.linalg.eigvalsh(t)
1080

1081
        # eigvalsh requires 'uplo' parameter to be 'U' or 'L'
1082
        t = torch.randn(3, 3, device=device, dtype=dtype)
1083
        for uplo in ["a", "wrong"]:
1084
            with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
1085
                torch.linalg.eigvalsh(t, UPLO=uplo)
1086
            with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
1087
                np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo)
1088

1089
        # if non-empty out tensor with wrong shape is passed a warning is given
1090
        real_dtype = t.real.dtype if dtype.is_complex else dtype
1091
        out = torch.empty_like(t).to(real_dtype)
1092
        with warnings.catch_warnings(record=True) as w:
1093
            # Trigger warning
1094
            torch.linalg.eigvalsh(t, out=out)
1095
            # Check warning occurs
1096
            self.assertEqual(len(w), 1)
1097
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1098

1099
        # dtypes should be safely castable
1100
        out = torch.empty(0, dtype=torch.int, device=device)
1101
        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1102
            torch.linalg.eigvalsh(t, out=out)
1103

1104
        # device should match
1105
        if torch.cuda.is_available():
1106
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1107
            out = torch.empty(0, device=wrong_device, dtype=dtype)
1108
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1109
                torch.linalg.eigvalsh(t, out=out)
1110

1111
    @dtypes(*floating_and_complex_types())
1112
    def test_kron(self, device, dtype):
1113

1114
        def run_test_case(a_shape, b_shape):
1115
            a = torch.rand(a_shape, dtype=dtype, device=device)
1116
            b = torch.rand(b_shape, dtype=dtype, device=device)
1117

1118
            expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1119
            result = torch.kron(a, b)
1120
            self.assertEqual(result, expected)
1121

1122
            # check the out= variant
1123
            out = torch.empty_like(result)
1124
            ans = torch.kron(a, b, out=out)
1125
            self.assertEqual(ans, out)
1126
            self.assertEqual(ans, result)
1127

1128
        shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)]
1129
        for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
1130
            run_test_case(a_shape, b_shape)
1131

1132
    @dtypes(*floating_and_complex_types())
1133
    def test_kron_empty(self, device, dtype):
1134

1135
        def run_test_case(empty_shape):
1136
            a = torch.eye(3, dtype=dtype, device=device)
1137
            b = torch.empty(empty_shape, dtype=dtype, device=device)
1138
            result = torch.kron(a, b)
1139
            expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1140
            self.assertEqual(result, expected)
1141

1142
            # NumPy doesn't work if the first argument is empty
1143
            result = torch.kron(b, a)
1144
            self.assertEqual(result.shape, expected.shape)
1145

1146
        empty_shapes = [(0,), (2, 0), (1, 0, 3)]
1147
        for empty_shape in empty_shapes:
1148
            run_test_case(empty_shape)
1149

1150
    @dtypes(*floating_and_complex_types())
1151
    def test_kron_errors_and_warnings(self, device, dtype):
1152
        # if non-empty out tensor with wrong shape is passed a warning is given
1153
        a = torch.eye(3, dtype=dtype, device=device)
1154
        b = torch.ones((2, 2), dtype=dtype, device=device)
1155
        out = torch.empty_like(a)
1156
        with warnings.catch_warnings(record=True) as w:
1157
            # Trigger warning
1158
            torch.kron(a, b, out=out)
1159
            # Check warning occurs
1160
            self.assertEqual(len(w), 1)
1161
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1162

1163
        # dtypes should match
1164
        out = torch.empty_like(a).to(torch.int)
1165
        with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
1166
            torch.kron(a, b, out=out)
1167

1168
    # This test confirms that torch.linalg.norm's dtype argument works
1169
    # as expected, according to the function's documentation
1170
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1171
    def test_norm_dtype(self, device, dtype):
1172
        make_arg = partial(make_tensor, dtype=dtype, device=device)
1173

1174
        def run_test_case(input_size, ord, keepdim, to_dtype):
1175
            msg = (
1176
                f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1177
                f'dtype={dtype}, to_dtype={to_dtype}')
1178
            input = make_arg(input_size)
1179
            result = torch.linalg.norm(input, ord, keepdim=keepdim)
1180
            self.assertEqual(result.dtype, input.real.dtype, msg=msg)
1181

1182
            result_out = torch.empty((0), dtype=result.dtype, device=device)
1183
            torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out)
1184
            self.assertEqual(result, result_out, msg=msg)
1185

1186
            result = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim)
1187
            result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype)
1188
            self.assertEqual(result, result_with_dtype, msg=msg)
1189

1190
            result_out_with_dtype = torch.empty_like(result_with_dtype)
1191
            torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype)
1192
            self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg)
1193

1194
        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1195

1196
        # In these orders we are computing the 10-th power and 10-th root of numbers.
1197
        # We avoid them for half-precision types as it makes the tests above too badly conditioned
1198
        if dtype != torch.float16 and dtype != torch.bfloat16:
1199
            ord_vector.extend([0.1, -0.1])
1200
        ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None]
1201
        S = 10
1202

1203
        if dtype == torch.cfloat:
1204
            norm_dtypes = (torch.cfloat, torch.cdouble)
1205
        elif dtype == torch.cdouble:
1206
            norm_dtypes = (torch.cdouble,)
1207
        elif dtype in (torch.float16, torch.bfloat16, torch.float):
1208
            norm_dtypes = (torch.float, torch.double)
1209
        elif dtype == torch.double:
1210
            norm_dtypes = (torch.double,)
1211
        else:
1212
            raise RuntimeError("Unsupported dtype")
1213

1214
        for ord, keepdim, norm_dtype in product(ord_vector, (True, False), norm_dtypes):
1215
            run_test_case((S,) , ord, keepdim, norm_dtype)
1216

1217
        for ord, keepdim, norm_dtype in product(ord_matrix, (True, False), norm_dtypes):
1218
            if ord in [2, -2, 'nuc']:
1219
                # We need torch.svdvals
1220
                if dtype == torch.float16 or dtype == torch.bfloat16:
1221
                    continue
1222

1223
                # We need LAPACK or equivalent
1224
                if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1225
                   (torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1226
                    continue
1227
            run_test_case((S, S) , ord, keepdim, norm_dtype)
1228

1229
    # This test confirms torch.linalg.norm bfloat16 and half get right result.
1230
    @dtypes(torch.bfloat16, torch.float16)
1231
    def test_norm_bfloat16_and_half(self, device, dtype):
1232
        make_arg = partial(make_tensor, dtype=dtype, device=device)
1233

1234
        def run_test_case(input_size, ord, keepdim):
1235
            msg = (
1236
                f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1237
                f'dtype={dtype}')
1238
            input = make_arg(input_size).fill_(1)
1239
            result_ref = torch.linalg.norm(input.float(), ord, keepdim=keepdim).to(dtype=dtype)
1240
            result = torch.linalg.norm(input, ord, keepdim=keepdim)
1241
            self.assertEqual(result_ref, result, msg=msg)
1242

1243
        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1244
        for S, ord, keepdim in product((10, 2049), ord_vector, (True, False)):
1245
            run_test_case((S,) , ord, keepdim, )
1246

1247
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1248
    def test_vector_norm(self, device, dtype):
1249
        if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]:
1250
            raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
1251
        # have to use torch.randn(...).to(bfloat16) instead of
1252
        # This test compares torch.linalg.vector_norm's output with
1253
        # torch.linalg.norm given a flattened tensor
1254
        ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1255
        input_sizes = [
1256
            (1, ),
1257
            (10, ),
1258
            (4, 5),
1259
            (3, 4, 5),
1260
            (0, ),
1261
            (0, 10),
1262
            (0, 0),
1263
            (10, 0, 10),
1264
        ]
1265

1266
        def vector_norm_reference(input, ord, dim=None, keepdim=False, dtype=None):
1267
            if dim is None:
1268
                input_maybe_flat = input.flatten(0, -1)
1269
            else:
1270
                input_maybe_flat = input
1271

1272
            result = torch.linalg.norm(input_maybe_flat, ord, dim=dim, keepdim=keepdim, dtype=dtype)
1273
            if keepdim and dim is None:
1274
                result = result.reshape([1] * input.dim())
1275
            return result
1276

1277
        def run_test_case(input, ord, dim, keepdim, norm_dtype):
1278
            if (input.numel() == 0 and
1279
                (ord < 0. or ord == inf) and
1280
               (dim is None or input.shape[dim] == 0)):
1281
                # The operation does not have an identity.
1282
                error_msg = "linalg.vector_norm cannot compute"
1283
                with self.assertRaisesRegex(RuntimeError, error_msg):
1284
                    torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim)
1285
            else:
1286
                msg = (f'input.size()={input.size()}, ord={ord}, dim={dim}, '
1287
                       f'keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}')
1288
                result_dtype_reference = vector_norm_reference(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1289
                result_dtype = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1290
                if dtype.is_complex:
1291
                    result_dtype_reference = result_dtype_reference.real
1292
                self.assertEqual(result_dtype, result_dtype_reference, msg=msg)
1293

1294
                if norm_dtype is not None:
1295
                    ref = torch.linalg.vector_norm(input.to(norm_dtype), ord, dim=dim, keepdim=keepdim)
1296
                    actual = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1297
                    self.assertEqual(ref, actual, msg=msg)
1298

1299
        if dtype == torch.cfloat:
1300
            norm_dtypes = (None, torch.cfloat, torch.cdouble)
1301
        elif dtype == torch.cdouble:
1302
            norm_dtypes = (None, torch.cdouble)
1303
        elif dtype in (torch.float16, torch.bfloat16, torch.float):
1304
            norm_dtypes = (None, torch.float, torch.double)
1305
        elif dtype == torch.double:
1306
            norm_dtypes = (None, torch.double)
1307
        else:
1308
            raise RuntimeError("Unsupported dtype")
1309

1310
        for amp in [False, True]:
1311
            with torch.autocast(device_type=device, enabled=amp):
1312
                for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes):
1313
                    input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1314
                    for dim in [None, random.randint(0, len(input_size) - 1)]:
1315
                        run_test_case(
1316
                            input,
1317
                            ord,
1318
                            dim,
1319
                            keepdim,
1320
                            norm_dtype)
1321

1322
    def test_vector_norm_dim_tuple_arg(self, device):
1323
        test_cases = [
1324
            # input size, dim, error, error message
1325
            ((4, ), (0, ), None, None),
1326
            ((4, ), (1, ), IndexError, r'Dimension out of range'),
1327
            ((4, ), (-2, ), IndexError, r'Dimension out of range'),
1328
            ((4, 3), (0, -1), None, None),
1329
            ((4, 3), (0, 0), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1330
            ((4, 3), (0, -2), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1331
            ((4, 3), (0, 1.0), TypeError, r"argument 'dim' must be tuple of ints"),
1332
            ((4, 3), (None, ), TypeError, r"argument 'dim' must be tuple of ints"),
1333
        ]
1334
        for input_size, dim_tuple, error, error_msg in test_cases:
1335
            input = torch.randn(input_size, device=device)
1336
            # vector_norm should accept a tuple or a list for dim arg
1337
            for dim in [dim_tuple, list(dim_tuple)]:
1338
                if error is None:
1339
                    torch.linalg.vector_norm(input, dim=dim)
1340
                else:
1341
                    with self.assertRaises(error):
1342
                        torch.linalg.vector_norm(input, dim=dim)
1343

1344
    # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
1345
    # their vector norm results match
1346
    @dtypes(torch.float, torch.double)
1347
    def test_norm_vector(self, device, dtype):
1348
        def run_test_case(input, p, dim, keepdim):
1349
            result = torch.linalg.norm(input, ord, dim, keepdim)
1350
            input_numpy = input.cpu().numpy()
1351
            result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1352

1353
            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1354
            self.assertEqual(result, result_numpy, msg=msg)
1355

1356
            result_out = torch.empty_like(result)
1357
            torch.linalg.norm(input, ord, dim, keepdim, out=result_out)
1358
            self.assertEqual(result, result_out, msg=msg)
1359

1360
        ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf]
1361
        S = 10
1362
        test_cases = [
1363
            # input size, p settings, dim
1364
            ((S, ), ord_vector, None),
1365
            ((S, ), ord_vector, 0),
1366
            ((S, S, S), ord_vector, 0),
1367
            ((S, S, S), ord_vector, 1),
1368
            ((S, S, S), ord_vector, 2),
1369
            ((S, S, S), ord_vector, -1),
1370
            ((S, S, S), ord_vector, -2),
1371
        ]
1372
        L = 1_000_000
1373
        if dtype == torch.double:
1374
            test_cases.append(((L, ), ord_vector, None))
1375
        for keepdim in [True, False]:
1376
            for input_size, ord_settings, dim in test_cases:
1377
                input = torch.randn(*input_size, dtype=dtype, device=device)
1378
                for ord in ord_settings:
1379
                    run_test_case(input, ord, dim, keepdim)
1380

1381
    # This test compares torch.linalg.norm, torch.linalg.matrix_norm and numpy.linalg.norm to
1382
    # ensure that their matrix norm results match.
1383
    @skipMeta  # https://github.com/pytorch/pytorch/issues/54082
1384
    @skipCUDAIfNoMagma
1385
    @dtypes(torch.float, torch.double)
1386
    @precisionOverride({torch.float32: 2e-4})
1387
    def test_norm_matrix(self, device, dtype):
1388
        make_arg = partial(make_tensor, dtype=dtype, device=device)
1389

1390
        def run_test_case(input, ord, dim, keepdim):
1391
            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1392
            result = torch.linalg.norm(input, ord, dim, keepdim)
1393
            input_numpy = input.cpu().numpy()
1394
            result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1395

1396
            result = torch.linalg.norm(input, ord, dim, keepdim)
1397
            self.assertEqual(result, result_numpy, msg=msg)
1398
            if ord is not None and dim is not None:
1399
                result = torch.linalg.matrix_norm(input, ord, dim, keepdim)
1400
                self.assertEqual(result, result_numpy, msg=msg)
1401

1402
        ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro']
1403
        S = 10
1404
        test_cases = [
1405
            # input size, dim
1406
            ((S, S), None),
1407
            ((S, S), (0, 1)),
1408
            ((S, S), (1, 0)),
1409
            ((S, S, S, S), (2, 0)),
1410
            ((S, S, S, S), (-1, -2)),
1411
            ((S, S, S, S), (-1, -3)),
1412
            ((S, S, S, S), (-3, 2)),
1413
        ]
1414

1415
        for (shape, dim), keepdim, ord in product(test_cases, [True, False], ord_matrix):
1416
            if ord in [2, -2, 'nuc']:
1417
                # We need torch.svdvals
1418
                if dtype == torch.float16 or dtype == torch.bfloat16:
1419
                    continue
1420
                # We need LAPACK or equivalent
1421
                if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1422
                   (torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1423
                    continue
1424
            run_test_case(make_arg(shape), ord, dim, keepdim)
1425

1426

1427
    @onlyCUDA
1428
    @dtypes(torch.bfloat16, torch.float16)
1429
    def test_norm_fused_type_promotion(self, device, dtype):
1430
        x = torch.randn(10, device=device, dtype=dtype)
1431

1432
        def profile_and_check(fn, x, kwargs):
1433
            with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
1434
                fn(x, **kwargs, dtype=torch.float)
1435
            # smoke check that profiler returned some events
1436
            self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events()))
1437
            # test that there was no explicit copy
1438
            self.assertFalse("aten::to" in (e.name for e in p.events()))
1439

1440
        for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})):
1441
            profile_and_check(f, x, kwargs)
1442

1443
    @skipMeta  # https://github.com/pytorch/pytorch/issues/53739
1444
    @skipCPUIfNoLapack
1445
    @skipCUDAIfNoMagma
1446
    @dtypes(*floating_and_complex_types())
1447
    @precisionOverride({torch.float32: 1e-3})
1448
    def test_cond(self, device, dtype):
1449
        def run_test_case(input, p):
1450
            result = torch.linalg.cond(input, p)
1451
            result_numpy = np.linalg.cond(input.cpu().numpy(), p)
1452
            self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision, exact_dtype=False)
1453
            self.assertEqual(result.shape, result_numpy.shape)
1454

1455
            # test out= variant
1456
            out = torch.empty_like(result)
1457
            ans = torch.linalg.cond(input, p, out=out)
1458
            self.assertEqual(ans, out)
1459
            self.assertEqual(ans, result)
1460

1461
        norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1462
        input_sizes = [(32, 32), (2, 3, 3, 3)]
1463
        for input_size in input_sizes:
1464
            input = torch.randn(*input_size, dtype=dtype, device=device)
1465
            for p in norm_types:
1466
                run_test_case(input, p)
1467

1468
        # test empty batch sizes
1469
        input_sizes = [(0, 3, 3), (0, 2, 5, 5)]
1470
        for input_size in input_sizes:
1471
            input = torch.randn(*input_size, dtype=dtype, device=device)
1472
            for p in norm_types:
1473
                run_test_case(input, p)
1474

1475
        # test non-square input
1476
        input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)]
1477
        for input_size in input_sizes:
1478
            input = torch.randn(*input_size, dtype=dtype, device=device)
1479
            for p in [2, -2, None]:
1480
                run_test_case(input, p)
1481

1482
        # test for singular input
1483
        a = torch.eye(3, dtype=dtype, device=device)
1484
        a[-1, -1] = 0  # make 'a' singular
1485
        for p in norm_types:
1486
            try:
1487
                run_test_case(a, p)
1488
            except np.linalg.LinAlgError:
1489
                # Numpy may fail to converge for some BLAS backends (although this is very rare)
1490
                # See the discussion in https://github.com/pytorch/pytorch/issues/67675
1491
                pass
1492

1493
        # test for 0x0 matrices. NumPy doesn't work for such input, we return 0
1494
        input_sizes = [(0, 0), (2, 5, 0, 0)]
1495
        for input_size in input_sizes:
1496
            input = torch.randn(*input_size, dtype=dtype, device=device)
1497
            for p in ['fro', 2]:
1498
                expected_dtype = a.real.dtype if dtype.is_complex else dtype
1499
                expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device)
1500
                actual = torch.linalg.cond(input, p)
1501
                self.assertEqual(actual, expected)
1502

1503
    @skipMeta  # https://github.com/pytorch/pytorch/issues/53739
1504
    @skipCPUIfNoLapack
1505
    @skipCUDAIfNoMagma
1506
    @dtypes(*floating_and_complex_types())
1507
    @precisionOverride({torch.float32: 1e-3})
1508
    def test_cond_errors_and_warnings(self, device, dtype):
1509
        norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1510

1511
        # cond expects the input to be at least 2-dimensional
1512
        a = torch.ones(3, dtype=dtype, device=device)
1513
        for p in norm_types:
1514
            with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'):
1515
                torch.linalg.cond(a, p)
1516

1517
        # for some norm types cond expects the input to be square
1518
        a = torch.ones(3, 2, dtype=dtype, device=device)
1519
        norm_types = [1, -1, inf, -inf, 'fro', 'nuc']
1520
        for p in norm_types:
1521
            with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
1522
                torch.linalg.cond(a, p)
1523

1524
        # if non-empty out tensor with wrong shape is passed a warning is given
1525
        a = torch.ones((2, 2), dtype=dtype, device=device)
1526
        for p in ['fro', 2]:
1527
            real_dtype = a.real.dtype if dtype.is_complex else dtype
1528
            out = torch.empty(a.shape, dtype=real_dtype, device=device)
1529
            with warnings.catch_warnings(record=True) as w:
1530
                # Trigger warning
1531
                torch.linalg.cond(a, p, out=out)
1532
                # Check warning occurs
1533
                self.assertEqual(len(w), 1)
1534
                self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1535

1536
        # dtypes should be safely castable
1537
        out = torch.empty(0, dtype=torch.int, device=device)
1538
        for p in ['fro', 2]:
1539
            with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
1540
                torch.linalg.cond(a, p, out=out)
1541

1542
        # device should match
1543
        if torch.cuda.is_available():
1544
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1545
            out = torch.empty(0, dtype=dtype, device=wrong_device)
1546
            for p in ['fro', 2]:
1547
                with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1548
                    torch.linalg.cond(a, p, out=out)
1549

1550
        # for batched input if at least one matrix in the batch is not invertible,
1551
        # we can't get the result for all other (possibly) invertible matrices in the batch without an explicit for loop.
1552
        # this should change when at::inverse works with silent errors
1553
        # NumPy works fine in this case because it's possible to silence the error and get the inverse matrix results
1554
        # possibly filled with NANs
1555
        batch_dim = 3
1556
        a = torch.eye(3, 3, dtype=dtype, device=device)
1557
        a = a.reshape((1, 3, 3))
1558
        a = a.repeat(batch_dim, 1, 1)
1559
        a[1, -1, -1] = 0  # now a[1] is singular
1560
        for p in [1, -1, inf, -inf, 'fro', 'nuc']:
1561
            result = torch.linalg.cond(a, p)
1562
            self.assertEqual(result[1], float('inf'))
1563

1564
        # check invalid norm type
1565
        a = torch.ones(3, 3, dtype=dtype, device=device)
1566
        for p in ['wrong_norm', 5]:
1567
            with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"):
1568
                torch.linalg.cond(a, p)
1569

1570
    # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments
1571
    # to ensure that they both throw errors
1572
    @dtypes(torch.float, torch.double)
1573
    def test_norm_errors(self, device, dtype):
1574
        def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex):
1575
            test_case_info = (
1576
                f'test case input.size()={input.size()}, ord={ord}, dim={dim}, '
1577
                f'keepdim={keepdim}, dtype={dtype}')
1578

1579
            with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info):
1580
                torch.linalg.norm(input, ord, dim, keepdim)
1581

1582
            input_numpy = input.cpu().numpy()
1583

1584
            msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"'
1585
            with self.assertRaises(Exception, msg=test_case_info):
1586
                np.linalg.norm(input_numpy, ord, dim, keepdim)
1587

1588
        S = 10
1589
        error_test_cases = [
1590
            # input size, p settings, dim, error type, error regex
1591
            ((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'),
1592
            ((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'),
1593
            ((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'),
1594
            ((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'),
1595
            ((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'),
1596
            ((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'),
1597
            ((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'),
1598
            ((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'),
1599
            ((S, ), [0], (4, ), IndexError, r'Dimension out of range'),
1600
            ((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'),
1601
            ((S, S, S), [1], (0, 1, 2), RuntimeError, r"If dim is specified, it must be of length 1 or 2."),
1602
            ((S, S, S), [1], None, RuntimeError, r"If dim is not specified but ord is, the input must be 1D or 2D"),
1603
        ]
1604
        for keepdim in [True, False]:
1605
            for input_size, ord_settings, dim, error_type, error_regex in error_test_cases:
1606
                input = torch.randn(*input_size, dtype=dtype, device=device)
1607
                for ord in ord_settings:
1608
                    run_error_test_case(input, ord, dim, keepdim, error_type, error_regex)
1609

1610
    # Test complex number inputs for linalg.norm
1611
    @skipCUDAIfNoMagma
1612
    @skipCPUIfNoLapack
1613
    @dtypes(torch.cfloat, torch.cdouble)
1614
    @precisionOverride({torch.cfloat: 5e-4})
1615
    def test_norm_complex(self, device, dtype):
1616
        def gen_error_message(input_size, ord, keepdim, dim=None):
1617
            return f"complex norm failed for input size {input_size}, ord={ord}, keepdim={keepdim}, dim={dim}"
1618

1619
        vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf]
1620
        matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf]
1621

1622
        # Test supported ords
1623
        for keepdim in [False, True]:
1624
            # vector norm
1625
            x = torch.randn(25, device=device, dtype=dtype)
1626
            xn = x.cpu().numpy()
1627
            for ord in vector_ords:
1628
                res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1629
                expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1630
                msg = gen_error_message(x.size(), ord, keepdim)
1631
                self.assertEqual(res.shape, expected.shape, msg=msg)
1632
                self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1633

1634
                res_out = torch.tensor([], device=device, dtype=res.dtype)
1635
                torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1636
                self.assertEqual(res_out.shape, expected.shape, msg=msg)
1637
                self.assertEqual(res_out, expected, msg=msg)
1638

1639
            # matrix norm
1640
            x = torch.randn(25, 25, device=device, dtype=dtype)
1641
            xn = x.cpu().numpy()
1642
            for ord in matrix_ords:
1643
                res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1644
                expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1645
                msg = gen_error_message(x.size(), ord, keepdim)
1646
                self.assertEqual(res.shape, expected.shape, msg=msg)
1647
                self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1648

1649
                res_out = torch.tensor([], device=device, dtype=res.dtype)
1650
                torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1651
                self.assertEqual(res_out.shape, expected.shape, msg=msg)
1652
                self.assertEqual(res_out, expected, msg=msg)
1653

1654
    # Test that linal.vector_norm gives the same result as numpy when inputs
1655
    # contain extreme values (inf, -inf, nan)
1656
    def test_vector_norm_extreme_values(self, device):
1657
        vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1658
        vectors = []
1659
        for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1660
            vectors.append(list(pair))
1661
        for vector in vectors:
1662
            x = torch.tensor(vector, device=device)
1663
            x_n = x.cpu().numpy()
1664
            for ord in vector_ords:
1665
                msg = f'ord={ord}, vector={vector}'
1666
                result = torch.linalg.vector_norm(x, ord=ord)
1667
                result_n = np.linalg.norm(x_n, ord=ord)
1668
                self.assertEqual(result, result_n, msg=msg)
1669

1670
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1671
    def test_vector_norm_reduce_over_1D_vector(self, device, dtype):
1672
        input_sizes_and_dims = [
1673
            ((6, 1), -1),
1674
            ((3, 1, 2, 1), (1, 3)),
1675
            ((1,), None),
1676
        ]
1677
        orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2]
1678
        keepdims = [True, False]
1679

1680
        for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims):
1681
            input_size = input_size_and_dim[0]
1682
            dim = input_size_and_dim[1]
1683
            if type(dim) is tuple and ord == 0:
1684
                # skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.'
1685
                continue
1686
            input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1687
            result = torch.linalg.vector_norm(input, ord, dim, keepdim)
1688
            result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim)
1689

1690
            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1691
            self.assertEqual(result, result_numpy, msg=msg)
1692

1693
    @skipCUDAIfNoMagmaAndNoCusolver
1694
    @skipCPUIfNoLapack
1695
    @dtypes(torch.float, torch.double)
1696
    @precisionOverride({torch.float32: 2e-5})
1697
    def test_matrix_norm(self, device, dtype):
1698
        # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm
1699
        A = make_tensor((2, 2, 2), dtype=dtype, device=device)
1700

1701
        with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'):
1702
            torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device))
1703
        with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'):
1704
            torch.linalg.matrix_norm(A, dim=(0,))
1705
        with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1706
            torch.linalg.matrix_norm(A, ord=0)
1707
        with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1708
            torch.linalg.matrix_norm(A, ord=3.0)
1709

1710
        # Test dim=None behavior
1711
        ref = torch.linalg.norm(A, dim=(-2, -1))
1712
        res = torch.linalg.matrix_norm(A)
1713
        self.assertEqual(ref, res)
1714

1715
    # Test that linal.norm gives the same result as numpy when inputs
1716
    # contain extreme values (inf, -inf, nan)
1717
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
1718
    @unittest.skipIf(IS_MACOS, "Skipped on MacOS!")
1719
    @skipCUDAIfNoMagma
1720
    @skipCPUIfNoLapack
1721
    def test_norm_extreme_values(self, device):
1722
        vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1723
        # matrix_ords 'nuc', 2, -2 are skipped currently
1724
        # See issue https://github.com/pytorch/pytorch/issues/71911
1725
        matrix_ords = ['fro', 1, inf, -1, -inf]
1726
        vectors = []
1727
        matrices = []
1728
        for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1729
            vectors.append(list(pair))
1730
            matrices.append([[pair[0], pair[1]]])
1731
            matrices.append([[pair[0]], [pair[1]]])
1732
        for vector in vectors:
1733
            x = torch.tensor(vector).to(device)
1734
            x_n = x.cpu().numpy()
1735
            for ord in vector_ords:
1736
                msg = f'ord={ord}, vector={vector}'
1737
                result = torch.linalg.norm(x, ord=ord)
1738
                result_n = np.linalg.norm(x_n, ord=ord)
1739
                self.assertEqual(result, result_n, msg=msg)
1740

1741
        # TODO: Remove this function once the broken cases are fixed
1742
        def is_broken_matrix_norm_case(ord, x):
1743
            if self.device_type == 'cuda':
1744
                if x.size() == torch.Size([1, 2]):
1745
                    if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1:
1746
                        # These cases are broken because of an issue with svd
1747
                        # https://github.com/pytorch/pytorch/issues/43567
1748
                        return True
1749
                if ord in ['nuc', 2, -2]:
1750
                    # These cases are broken because of another issue with svd
1751
                    # https://github.com/pytorch/pytorch/issues/52633
1752
                    return True
1753
            return False
1754

1755
        for matrix in matrices:
1756
            x = torch.tensor(matrix).to(device)
1757
            x_n = x.cpu().numpy()
1758
            for ord in matrix_ords:
1759
                msg = f'ord={ord}, matrix={matrix}'
1760
                if is_broken_matrix_norm_case(ord, x):
1761
                    continue
1762
                else:
1763
                    result_n = np.linalg.norm(x_n, ord=ord)
1764
                    result = torch.linalg.norm(x, ord=ord)
1765
                    self.assertEqual(result, result_n, msg=msg)
1766

1767
    # Test degenerate shape results match numpy for linalg.norm vector norms
1768
    @skipCUDAIfNoMagma
1769
    @skipCPUIfNoLapack
1770
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1771
    def test_norm_vector_degenerate_shapes(self, device, dtype):
1772
        def run_test_case(input, ord, dim, keepdim):
1773
            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1774
            if (input.numel() == 0 and
1775
                (ord < 0. or ord == inf) and
1776
               (dim is None or input.shape[dim] == 0)):
1777
                with self.assertRaises(RuntimeError):
1778
                    torch.linalg.norm(input, ord, dim, keepdim)
1779
            else:
1780
                input_numpy = input.cpu().numpy()
1781
                result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1782
                result = torch.linalg.norm(input, ord, dim, keepdim)
1783
                self.assertEqual(result, result_numpy, msg=msg)
1784

1785
        ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1786
        S = 10
1787
        test_cases = [
1788
            # input size, dim
1789
            ((0, ), None),
1790
            ((0, S), 0),
1791
            ((0, S), 1),
1792
            ((S, 0), 0),
1793
            ((S, 0), 1),
1794
        ]
1795
        for keepdim in [True, False]:
1796
            for input_size, dim in test_cases:
1797
                input = torch.randn(*input_size, dtype=dtype, device=device)
1798
                for ord in ord_vector:
1799
                    run_test_case(input, ord, dim, keepdim)
1800

1801
    # Test degenerate shape results match numpy for linalg.norm matrix norms
1802
    @skipCUDAIfNoMagma
1803
    @skipCPUIfNoLapack
1804
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1805
    def test_norm_matrix_degenerate_shapes(self, device, dtype):
1806
        def run_test_case(input, ord, dim, keepdim, should_error):
1807
            msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1808
            input_numpy = input.cpu().numpy()
1809
            ops = [torch.linalg.norm]
1810

1811
            if ord is not None and dim is not None:
1812
                ops.append(torch.linalg.matrix_norm)
1813

1814
            if should_error:
1815
                with self.assertRaises(ValueError):
1816
                    np.linalg.norm(input_numpy, ord, dim, keepdim)
1817
                for op in ops:
1818
                    with self.assertRaises(IndexError):
1819
                        op(input, ord, dim, keepdim)
1820
            else:
1821
                result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1822
                for op in ops:
1823
                    result = op(input, ord, dim, keepdim)
1824
                    self.assertEqual(result, result_numpy, msg=msg)
1825

1826
        ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None]
1827
        S = 10
1828
        test_cases = [
1829
            # input size, p settings that cause error, dim
1830
            ((0, 0), [1, 2, inf, -1, -2, -inf], None),
1831
            ((0, S), [2, inf, -2, -inf], None),
1832
            ((S, 0), [1, 2, -1, -2], None),
1833
            ((S, S, 0), [], (0, 1)),
1834
            ((1, S, 0), [], (0, 1)),
1835
            ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)),
1836
            ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)),
1837
        ]
1838

1839
        for keepdim in [True, False]:
1840
            for input_size, error_ords, dim in test_cases:
1841
                input = torch.randn(*input_size, dtype=dtype, device=device)
1842
                for ord in ord_matrix:
1843
                    run_test_case(input, ord, dim, keepdim, ord in error_ords)
1844

1845
    def test_norm_fastpaths(self, device):
1846
        x = torch.randn(3, 5, device=device)
1847

1848
        # slow path
1849
        result = torch.linalg.norm(x, 4.5, 1)
1850
        expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5)
1851
        self.assertEqual(result, expected)
1852

1853
        # fast 0-norm
1854
        result = torch.linalg.norm(x, 0, 1)
1855
        expected = (x != 0).type_as(x).sum(1)
1856
        self.assertEqual(result, expected)
1857

1858
        # fast 1-norm
1859
        result = torch.linalg.norm(x, 1, 1)
1860
        expected = x.abs().sum(1)
1861
        self.assertEqual(result, expected)
1862

1863
        # fast 2-norm
1864
        result = torch.linalg.norm(x, 2, 1)
1865
        expected = torch.sqrt(x.pow(2).sum(1))
1866
        self.assertEqual(result, expected)
1867

1868
        # fast 3-norm
1869
        result = torch.linalg.norm(x, 3, 1)
1870
        expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0)
1871
        self.assertEqual(result, expected)
1872

1873
    @skipCPUIfNoLapack
1874
    @skipCUDAIfNoMagma
1875
    # NumPy computes only in float64 and complex128 precisions
1876
    # for float32 or complex64 results might be very different from float64 or complex128
1877
    @dtypes(torch.float64, torch.complex128)
1878
    def test_eig_numpy(self, device, dtype):
1879
        def run_test(shape, *, symmetric=False):
1880
            from torch.testing._internal.common_utils import random_symmetric_matrix
1881

1882
            if not dtype.is_complex and symmetric:
1883
                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
1884
                # unlike NumPy the result is not cast to float32 or float64 dtype in this case
1885
                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1886
            else:
1887
                a = make_tensor(shape, dtype=dtype, device=device)
1888

1889
            actual = torch.linalg.eig(a)
1890

1891
            # compare with NumPy
1892
            # the eigenvalues are not necessarily ordered
1893
            # so order of NumPy and PyTorch can be different
1894
            expected = np.linalg.eig(a.cpu().numpy())
1895

1896
            # sort NumPy output
1897
            ind = np.argsort(expected[0], axis=-1)[::-1]
1898
            expected = (np.take_along_axis(expected[0], ind, axis=-1), np.take_along_axis(expected[1], ind[:, None], axis=-1))
1899

1900
            # sort PyTorch output
1901
            # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead
1902
            # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble
1903
            # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble'
1904
            ind = np.argsort(actual[0].cpu().numpy(), axis=-1)[::-1]
1905
            actual_np = [x.cpu().numpy() for x in actual]
1906
            sorted_actual = (
1907
                np.take_along_axis(actual_np[0], ind, axis=-1),
1908
                np.take_along_axis(actual_np[1], ind[:, None], axis=-1))
1909

1910
            self.assertEqual(expected[0], sorted_actual[0], exact_dtype=False)
1911
            self.assertEqual(abs(expected[1]), abs(sorted_actual[1]), exact_dtype=False)
1912

1913
        shapes = [(0, 0),  # Empty matrix
1914
                  (5, 5),  # Single matrix
1915
                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
1916
                  (2, 5, 5),  # 3-dim tensors
1917
                  (2, 1, 5, 5)]  # 4-dim tensors
1918
        for shape in shapes:
1919
            run_test(shape)
1920
            run_test(shape, symmetric=True)
1921

1922
    @onlyCUDA
1923
    @skipCUDAIfNoMagma
1924
    @dtypes(*floating_and_complex_types())
1925
    def test_eig_compare_backends(self, device, dtype):
1926
        def run_test(shape, *, symmetric=False):
1927
            from torch.testing._internal.common_utils import random_symmetric_matrix
1928

1929
            if not dtype.is_complex and symmetric:
1930
                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
1931
                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1932
            else:
1933
                a = make_tensor(shape, dtype=dtype, device=device)
1934

1935
            actual = torch.linalg.eig(a)
1936

1937
            complementary_device = 'cpu'
1938

1939
            # compare with CPU
1940
            expected = torch.linalg.eig(a.to(complementary_device))
1941
            self.assertEqual(expected[0], actual[0])
1942
            self.assertEqual(expected[1], actual[1])
1943

1944
        shapes = [(0, 0),  # Empty matrix
1945
                  (5, 5),  # Single matrix
1946
                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
1947
                  (2, 5, 5),  # 3-dim tensors
1948
                  (2, 1, 5, 5)]  # 4-dim tensors
1949
        for shape in shapes:
1950
            run_test(shape)
1951
            run_test(shape, symmetric=True)
1952

1953
    @slowTest
1954
    @onlyCUDA
1955
    @skipCUDAIfNoMagma
1956
    @dtypes(torch.float32)
1957
    def test_eig_check_magma(self, device, dtype):
1958
        # For CUDA inputs only matrices of size larger than 2048x2048 actually call MAGMA library
1959
        shape = (2049, 2049)
1960
        a = make_tensor(shape, dtype=dtype, device=device)
1961
        w, v = torch.linalg.eig(a)
1962
        # check correctness using eigendecomposition identity
1963
        self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3)
1964

1965
    @skipCUDAIfNoMagma
1966
    @skipCPUIfNoLapack
1967
    @dtypes(*floating_and_complex_types())
1968
    def test_eig_errors_and_warnings(self, device, dtype):
1969
        # eig requires the input to be at least 2 dimensional tensor
1970
        a = make_tensor(2, dtype=dtype, device=device)
1971
        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
1972
            torch.linalg.eig(a)
1973

1974
        # eig requires a square matrix
1975
        a = make_tensor((2, 3), dtype=dtype, device=device)
1976
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1977
            torch.linalg.eig(a)
1978

1979
        # if out tensor with floating dtype is passed for complex output an error is thrown
1980
        if not dtype.is_complex:
1981
            # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i
1982
            a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
1983
            out0 = torch.empty(0, device=device, dtype=dtype)
1984
            out1 = torch.empty(0, device=device, dtype=dtype)
1985
            with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
1986
                torch.linalg.eig(a, out=(out0, out1))
1987

1988
            out0 = torch.empty(0, device=device, dtype=torch.complex128)
1989
            with self.assertRaisesRegex(RuntimeError, "Expected eigenvectors to be safely castable"):
1990
                torch.linalg.eig(a, out=(out0, out1))
1991

1992
        # dtypes should be safely castable
1993
        a = make_tensor((3, 3), dtype=dtype, device=device)
1994
        out0 = torch.empty(0, dtype=torch.int, device=device)
1995
        out1 = torch.empty(0, dtype=torch.int, device=device)
1996
        with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
1997
            torch.linalg.eig(a, out=(out0, out1))
1998

1999
        out0 = torch.empty(0, dtype=torch.complex128, device=device)
2000
        with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"):
2001
            torch.linalg.eig(a, out=(out0, out1))
2002

2003
        # if non-empty out tensor with wrong shape is passed a warning is given
2004
        a = make_tensor((3, 3), dtype=dtype, device=device)
2005
        out0 = torch.empty(1, device=device, dtype=torch.complex128)
2006
        out1 = torch.empty(1, device=device, dtype=torch.complex128)
2007
        with warnings.catch_warnings(record=True) as w:
2008
            # Trigger warning
2009
            torch.linalg.eig(a, out=(out0, out1))
2010
            # Check warning occurs
2011
            self.assertEqual(len(w), 2)
2012
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2013
            self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
2014

2015
        # device should match
2016
        if torch.cuda.is_available():
2017
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2018
            out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2019
            out_v = torch.empty(0, device=device, dtype=torch.complex128)
2020
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2021
                torch.linalg.eig(a, out=(out_w, out_v))
2022
            out_w = torch.empty(0, device=device, dtype=torch.complex128)
2023
            out_v = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2024
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2025
                torch.linalg.eig(a, out=(out_w, out_v))
2026

2027
    @skipCPUIfNoLapack
2028
    @skipCUDAIfNoMagma
2029
    @dtypes(*floating_and_complex_types())
2030
    def test_eig_with_nan(self, device, dtype):
2031
        for val in [np.inf, np.nan]:
2032
            for batch_dim in [(), (10,)]:
2033
                a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype)
2034
                a[..., -1, -1] = val
2035

2036
                with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"):
2037
                    torch.linalg.eig(a)
2038

2039
    @skipCPUIfNoLapack
2040
    @skipCUDAIfNoMagma
2041
    # NumPy computes only in float64 and complex128 precisions
2042
    # for float32 or complex64 results might be very different from float64 or complex128
2043
    @dtypes(torch.float64, torch.complex128)
2044
    def test_eigvals_numpy(self, device, dtype):
2045
        def run_test(shape, *, symmetric=False):
2046
            from torch.testing._internal.common_utils import random_symmetric_matrix
2047

2048
            if not dtype.is_complex and symmetric:
2049
                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
2050
                # unlike NumPy the result is not cast to float32 or float64 dtype in this case
2051
                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2052
            else:
2053
                a = make_tensor(shape, dtype=dtype, device=device)
2054

2055
            actual = torch.linalg.eigvals(a)
2056

2057
            # compare with NumPy
2058
            # the eigenvalues are not necessarily ordered
2059
            # so order of NumPy and PyTorch can be different
2060
            expected = np.linalg.eigvals(a.cpu().numpy())
2061

2062
            # sort NumPy output
2063
            ind = np.argsort(expected, axis=-1)[::-1]
2064
            expected = np.take_along_axis(expected, ind, axis=-1)
2065

2066
            # sort PyTorch output
2067
            # torch.argsort doesn't work with complex inputs, NumPy sorting on CPU is used instead
2068
            # RuntimeError: _th_sort not supported on CUDAType for ComplexDouble
2069
            # RuntimeError: "sorting_kernel_method_name" not implemented for 'ComplexDouble'
2070
            ind = np.argsort(actual.cpu().numpy(), axis=-1)[::-1]
2071
            actual_np = actual.cpu().numpy()
2072
            sorted_actual = np.take_along_axis(actual_np, ind, axis=-1)
2073

2074
            self.assertEqual(expected, sorted_actual, exact_dtype=False)
2075

2076
        shapes = [(0, 0),  # Empty matrix
2077
                  (5, 5),  # Single matrix
2078
                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
2079
                  (2, 5, 5),  # 3-dim tensors
2080
                  (2, 1, 5, 5)]  # 4-dim tensors
2081
        for shape in shapes:
2082
            run_test(shape)
2083
            run_test(shape, symmetric=True)
2084

2085
    @onlyCUDA
2086
    @skipCUDAIfNoMagma
2087
    @dtypes(*floating_and_complex_types())
2088
    def test_eigvals_compare_backends(self, device, dtype):
2089
        def run_test(shape, *, symmetric=False):
2090
            from torch.testing._internal.common_utils import random_symmetric_matrix
2091

2092
            if not dtype.is_complex and symmetric:
2093
                # for symmetric real-valued inputs eigenvalues and eigenvectors have imaginary part equal to zero
2094
                a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2095
            else:
2096
                a = make_tensor(shape, dtype=dtype, device=device)
2097

2098
            actual = torch.linalg.eigvals(a)
2099

2100
            complementary_device = 'cpu'
2101

2102
            # compare with CPU
2103
            expected = torch.linalg.eigvals(a.to(complementary_device))
2104
            self.assertEqual(expected, actual)
2105

2106
            # check out= variant
2107
            complex_dtype = dtype
2108
            if not dtype.is_complex:
2109
                complex_dtype = torch.complex128 if dtype == torch.float64 else torch.complex64
2110
            out = torch.empty(0, dtype=complex_dtype, device=device)
2111
            ans = torch.linalg.eigvals(a, out=out)
2112
            self.assertEqual(ans, out)
2113
            self.assertEqual(expected.to(complex_dtype), out)
2114

2115
            # check non-contiguous out
2116
            if a.numel() > 0:
2117
                out = torch.empty(2 * shape[0], *shape[1:-1], dtype=complex_dtype, device=device)[::2]
2118
                self.assertFalse(out.is_contiguous())
2119
                ans = torch.linalg.eigvals(a, out=out)
2120
                self.assertEqual(ans, out)
2121
                self.assertEqual(expected.to(complex_dtype), out)
2122

2123
        shapes = [(0, 0),  # Empty matrix
2124
                  (5, 5),  # Single matrix
2125
                  (0, 0, 0), (0, 5, 5),  # Zero batch dimension tensors
2126
                  (2, 5, 5),  # 3-dim tensors
2127
                  (2, 1, 5, 5)]  # 4-dim tensors
2128
        for shape in shapes:
2129
            run_test(shape)
2130
            run_test(shape, symmetric=True)
2131

2132
    @skipCUDAIfNoMagma
2133
    @skipCPUIfNoLapack
2134
    @dtypes(*floating_and_complex_types())
2135
    def test_eigvals_errors_and_warnings(self, device, dtype):
2136
        # eig requires the input to be at least 2 dimensional tensor
2137
        a = make_tensor(2, dtype=dtype, device=device)
2138
        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
2139
            torch.linalg.eigvals(a)
2140

2141
        # eig requires a square matrix
2142
        a = make_tensor((2, 3), dtype=dtype, device=device)
2143
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2144
            torch.linalg.eigvals(a)
2145

2146
        # if out tensor with floating dtype is passed for complex output an error is thrown
2147
        if not dtype.is_complex:
2148
            # The characteristic equation is p(lambda) = lambda^2 - 2lambda + 5 = 0, with roots lambda = 1[+-]2i
2149
            a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
2150
            out = torch.empty(0, device=device, dtype=dtype)
2151
            with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
2152
                torch.linalg.eigvals(a, out=out)
2153

2154
        # dtypes should be safely castable
2155
        a = make_tensor((3, 3), dtype=dtype, device=device)
2156
        out = torch.empty(0, dtype=torch.int, device=device)
2157
        with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
2158
            torch.linalg.eigvals(a, out=out)
2159

2160
        # if non-empty out tensor with wrong shape is passed a warning is given
2161
        out = torch.empty(1, device=device, dtype=torch.complex128)
2162
        with warnings.catch_warnings(record=True) as w:
2163
            # Trigger warning
2164
            torch.linalg.eigvals(a, out=out)
2165
            # Check warning occurs
2166
            self.assertEqual(len(w), 1)
2167
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2168

2169
        # device should match
2170
        if torch.cuda.is_available():
2171
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2172
            out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2173
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2174
                torch.linalg.eigvals(a, out=out_w)
2175

2176
    @skipCUDAIfNoMagma
2177
    @skipCPUIfNoLapack
2178
    def test_norm_old(self, device):
2179
        def gen_error_message(input_size, p, keepdim, dim=None):
2180
            return f"norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2181

2182
        # 'nuc' norm uses SVD, and thus its precsion is much lower than other norms.
2183
        # test_svd takes @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4}),
2184
        # and here we are doing the same thing for nuc norm.
2185
        class PrecisionContext:
2186
            def __init__(self, test, norm):
2187
                self.norm = norm
2188
                self.saved_overrides = getattr(test, 'precision_overrides', None)
2189
                self.target_test = test
2190

2191
            def __enter__(self):
2192
                if 'nuc' != self.norm:
2193
                    return None
2194
                self.target_test.precision_overrides = {torch.float: 1e-4, torch.cfloat: 2e-4}
2195
                return self.target_test.precision_overrides
2196

2197
            def __exit__(self, type, value, tb) -> bool:
2198
                if 'nuc' != self.norm:
2199
                    return True
2200
                if self.saved_overrides is None:
2201
                    delattr(self.target_test, 'precision_overrides')
2202
                else:
2203
                    self.target_test.precision_overrides = self.saved_overrides
2204
                return True
2205

2206
        for keepdim in [False, True]:
2207
            # full reduction
2208
            x = torch.randn(25, device=device)
2209
            xn = x.cpu().numpy()
2210
            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]:
2211
                res = x.norm(p, keepdim=keepdim).cpu()
2212
                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2213
                self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim))
2214

2215
            # one dimension
2216
            x = torch.randn(25, 25, device=device)
2217
            xn = x.cpu().numpy()
2218
            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]:
2219
                dim = 1
2220
                res = x.norm(p, dim, keepdim=keepdim).cpu()
2221
                expected = np.linalg.norm(xn, p, dim, keepdims=keepdim)
2222
                msg = gen_error_message(x.size(), p, keepdim, dim)
2223
                self.assertEqual(res.shape, expected.shape, msg=msg)
2224
                self.assertEqual(res, expected, msg=msg)
2225

2226
            # matrix norm
2227
            for p in ['fro', 'nuc']:
2228
                res = x.norm(p, keepdim=keepdim).cpu()
2229
                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2230
                msg = gen_error_message(x.size(), p, keepdim)
2231
                with PrecisionContext(self, p):
2232
                    self.assertEqual(res.shape, expected.shape, msg=msg)
2233
                    self.assertEqual(res, expected, msg=msg)
2234

2235
            # zero dimensions
2236
            x = torch.randn((), device=device)
2237
            xn = x.cpu().numpy()
2238
            res = x.norm(keepdim=keepdim).cpu()
2239
            expected = np.linalg.norm(xn, keepdims=keepdim)
2240
            msg = gen_error_message(x.size(), None, keepdim)
2241
            self.assertEqual(res.shape, expected.shape, msg=msg)
2242
            self.assertEqual(res, expected, msg=msg)
2243

2244
            # larger tensor sanity check
2245
            self.assertEqual(
2246
                2 * torch.norm(torch.ones(10000), keepdim=keepdim),
2247
                torch.norm(torch.ones(40000), keepdim=keepdim))
2248

2249
            # matrix norm with non-square >2-D tensors, all combinations of reduction dims
2250
            x = torch.randn(5, 6, 7, 8, device=device)
2251
            xn = x.cpu().numpy()
2252
            for p in ['fro', 'nuc']:
2253
                for dim in itertools.product(*[list(range(4))] * 2):
2254
                    if dim[0] == dim[1]:
2255
                        continue
2256
                    res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu()
2257
                    expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim)
2258
                    msg = gen_error_message(x.size(), p, keepdim, dim)
2259
                    with PrecisionContext(self, p):
2260
                        self.assertEqual(res.shape, expected.shape, msg=msg)
2261
                        self.assertEqual(res, expected, msg=msg)
2262

2263
    # Test that torch.norm with p=+/-inf propagates NaN
2264
    def test_norm_old_nan_propagation(self, device):
2265
        ords = [inf, -inf]
2266
        for pair in itertools.product([0.0, nan, 1.0], repeat=2):
2267
            x = torch.tensor(list(pair), device=device)
2268
            for ord in ords:
2269
                result = torch.norm(x, p=ord)
2270
                result_check = torch.linalg.norm(x, ord=ord)
2271
                self.assertEqual(result, result_check)
2272

2273
    @skipCUDAIfNoMagma
2274
    @skipCPUIfNoLapack
2275
    def test_norm_complex_old(self, device):
2276
        def gen_error_message(input_size, p, keepdim, dim=None):
2277
            return f"complex norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2278

2279
        for keepdim in [False, True]:
2280
            # vector norm
2281
            x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device)
2282
            xn = x.cpu().numpy()
2283
            for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]:
2284
                res = x.norm(p, keepdim=keepdim).cpu()
2285
                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2286
                msg = gen_error_message(x.size(), p, keepdim)
2287
                self.assertEqual(res.shape, expected.shape, msg=msg)
2288
                self.assertEqual(res, expected, msg=msg)
2289

2290
            # matrix norm
2291
            x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device)
2292
            xn = x.cpu().numpy()
2293
            for p in ['nuc', 'fro']:
2294
                res = x.norm(p, keepdim=keepdim).cpu()
2295
                expected = np.linalg.norm(xn, p, keepdims=keepdim)
2296
                msg = gen_error_message(x.size(), p, keepdim)
2297
                self.assertEqual(res.shape, expected.shape, msg=msg)
2298
                self.assertEqual(res, expected, msg=msg, rtol=4e-6, atol=6e-4)
2299

2300
    # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations
2301
    @dtypes(torch.float)
2302
    def test_norm_fro_2_equivalence_old(self, device, dtype):
2303
        input_sizes = [
2304
            (0,),
2305
            (10,),
2306
            (0, 0),
2307
            (4, 30),
2308
            (0, 45),
2309
            (100, 0),
2310
            (45, 10, 23),
2311
            (0, 23, 59),
2312
            (23, 0, 37),
2313
            (34, 58, 0),
2314
            (0, 0, 348),
2315
            (0, 3434, 0),
2316
            (0, 0, 0),
2317
            (5, 3, 8, 1, 3, 5)]
2318

2319
        for input_size in input_sizes:
2320
            a = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
2321

2322
            # Try full reduction
2323
            dim_settings = [None]
2324

2325
            # Try all possible 1-D reductions
2326
            dim_settings += list(range(-a.dim(), a.dim()))
2327

2328
            def wrap_dim(dim, ndims):
2329
                assert (dim < ndims) and (dim >= -ndims)
2330
                if dim >= 0:
2331
                    return dim
2332
                else:
2333
                    return dim + ndims
2334

2335
            # Try all possible 2-D reductions
2336
            dim_settings += [
2337
                (d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2)
2338
                if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())]
2339

2340
            for dim in dim_settings:
2341
                for keepdim in [True, False]:
2342
                    a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim)
2343
                    a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim)
2344
                    self.assertEqual(a_norm_fro, a_norm_2)
2345

2346
    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2347
    @skipCUDAIfNoMagma
2348
    @skipCPUIfNoLapack
2349
    def test_nuclear_norm_axes_small_brute_force_old(self, device):
2350
        def check_single_nuclear_norm(x, axes):
2351
            if self.device_type != 'cpu' and randrange(100) < 95:
2352
                return  # too many cpu <==> device copies
2353

2354
            a = np.array(x.cpu(), copy=False)
2355
            expected = np.linalg.norm(a, "nuc", axis=axes)
2356

2357
            ans = torch.norm(x, "nuc", dim=axes)
2358
            self.assertTrue(ans.is_contiguous())
2359
            self.assertEqual(ans.shape, expected.shape)
2360
            self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2361

2362
            out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
2363
            ans = torch.norm(x, "nuc", dim=axes, out=out)
2364
            self.assertIs(ans, out)
2365
            self.assertTrue(ans.is_contiguous())
2366
            self.assertEqual(ans.shape, expected.shape)
2367
            self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2368

2369
        for n in range(1, 3):
2370
            for m in range(1, 3):
2371
                for axes in itertools.permutations([0, 1], 2):
2372
                    # 2d, inner dimensions C
2373
                    x = torch.randn(n, m, device=device)
2374
                    check_single_nuclear_norm(x, axes)
2375

2376
                    # 2d, inner dimensions Fortran
2377
                    x = torch.randn(m, n, device=device).mT
2378
                    check_single_nuclear_norm(x, axes)
2379

2380
                    # 2d, inner dimensions non-contiguous
2381
                    x = torch.randn(n, 2 * m, device=device)[:, ::2]
2382
                    check_single_nuclear_norm(x, axes)
2383

2384
                    # 2d, all dimensions non-contiguous
2385
                    x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2]
2386
                    check_single_nuclear_norm(x, axes)
2387

2388
                for o in range(1, 3):
2389
                    for axes in itertools.permutations([0, 1, 2], 2):
2390
                        # 3d, inner dimensions C
2391
                        x = torch.randn(o, n, m, device=device)
2392
                        check_single_nuclear_norm(x, axes)
2393

2394
                        # 3d, inner dimensions Fortran
2395
                        x = torch.randn(o, m, n, device=device).mT
2396
                        check_single_nuclear_norm(x, axes)
2397

2398
                        # 3d, inner dimensions non-contiguous
2399
                        x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2]
2400
                        check_single_nuclear_norm(x, axes)
2401

2402
                        # 3d, all dimensions non-contiguous
2403
                        x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2]
2404
                        check_single_nuclear_norm(x, axes)
2405

2406
                    for r in range(1, 3):
2407
                        for axes in itertools.permutations([0, 1, 2, 3], 2):
2408
                            # 4d, inner dimensions C
2409
                            x = torch.randn(r, o, n, m, device=device)
2410
                            check_single_nuclear_norm(x, axes)
2411

2412
                            # 4d, inner dimensions Fortran
2413
                            x = torch.randn(r, o, n, m, device=device).mT
2414
                            check_single_nuclear_norm(x, axes)
2415

2416
                            # 4d, inner dimensions non-contiguous
2417
                            x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2]
2418
                            check_single_nuclear_norm(x, axes)
2419

2420
                            # 4d, all dimensions non-contiguous
2421
                            x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2]
2422
                            check_single_nuclear_norm(x, axes)
2423

2424
    @skipCUDAIfNoMagma
2425
    def test_nuclear_norm_exceptions_old(self, device):
2426
        for lst in [], [1], [1, 2]:
2427
            x = torch.tensor(lst, dtype=torch.double, device=device)
2428
            for axes in (), (0,):
2429
                self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes)
2430
            self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1))
2431

2432
        x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device)
2433
        self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0))
2434
        self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))
2435

2436
    @skipCUDAIfNoCusolver
2437
    @skipCPUIfNoLapack
2438
    @dtypes(torch.double, torch.cdouble)
2439
    def test_svd_lowrank(self, device, dtype):
2440
        from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
2441

2442
        def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options):
2443
            density = options.pop('density', 1)
2444
            if isinstance(matrix_size, int):
2445
                rows = columns = matrix_size
2446
            else:
2447
                rows, columns = matrix_size
2448
            if density == 1:
2449
                a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
2450
                a = a_input
2451
            else:
2452
                assert batches == ()
2453
                a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
2454
                a = a_input.to_dense()
2455

2456
            q = min(*size)
2457
            u, s, v = svd_lowrank(a_input, q=q, **options)
2458

2459
            # check if u, s, v is a SVD
2460
            u, s, v = u[..., :q], s[..., :q], v[..., :q]
2461
            A = (u * s.unsqueeze(-2)).matmul(v.mH)
2462
            self.assertEqual(A, a, rtol=1e-7, atol=2e-7)
2463

2464
            # check if svd_lowrank produces same singular values as linalg.svdvals
2465
            U, S, Vh = torch.linalg.svd(a, full_matrices=False)
2466
            V = Vh.mH
2467
            self.assertEqual(s, S)
2468

2469
            if density == 1:
2470
                # actual_rank is known only for dense inputs
2471
                #
2472
                # check if pairs (u, U) and (v, V) span the same
2473
                # subspaces, respectively
2474
                u, v = u[..., :actual_rank], v[..., :actual_rank]
2475
                U, V = U[..., :actual_rank], V[..., :actual_rank]
2476
                expected_ones = u.mH.matmul(U).det().abs()
2477
                self.assertEqual(expected_ones, torch.ones_like(expected_ones))
2478
                self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones))
2479

2480
        all_batches = [(), (1,), (3,), (2, 3)]
2481
        for actual_rank, size, all_batches in [  # noqa: B020
2482
                (2, (17, 4), all_batches),
2483
                (4, (17, 4), all_batches),
2484
                (4, (17, 17), all_batches),
2485
                (10, (100, 40), all_batches),
2486
                (7, (1000, 1000), [()]),
2487
        ]:
2488
            # dense input
2489
            for batches in all_batches:
2490
                run_subtest(actual_rank, size, batches, device, torch.svd_lowrank)
2491
                if size != size[::-1]:
2492
                    run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank)
2493

2494
        # sparse input
2495
        for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]:
2496
            for density in [0.005, 0.1]:
2497
                run_subtest(None, size, (), device, torch.svd_lowrank, density=density)
2498

2499
        # jitting support
2500
        jitted = torch.jit.script(torch.svd_lowrank)
2501
        actual_rank, size, batches = 2, (17, 4), ()
2502
        run_subtest(actual_rank, size, batches, device, jitted)
2503

2504
    @skipCUDAIfNoMagmaAndNoCusolver
2505
    @skipCPUIfNoLapack
2506
    @precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4})
2507
    @setLinalgBackendsToDefaultFinally
2508
    @dtypes(*floating_and_complex_types())
2509
    @serialTest()
2510
    def test_svd(self, device, dtype):
2511
        # tests linalg.svd, svd, linalg.svdvals
2512
        make_arg = partial(make_tensor, dtype=dtype, device=device)
2513

2514
        backends = ["default"]
2515

2516
        if torch.device(device).type == 'cuda':
2517
            if torch.cuda.has_magma:
2518
                backends.append("magma")
2519
            if has_cusolver() or has_hipsolver():
2520
                backends.append("cusolver")
2521

2522
        ns = (12, 4, 2, 0)
2523
        batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
2524
        drivers = (None, 'gesvd', 'gesvdj', 'gesvda')
2525

2526
        for backend in backends:
2527
            torch.backends.cuda.preferred_linalg_library(backend)
2528

2529
            for batch, m, n, driver in product(batches, ns, ns, drivers):
2530
                if not (backend == 'cusolver' or driver is None):
2531
                    # only test cases below and skip otherwise:
2532
                    # - backend == 'cusolver' (driver can be anything)
2533
                    # - backend != 'cusolver' (driver should only be None)
2534
                    continue
2535

2536
                shape = batch + (m, n)
2537
                k = min(m, n)
2538
                A = make_arg(shape)
2539
                U, S, Vh = torch.linalg.svd(A, full_matrices=False, driver=driver)
2540
                self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ Vh, A)
2541

2542
                U_f, S_f, Vh_f = torch.linalg.svd(A, full_matrices=True, driver=driver)
2543
                self.assertEqual(S_f, S)
2544
                self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ Vh_f[..., :k, :], A)
2545

2546
                S_s = torch.linalg.svdvals(A, driver=driver)
2547
                self.assertEqual(S_s, S)
2548

2549
                U, S, V = torch.svd(A, some=True)
2550
                self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ V.mH, A)
2551

2552
                U_f, S_f, V_f = torch.svd(A, some=False)
2553
                self.assertEqual(S_f, S)
2554
                self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ V_f[..., :k].mH, A)
2555

2556
                S_s = torch.svd(A, compute_uv=False).S
2557
                self.assertEqual(S_s, S)
2558

2559
    @skipCUDAIfNoMagmaAndNoCusolver
2560
    @skipCPUIfNoLapack
2561
    @dtypes(torch.complex128)
2562
    def test_invariance_error_spectral_decompositions(self, device, dtype):
2563
        make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
2564
        A = make_arg((3, 3))
2565
        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2566
            U, _, Vh = torch.linalg.svd(A, full_matrices=False)
2567
            (U + Vh).sum().abs().backward()
2568

2569
        A = make_arg((3, 3))
2570
        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2571
            V = torch.linalg.eig(A).eigenvectors
2572
            V.sum().abs().backward()
2573

2574
        A = make_arg((3, 3))
2575
        A = A + A.mH
2576
        with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2577
            Q = torch.linalg.eigh(A).eigenvectors
2578
            Q.sum().abs().backward()
2579

2580
    @skipCUDAIfNoCusolver  # MAGMA backend doesn't work in this case
2581
    @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
2582
    @skipCPUIfNoLapack
2583
    @dtypes(*floating_and_complex_types())
2584
    def test_svd_memory_allocation(self, device, dtype):
2585
        # test for https://github.com/pytorch/pytorch/issues/61949
2586
        # the problem was that tensors of incorrect size were allocated and then narrowed
2587
        m = 3
2588
        n = 2**20
2589
        a = make_tensor((m, n), dtype=dtype, device=device)
2590
        # the following should run without errors
2591
        S = torch.linalg.svdvals(a)
2592
        result = torch.linalg.svd(a, full_matrices=False)
2593
        self.assertEqual(result.S, S)
2594

2595
    def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
2596
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2597

2598
        b = torch.randn(*b_dims, dtype=dtype, device=device)
2599
        A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device)
2600
        L = torch.cholesky(A, upper=upper)
2601
        return b, A, L
2602

2603
    @skipCUDAIfNoMagma
2604
    @skipCPUIfNoLapack
2605
    @dtypes(*floating_and_complex_types())
2606
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2607
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2608
    def test_cholesky_solve(self, device, dtype):
2609
        for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]):
2610
            b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype)
2611
            x = torch.cholesky_solve(b, L, upper=upper)
2612
            self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
2613

2614
    @skipCUDAIfNoMagma
2615
    @skipCPUIfNoLapack
2616
    @dtypes(*floating_and_complex_types())
2617
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2618
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2619
    def test_cholesky_solve_batched(self, device, dtype):
2620
        def cholesky_solve_batch_helper(A_dims, b_dims, upper):
2621
            b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2622
            x_exp_list = []
2623
            for i in range(b_dims[0]):
2624
                x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper))
2625
            x_exp = torch.stack(x_exp_list)  # Stacked output
2626
            x_act = torch.cholesky_solve(b, L, upper=upper)  # Actual output
2627
            self.assertEqual(x_act, x_exp)  # Equality check
2628
            Ax = np.matmul(A.cpu(), x_act.cpu())
2629
            self.assertEqual(b, Ax)  # Correctness check
2630

2631
        for upper, batchsize in itertools.product([True, False], [1, 3, 4]):
2632
            cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper)
2633

2634
    @slowTest
2635
    @skipCUDAIfNoMagma
2636
    @skipCPUIfNoLapack
2637
    @dtypes(*floating_and_complex_types())
2638
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2639
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2640
    def test_cholesky_solve_batched_many_batches(self, device, dtype):
2641
        for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]):
2642
            for upper in [True, False]:
2643
                b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2644
                x = torch.cholesky_solve(b, L, upper)
2645
                Ax = torch.matmul(A, x)
2646
                self.assertEqual(Ax, b.expand_as(Ax))
2647

2648
    @skipCUDAIfNoMagma
2649
    @skipCPUIfNoLapack
2650
    @dtypes(*floating_and_complex_types())
2651
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2652
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2653
    def test_cholesky_solve_batched_broadcasting(self, device, dtype):
2654
        from numpy.linalg import solve
2655
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2656

2657
        def run_test(A_dims, b_dims, upper):
2658
            A_matrix_size = A_dims[-1]
2659
            A_batch_dims = A_dims[:-2]
2660
            A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims,
2661
                                           dtype=dtype, device='cpu')
2662
            b = torch.randn(*b_dims, dtype=dtype, device='cpu')
2663
            x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device)
2664
            A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device)
2665
            L = torch.linalg.cholesky(A, upper=upper)
2666
            x = torch.cholesky_solve(b, L, upper=upper)
2667
            self.assertEqual(x, x_exp)
2668
            # https://github.com/pytorch/pytorch/issues/42695
2669
            x = torch.cholesky_solve(b, L, upper=upper, out=x)
2670
            self.assertEqual(x, x_exp)
2671

2672
        # test against numpy.linalg.solve
2673
        for upper in [True, False]:
2674
            run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper)  # no broadcasting
2675
            run_test((2, 1, 3, 4, 4), (4, 6), upper)  # broadcasting b
2676
            run_test((4, 4), (2, 1, 3, 4, 2), upper)  # broadcasting A
2677
            run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper)  # broadcasting A & b
2678

2679
    @skipCUDAIfNoMagma
2680
    @skipCPUIfNoLapack
2681
    @dtypes(*floating_and_complex_types())
2682
    def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
2683
        # dtypes should be safely castable
2684
        a = torch.eye(2, dtype=dtype, device=device)
2685
        b = torch.randn(2, 1, dtype=dtype, device=device)
2686
        out = torch.empty(0, dtype=torch.int, device=device)
2687
        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2688
            torch.cholesky_solve(b, a, out=out)
2689

2690
        # device should match
2691
        if torch.cuda.is_available():
2692
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2693
            out = torch.empty(0, dtype=dtype, device=wrong_device)
2694
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2695
                torch.cholesky_solve(b, a, out=out)
2696

2697
        # if out tensor with wrong shape is passed a warning is given
2698
        with warnings.catch_warnings(record=True) as w:
2699
            out = torch.empty(1, dtype=dtype, device=device)
2700
            # Trigger warning
2701
            torch.cholesky_solve(b, a, out=out)
2702
            # Check warning occurs
2703
            self.assertEqual(len(w), 1)
2704
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2705

2706
    @skipCUDAIfNoMagma
2707
    @skipCPUIfNoLapack
2708
    @dtypes(torch.double)
2709
    def test_cholesky_solve_backward(self, device, dtype):
2710
        b_dims = (5, 2)
2711
        L_dims = (5, 5)
2712

2713
        for test_L_grad in (False, True):
2714
            b = torch.randn(*b_dims, dtype=dtype, device=device, requires_grad=True)
2715
            L = torch.randn(*L_dims, dtype=dtype, device=device, requires_grad=test_L_grad)
2716
            if test_L_grad:
2717
                torch.autograd.gradcheck(lambda b, L: torch.cholesky_solve(b, torch.tril(L), upper=False), (b, L))
2718
            else:
2719
                torch.autograd.gradcheck(lambda b: torch.cholesky_solve(b, L, upper=False), (b,))
2720

2721
    @skipCUDAIfNoMagmaAndNoCusolver
2722
    @skipCPUIfNoLapack
2723
    @dtypes(*floating_and_complex_types())
2724
    @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2725
                        torch.float64: 1e-8, torch.complex128: 1e-8})
2726
    def test_inverse(self, device, dtype):
2727
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2728
        make_arg = partial(make_fullrank, device=device, dtype=dtype)
2729

2730
        def run_test(torch_inverse, matrix, batches, n):
2731
            matrix_inverse = torch_inverse(matrix)
2732

2733
            # Compare against NumPy output
2734
            # NumPy uses 'gesv' LAPACK routine solving the equation A A_inv = I
2735
            # But in PyTorch 'gertf' + 'getrs' is used. As such, there may be some element-wise differences
2736
            expected = np.linalg.inv(matrix.cpu().numpy())
2737
            self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision)
2738

2739
            # Additional correctness tests, check matrix*matrix_inverse == identity
2740
            identity = torch.eye(n, dtype=dtype, device=device)
2741
            self.assertEqual(identity.expand_as(matrix), np.matmul(matrix.cpu(), matrix_inverse.cpu()))
2742
            self.assertEqual(identity.expand_as(matrix), np.matmul(matrix_inverse.cpu(), matrix.cpu()))
2743

2744
            # check the out= variant
2745
            # prepare the expected out tensor
2746
            matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device)
2747
            matrix_inverse_out_t = matrix_inverse_out.mT.clone(memory_format=torch.contiguous_format)
2748
            matrix_inverse_out = matrix_inverse_out_t.mT
2749
            ans = torch_inverse(matrix, out=matrix_inverse_out)
2750
            self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0)
2751
            self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0)
2752

2753
            # batched matrices: 3+ dimensional tensors, check matrix_inverse same as single-inverse for each matrix
2754
            if matrix.ndim > 2 and batches[0] != 0:
2755
                expected_inv_list = []
2756
                p = int(np.prod(batches))  # use `p` instead of -1, so that the test works for empty input as well
2757
                for mat in matrix.contiguous().view(p, n, n):
2758
                    expected_inv_list.append(torch_inverse(mat))
2759
                expected_inv = torch.stack(expected_inv_list).view(*batches, n, n)
2760
                if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]:
2761
                    # single-inverse is done using cuSOLVER, while batched inverse is done using MAGMA
2762
                    # individual values can be significantly different for fp32, hence rather high rtol is used
2763
                    # the important thing is that torch_inverse passes above checks with identity
2764
                    self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2)
2765
                else:
2766
                    self.assertEqual(matrix_inverse, expected_inv)
2767

2768
        # helper function for testing torch.linalg.inv_ex
2769
        def test_inv_ex(input, out=None):
2770
            if out is not None:
2771
                info = torch.empty(0, dtype=torch.int32, device=device)
2772
                return torch.linalg.inv_ex(input, out=(out, info)).inverse
2773
            return torch.linalg.inv_ex(input).inverse
2774

2775
        for torch_inverse in [torch.inverse, torch.linalg.inv, test_inv_ex]:
2776
            for batches, n in itertools.product(
2777
                [[], [0], [2], [2, 1]],
2778
                [0, 5]
2779
            ):
2780
                matrices = make_arg(*batches, n, n)
2781
                run_test(torch_inverse, matrices, batches, n)
2782

2783
                # test non-contiguous input
2784
                run_test(torch_inverse, matrices.mT, batches, n)
2785
                if n > 0:
2786
                    run_test(
2787
                        torch_inverse,
2788
                        make_arg(*batches, 2 * n, 2 * n)
2789
                        .view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n),
2790
                        batches, n
2791
                    )
2792

2793
    @skipCUDAIfNoMagmaAndNoCusolver
2794
    @skipCPUIfNoLapack
2795
    @dtypes(*floating_and_complex_types())
2796
    def test_inv_ex_info_device(self, device, dtype):
2797
        A = torch.eye(3, 3, dtype=dtype, device=device)
2798
        info = torch.linalg.inv_ex(A).info
2799
        self.assertTrue(info.device == A.device)
2800

2801
    @skipCUDAIfNoMagmaAndNoCusolver
2802
    @skipCPUIfNoLapack
2803
    @dtypes(*floating_and_complex_types())
2804
    def test_inv_ex_singular(self, device, dtype):
2805
        # if the input matrix is not invertible, info with positive integer is returned
2806
        A = torch.eye(3, 3, dtype=dtype, device=device)
2807
        A[-1, -1] = 0  # Now A is singular
2808
        info = torch.linalg.inv_ex(A).info
2809
        self.assertEqual(info, 3)
2810
        with self.assertRaisesRegex(torch.linalg.LinAlgError,
2811
                                    r'diagonal element 3 is zero, the inversion could not be completed'):
2812
            torch.linalg.inv_ex(A, check_errors=True)
2813

2814
        # if at least one matrix in the batch is not positive definite,
2815
        # batched info with positive integer for the corresponding matrix is returned
2816
        A = torch.eye(3, 3, dtype=dtype, device=device)
2817
        A = A.reshape((1, 3, 3))
2818
        A = A.repeat(5, 1, 1)
2819
        A[3, -2, -2] = 0  # Now A[3] is singular
2820
        info = torch.linalg.inv_ex(A).info
2821

2822
        expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
2823
        expected_info[3] = 2
2824
        self.assertEqual(info, expected_info)
2825
        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'):
2826
            torch.linalg.inv_ex(A, check_errors=True)
2827

2828
    @slowTest
2829
    @skipCUDAIfNoMagmaAndNoCusolver
2830
    @skipCPUIfNoLapack
2831
    @dtypes(*floating_and_complex_types())
2832
    @precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2833
                        torch.float64: 1e-5, torch.complex128: 1e-5})
2834
    def test_inverse_many_batches(self, device, dtype):
2835
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2836
        make_arg = partial(make_fullrank, device=device, dtype=dtype)
2837

2838
        def test_inverse_many_batches_helper(torch_inverse, b, n):
2839
            matrices = make_arg(b, n, n)
2840
            matrices_inverse = torch_inverse(matrices)
2841

2842
            # Compare against NumPy output
2843
            expected = np.linalg.inv(matrices.cpu().numpy())
2844
            self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3)
2845

2846
        for torch_inverse in [torch.inverse, torch.linalg.inv]:
2847
            test_inverse_many_batches_helper(torch_inverse, 5, 256)
2848
            test_inverse_many_batches_helper(torch_inverse, 3, 512)
2849

2850
    @skipCUDAIfNoMagmaAndNoCusolver
2851
    @skipCPUIfNoLapack
2852
    @onlyNativeDeviceTypes   # TODO: XLA doesn't raise exception
2853
    @dtypes(*floating_and_complex_types())
2854
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
2855
    def test_inverse_errors(self, device, dtype):
2856
        # inverse expects batches of square matrices as input
2857
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2858
            torch.inverse(torch.randn(2, 3, 4, 3))
2859

2860
        # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch
2861
        def run_test_singular_input(batch_dim, n):
2862
            x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
2863
            x[n, -1, -1] = 0
2864
            with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
2865
                torch.inverse(x)
2866

2867
        for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
2868
            run_test_singular_input(*params)
2869

2870
    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
2871
    @skipCUDAIfNoMagmaAndNoCusolver
2872
    @skipCPUIfNoLapack
2873
    @onlyNativeDeviceTypes   # TODO: XLA doesn't raise exception
2874
    @dtypes(*floating_and_complex_types())
2875
    def test_inverse_errors_large(self, device, dtype):
2876
        # Test batched inverse of singular matrices reports errors without crashing (gh-51930)
2877
        x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
2878
        x[:] = torch.eye(616, dtype=dtype, device=device)
2879
        x[..., 10, 10] = 0
2880
        with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'):
2881
            torch.inverse(x)
2882

2883
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
2884
    @skipCUDAIfNoMagma
2885
    @skipCPUIfNoLapack
2886
    @dtypes(*floating_and_complex_types())
2887
    def test_pinv(self, device, dtype):
2888
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2889

2890
        def run_test_main(A, hermitian):
2891
            # Testing against definition for pseudo-inverses
2892
            A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
2893
            np_A = A.cpu().numpy()
2894
            np_A_pinv = A_pinv.cpu().numpy()
2895
            if A.numel() > 0:
2896
                self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=self.precision, rtol=self.precision)
2897
                self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=self.precision, rtol=self.precision)
2898
                self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1))
2899
                self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1))
2900
            else:
2901
                self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
2902

2903
            # Check out= variant
2904
            out = torch.empty_like(A_pinv)
2905
            ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
2906
            self.assertEqual(ans, out)
2907
            self.assertEqual(ans, A_pinv)
2908

2909
        def run_test_numpy(A, hermitian):
2910
            # Check against NumPy output
2911
            # Test float rcond, and specific value for each matrix
2912
            rconds = [float(torch.rand(1)), ]
2913
            # Test different types of rcond tensor
2914
            for rcond_type in all_types():
2915
                rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type))
2916
            # Test broadcasting of rcond
2917
            if A.ndim > 2:
2918
                rconds.append(torch.rand(A.shape[-3], device=device))
2919
            for rcond in rconds:
2920
                actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
2921
                torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
2922
                self.assertEqual(actual, torch_rtol)
2923
                numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
2924
                expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
2925
                self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5)
2926

2927
        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
2928
                      (3, 2), (5, 3, 2), (2, 5, 3, 2),  # fat matrices
2929
                      (2, 3), (5, 2, 3), (2, 5, 2, 3),  # thin matrices
2930
                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
2931
            A = torch.randn(*sizes, dtype=dtype, device=device)
2932
            hermitian = False
2933
            run_test_main(A, hermitian)
2934
            run_test_numpy(A, hermitian)
2935

2936
        # Check hermitian = True
2937
        for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),  # square matrices
2938
                      (0, 0), (3, 0, 0), ]:  # zero numel square matrices
2939
            A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
2940
            hermitian = True
2941
            run_test_main(A, hermitian)
2942
            run_test_numpy(A, hermitian)
2943

2944
    @skipCUDAIfNoMagma
2945
    @skipCPUIfNoLapack
2946
    @dtypes(*floating_and_complex_types())
2947
    def test_pinv_errors_and_warnings(self, device, dtype):
2948
        # pinv requires at least 2D tensor
2949
        a = torch.randn(1, device=device, dtype=dtype)
2950
        with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"):
2951
            torch.linalg.pinv(a)
2952

2953
        # if non-empty out tensor with wrong shape is passed a warning is given
2954
        a = torch.randn(3, 3, dtype=dtype, device=device)
2955
        out = torch.empty(7, 7, dtype=dtype, device=device)
2956
        with warnings.catch_warnings(record=True) as w:
2957
            # Trigger warning
2958
            torch.linalg.pinv(a, out=out)
2959
            # Check warning occurs
2960
            self.assertEqual(len(w), 1)
2961
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2962

2963
        # dtypes of out and input should be safely castable
2964
        out = torch.empty_like(a).to(torch.int)
2965
        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2966
            torch.linalg.pinv(a, out=out)
2967

2968
        if torch.cuda.is_available():
2969
            # device of out and input should match
2970
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2971
            out = torch.empty_like(a).to(wrong_device)
2972
            with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"):
2973
                torch.linalg.pinv(a, out=out)
2974

2975
            # device of rcond and input should match
2976
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2977
            rcond = torch.full((), 1e-2, device=wrong_device)
2978
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
2979
                torch.linalg.pinv(a, rcond=rcond)
2980

2981
        # rcond can't be complex
2982
        rcond = torch.full((), 1j, device=device)
2983
        with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"):
2984
            torch.linalg.pinv(a, rcond=rcond)
2985

2986
        # atol can't be complex
2987
        atol = torch.full((), 1j, device=device)
2988
        with self.assertRaisesRegex(RuntimeError, "atol tensor of complex type is not supported"):
2989
            torch.linalg.pinv(a, atol=atol)
2990

2991
        # rtol can't be complex
2992
        rtol = torch.full((), 1j, device=device)
2993
        with self.assertRaisesRegex(RuntimeError, "rtol tensor of complex type is not supported"):
2994
            torch.linalg.pinv(a, rtol=rtol)
2995

2996
    @skipCUDAIfNoMagmaAndNoCusolver
2997
    @skipCPUIfNoLapack
2998
    @dtypes(*floating_and_complex_types())
2999
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
3000
    def test_inv_errors_and_warnings(self, device, dtype):
3001
        # inv expects batches of square matrices as input
3002
        a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
3003
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
3004
            torch.linalg.inv(a)
3005

3006
        # inv requires the input to be at least 2 dimensional tensor
3007
        a = torch.randn(2, device=device, dtype=dtype)
3008
        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
3009
            torch.linalg.inv(a)
3010

3011
        # if input is not invertible, RuntimeError is raised mentioning the first non-invertible batch
3012
        def run_test_singular_input(batch_dim, n):
3013
            a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
3014
            a[n, -1, -1] = 0
3015
            with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
3016
                torch.linalg.inv(a)
3017

3018
        for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
3019
            run_test_singular_input(*params)
3020

3021
        # dtypes should match
3022
        a = torch.eye(2, dtype=dtype, device=device)
3023
        out = torch.empty(0, dtype=torch.int, device=device)
3024
        with self.assertRaisesRegex(RuntimeError, "but got int instead"):
3025
            torch.linalg.inv(a, out=out)
3026

3027
        # device should match
3028
        if torch.cuda.is_available():
3029
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3030
            out = torch.empty(0, device=wrong_device, dtype=dtype)
3031
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3032
                torch.linalg.inv(a, out=out)
3033

3034
        # if out tensor with wrong shape is passed a warning is given
3035
        with warnings.catch_warnings(record=True) as w:
3036
            a = torch.eye(2, dtype=dtype, device=device)
3037
            out = torch.empty(1, dtype=dtype, device=device)
3038
            # Trigger warning
3039
            torch.linalg.inv(a, out=out)
3040
            # Check warning occurs
3041
            self.assertEqual(len(w), 1)
3042
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3043

3044
        # if out tensor in batched column major format but with wrong a warning is given
3045
        with warnings.catch_warnings(record=True) as w:
3046
            a = torch.eye(2, dtype=dtype, device=device)
3047
            out = torch.empty(3, 3, dtype=dtype, device=device)
3048
            out = out.mT.clone(memory_format=torch.contiguous_format)
3049
            out = out.mT
3050
            self.assertTrue(out.mT.is_contiguous())
3051
            # Trigger warning
3052
            torch.linalg.inv(a, out=out)
3053
            # Check warning occurs
3054
            self.assertEqual(len(w), 1)
3055
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3056

3057
    def solve_test_helper(self, A_dims, b_dims, device, dtype):
3058
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3059
        make_A = partial(make_fullrank, device=device, dtype=dtype)
3060

3061
        b = torch.randn(*b_dims, dtype=dtype, device=device)
3062
        A = make_A(*A_dims)
3063
        return b, A
3064

3065
    @skipCUDAIfNoMagma
3066
    @skipCPUIfNoLapack
3067
    @dtypes(*floating_and_complex_types())
3068
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
3069
    def test_solve(self, device, dtype):
3070
        def run_test(n, batch, rhs):
3071
            A_dims = (*batch, n, n)
3072
            b_dims = (*batch, n, *rhs)
3073
            b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
3074

3075
            # Correctness test
3076
            x = torch.linalg.solve(A, b)
3077
            if rhs == ():
3078
                Ax = np.matmul(A.cpu(), x.unsqueeze(-1).cpu())
3079
                Ax.squeeze_(-1)
3080
            else:
3081
                Ax = np.matmul(A.cpu(), x.cpu())
3082
            self.assertEqual(b.expand_as(Ax), Ax)
3083

3084
            # Check against NumPy
3085
            expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy())
3086
            self.assertEqual(x, expected)
3087

3088
        batches = [(), (0, ), (3, ), (2, 3)]
3089
        ns = [0, 5, 32]
3090
        nrhs = [(), (1, ), (5, )]
3091
        for n, batch, rhs in itertools.product(ns, batches, nrhs):
3092
            run_test(n, batch, rhs)
3093

3094
    @skipCUDAIfNoMagmaAndNoCusolver
3095
    @skipCPUIfNoLapack
3096
    @dtypes(*floating_and_complex_types())
3097
    def test_solve_batched_broadcasting(self, device, dtype):
3098
        from numpy.linalg import solve
3099

3100
        def run_test(A_dims, B_dims):
3101
            A_matrix_size = A_dims[-1]
3102
            A_batch_dims = A_dims[:-2]
3103
            B, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), B_dims, device, dtype)
3104
            actual = torch.linalg.solve(A, B)
3105
            expected = solve(A.cpu().numpy(), B.cpu().numpy())
3106
            self.assertEqual(actual, expected)
3107

3108
        # test against numpy.linalg.solve
3109
        run_test((5, 5), (2, 0, 5, 3))  # broadcasting with 0 batch dim
3110
        run_test((2, 0, 5, 5), (5, 3))  # broadcasting with 0 batch dim
3111
        run_test((2, 1, 3, 4, 4), (4, 6))  # broadcasting B
3112
        run_test((4, 4), (2, 1, 3, 4, 2))  # broadcasting A
3113
        run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))  # broadcasting A & B
3114

3115
    @skipCUDAIfNoMagma
3116
    @skipCPUIfNoLapack
3117
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3118
    @precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
3119
    def test_tensorsolve(self, device, dtype):
3120
        def run_test(a_shape, dims):
3121
            a = torch.randn(a_shape, dtype=dtype, device=device)
3122
            b = torch.randn(a_shape[:2], dtype=dtype, device=device)
3123
            result = torch.linalg.tensorsolve(a, b, dims=dims)
3124
            expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
3125
            self.assertEqual(result, expected)
3126

3127
            # check the out= variant
3128
            out = torch.empty_like(result)
3129
            ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out)
3130
            self.assertEqual(ans, out)
3131
            self.assertEqual(ans, result)
3132

3133
        a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
3134
        dims = [None, (0, 2)]
3135
        for a_shape, d in itertools.product(a_shapes, dims):
3136
            run_test(a_shape, d)
3137

3138
    @skipCUDAIfNoMagma
3139
    @skipCPUIfNoLapack
3140
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3141
    def test_tensorsolve_empty(self, device, dtype):
3142
        # Check for empty inputs. NumPy does not work for these cases.
3143
        a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3144
        b = torch.empty(a.shape[:2], dtype=dtype, device=device)
3145
        x = torch.linalg.tensorsolve(a, b)
3146
        self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b)
3147

3148
    @skipCUDAIfNoMagma
3149
    @skipCPUIfNoLapack
3150
    @dtypes(torch.float32)
3151
    def test_tensorsolve_errors_and_warnings(self, device, dtype):
3152
        # tensorsolve expects the input that can be reshaped to a square matrix
3153
        a = torch.eye(2 * 3 * 4, dtype=dtype, device=device).reshape((2 * 3, 4, 2, 3, 4))
3154
        b = torch.randn(8, 4, dtype=dtype, device=device)
3155
        self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape))
3156
        with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'):
3157
            torch.linalg.tensorsolve(a, b)
3158

3159
        # if non-empty out tensor with wrong shape is passed a warning is given
3160
        out = torch.empty_like(a)
3161
        b = torch.randn(6, 4, dtype=dtype, device=device)
3162
        with warnings.catch_warnings(record=True) as w:
3163
            # Trigger warning
3164
            torch.linalg.tensorsolve(a, b, out=out)
3165
            # Check warning occurs
3166
            self.assertEqual(len(w), 1)
3167
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3168

3169
        # dtypes should be safely castable
3170
        out = torch.empty_like(a).to(torch.int)
3171
        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3172
            torch.linalg.tensorsolve(a, b, out=out)
3173

3174
        # device should match
3175
        if torch.cuda.is_available():
3176
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3177
            out = torch.empty(0, dtype=dtype, device=wrong_device)
3178
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3179
                torch.linalg.tensorsolve(a, b, out=out)
3180

3181
    @skipCUDAIfNoMagma
3182
    @skipCPUIfNoLapack
3183
    @dtypes(*floating_and_complex_types())
3184
    @precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
3185
    def test_tensorinv(self, device, dtype):
3186

3187
        def run_test(a_shape, ind):
3188
            a = torch.randn(a_shape, dtype=dtype, device=device)
3189
            a_numpy = a.cpu().numpy()
3190
            result = torch.linalg.tensorinv(a, ind=ind)
3191
            expected = np.linalg.tensorinv(a_numpy, ind=ind)
3192
            self.assertEqual(result, expected)
3193

3194
            # check the out= variant
3195
            out = torch.empty_like(result)
3196
            ans = torch.linalg.tensorinv(a, ind=ind, out=out)
3197
            self.assertEqual(ans, out)
3198
            self.assertEqual(ans, result)
3199

3200
        # compare to NumPy output
3201
        run_test((12, 3, 4), ind=1)
3202
        run_test((3, 8, 24), ind=2)
3203
        run_test((18, 3, 3, 2), ind=1)
3204
        run_test((1, 4, 2, 2), ind=2)
3205
        run_test((2, 3, 5, 30), ind=3)
3206
        run_test((24, 2, 2, 3, 2), ind=1)
3207
        run_test((3, 4, 2, 3, 2), ind=2)
3208
        run_test((1, 2, 3, 2, 3), ind=3)
3209
        run_test((3, 2, 1, 2, 12), ind=4)
3210

3211
    @skipMeta  # See https://github.com/pytorch/pytorch/issues/53739
3212
    @skipCUDAIfNoMagma
3213
    @skipCPUIfNoLapack
3214
    @dtypes(*floating_and_complex_types())
3215
    def test_tensorinv_empty(self, device, dtype):
3216
        for ind in range(1, 4):
3217
            # Check for empty inputs. NumPy does not work for these cases.
3218
            a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3219
            a_inv = torch.linalg.tensorinv(a, ind=ind)
3220
            self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind])
3221

3222
    @skipMeta  # See https://github.com/pytorch/pytorch/issues/53739
3223
    @skipCUDAIfNoMagma
3224
    @skipCPUIfNoLapack
3225
    @dtypes(*floating_and_complex_types())
3226
    def test_tensorinv_errors_and_warnings(self, device, dtype):
3227

3228
        def check_shape(a_shape, ind):
3229
            # tensorinv requires the input to satisfy
3230
            # prod(a.shape[ind:]) == prod(a.shape[:ind])
3231
            a = torch.randn(a_shape, dtype=dtype, device=device)
3232
            with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"):
3233
                torch.linalg.tensorinv(a, ind=ind)
3234

3235
        def check_ind(a_shape, ind):
3236
            a = torch.randn(a_shape, dtype=dtype, device=device)
3237
            with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"):
3238
                torch.linalg.tensorinv(a, ind=ind)
3239

3240
        def check_out(a_shape, ind):
3241
            # if non-empty out tensor with wrong shape is passed a warning is given
3242
            a = torch.randn(a_shape, dtype=dtype, device=device)
3243
            out = torch.empty_like(a)
3244
            with warnings.catch_warnings(record=True) as w:
3245
                # Trigger warning
3246
                torch.linalg.tensorinv(a, ind=ind, out=out)
3247
                # Check warning occurs
3248
                self.assertEqual(len(w), 1)
3249
                self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3250

3251
            # dtypes should be safely castable
3252
            out = torch.empty(0, dtype=torch.int, device=device)
3253
            with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3254
                torch.linalg.tensorinv(a, ind=ind, out=out)
3255

3256
            # device should match
3257
            if torch.cuda.is_available():
3258
                wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3259
                out = torch.empty(0, dtype=dtype, device=wrong_device)
3260
                with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3261
                    torch.linalg.tensorinv(a, ind=ind, out=out)
3262

3263
        # test for invalid shape
3264
        check_shape((2, 3, 4), ind=1)
3265
        check_shape((1, 2, 3, 4), ind=3)
3266

3267
        # test for invalid ind
3268
        check_ind((12, 3, 4), ind=-1)
3269
        check_ind((18, 3, 3, 2), ind=0)
3270

3271
        # test for invalid out tensor
3272
        check_out((12, 3, 4), ind=1)
3273
        check_out((3, 8, 24), ind=2)
3274

3275
    @skipCUDAIfNoMagma
3276
    @skipCPUIfNoLapack
3277
    @dtypes(*floating_and_complex_types())
3278
    def test_tensorinv_singular_input(self, device, dtype):
3279

3280
        def check_singular_input(a_shape, ind):
3281
            prod_ind_end = np.prod(a_shape[ind:])
3282
            a = torch.eye(prod_ind_end, dtype=dtype, device=device)
3283
            a[-1, -1] = 0   # Now `a` is singular
3284
            a = a.reshape(a_shape)
3285
            with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"):
3286
                torch.linalg.tensorinv(a, ind=ind)
3287

3288
        # test for non-invertible input
3289
        check_singular_input((12, 3, 4), ind=1)
3290
        check_singular_input((3, 6, 18), ind=2)
3291

3292
    def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
3293
        def check(x, y):
3294
            # Compare with numpy
3295
            res = torch_fn(x, y)
3296
            if x.dtype == torch.bfloat16:
3297
                ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy())))
3298
            else:
3299
                ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy())))
3300
            if res.dtype == torch.bfloat16:
3301
                self.assertEqual(res.cpu(), ref.bfloat16())
3302
            else:
3303
                self.assertEqual(res.cpu(), ref)
3304

3305
            # Test out variant
3306
            out = torch.empty_like(res)
3307
            torch_fn(x, y, out=out)
3308
            self.assertEqual(out, res)
3309

3310
        # Empty
3311
        x = torch.tensor([], dtype=dtype, device=device)
3312
        y = torch.tensor([], dtype=dtype, device=device)
3313
        check(x, y)
3314

3315
        # Contiguous
3316
        x = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3317
        y = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3318
        check(x, y)
3319

3320
        # 0 strided
3321
        y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000)
3322
        check(x, y)
3323

3324
        # 2 strided
3325
        check(x[::2], y[::2])
3326

3327
    @dtypes(torch.float, torch.cfloat, torch.bfloat16, torch.float16)
3328
    @dtypesIfCUDA(torch.float, torch.cfloat)
3329
    @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0})
3330
    def test_dot_vs_numpy(self, device, dtype):
3331
        self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)
3332

3333
    @dtypes(torch.float, torch.cfloat)
3334
    @precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
3335
    def test_vdot_vs_numpy(self, device, dtype):
3336
        self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)
3337

3338
    def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
3339
        def check(x, y, regex):
3340
            with self.assertRaisesRegex(RuntimeError, regex):
3341
                torch_fn(x, y)
3342

3343
        if complex_dtypes:
3344
            x = torch.randn(1, dtype=torch.cfloat, device=device)
3345
            y = torch.randn(3, dtype=torch.cdouble, device=device)
3346
        else:
3347
            x = torch.randn(1, dtype=torch.float, device=device)
3348
            y = torch.randn(3, dtype=torch.double, device=device)
3349

3350
        check(x, y, 'dot : expected both vectors to have same dtype')
3351
        check(x.reshape(1, 1), y, '1D tensors expected')
3352
        check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size')
3353

3354
        if self.device_type != 'cpu':
3355
            x_cpu = x.expand(3).cpu()
3356
            check(x_cpu, y.to(x.dtype), 'Expected all tensors to be on the same device')
3357

3358
    @onlyNativeDeviceTypes
3359
    def test_vdot_invalid_args(self, device):
3360
        self._test_dot_vdot_invalid_args(device, torch.vdot)
3361
        self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)
3362

3363
    @onlyNativeDeviceTypes
3364
    def test_dot_invalid_args(self, device):
3365
        self._test_dot_vdot_invalid_args(device, torch.dot)
3366
        self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)
3367

3368
    @skipCUDAIfNoMagma
3369
    @skipCPUIfNoLapack
3370
    @dtypes(*floating_and_complex_types())
3371
    def test_matrix_rank(self, device, dtype):
3372
        matrix_rank = torch.linalg.matrix_rank
3373

3374
        def run_test(shape0, shape1, batch):
3375
            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3376
            rank_a = matrix_rank(a)
3377

3378
            self.assertEqual(rank_a, matrix_rank(a.mH))
3379
            aaH = torch.matmul(a, a.mH)
3380
            rank_aaH = matrix_rank(aaH)
3381
            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3382
            self.assertEqual(rank_aaH, rank_aaH_hermitian)
3383
            aHa = torch.matmul(a.mH, a)
3384
            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3385

3386
            # check against NumPy
3387
            self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
3388
            self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
3389

3390
            self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
3391
            self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
3392

3393
            # hermitian flag for NumPy was added in 1.14.0
3394
            if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
3395
                self.assertEqual(rank_aaH_hermitian,
3396
                                 np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
3397
                self.assertEqual(matrix_rank(aaH, 0.01, True),
3398
                                 np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
3399

3400
            # check out= variant
3401
            out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
3402
            ans = matrix_rank(a, out=out)
3403
            self.assertEqual(ans, out)
3404
            self.assertEqual(ans, rank_a)
3405

3406
        shapes = (3, 13)
3407
        batches = ((), (0, ), (4, ), (3, 5, ))
3408
        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3409
            run_test(shape0, shape1, batch)
3410

3411
    @skipCUDAIfNoMagma
3412
    @skipCPUIfNoLapack
3413
    @dtypes(*floating_and_complex_types())
3414
    def test_matrix_rank_atol(self, device, dtype):
3415

3416
        def run_test_atol(shape0, shape1, batch):
3417
            a = make_tensor((*batch, shape0, shape1), dtype=dtype, device=device)
3418
            # Check against NumPy output
3419
            # Test float tol, and specific value for each matrix
3420
            tolerances = [float(torch.rand(1)), ]
3421
            # Test different types of tol tensor
3422
            for tol_type in all_types():
3423
                tolerances.append(make_tensor(a.shape[:-2], dtype=tol_type, device=device, low=0))
3424
            # Test broadcasting of tol
3425
            if a.ndim > 2:
3426
                tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0))
3427
            for tol in tolerances:
3428
                actual = torch.linalg.matrix_rank(a, atol=tol)
3429
                actual_tol = torch.linalg.matrix_rank(a, tol=tol)
3430
                self.assertEqual(actual, actual_tol)
3431
                numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy()
3432
                expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol)
3433
                self.assertEqual(actual, expected)
3434

3435
        shapes = (3, 13)
3436
        batches = ((), (0, ), (4, ), (3, 5, ))
3437
        for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3438
            run_test_atol(shape0, shape1, batch)
3439

3440
    @skipCUDAIfNoMagma
3441
    @skipCPUIfNoLapack
3442
    @dtypes(torch.float64)
3443
    def test_matrix_rank_atol_rtol(self, device, dtype):
3444
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3445
        make_arg = partial(make_fullrank, device=device, dtype=dtype)
3446

3447
        # creates a matrix with singular values rank=n and singular values in range [2/3, 3/2]
3448
        # the singular values are 1 + 1/2, 1 - 1/3, 1 + 1/4, 1 - 1/5, ...
3449
        n = 9
3450
        a = make_arg(n, n)
3451

3452
        # test float and tensor variants
3453
        for tol_value in [0.81, torch.tensor(0.81, device=device)]:
3454
            # using rtol (relative tolerance) takes into account the largest singular value (1.5 in this case)
3455
            result = torch.linalg.matrix_rank(a, rtol=tol_value)
3456
            self.assertEqual(result, 2)  # there are 2 singular values above 1.5*0.81 = 1.215
3457

3458
            # atol is used directly to compare with singular values
3459
            result = torch.linalg.matrix_rank(a, atol=tol_value)
3460
            self.assertEqual(result, 7)  # there are 7 singular values above 0.81
3461

3462
            # when both are specified the maximum tolerance is used
3463
            result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value)
3464
            self.assertEqual(result, 2)  # there are 2 singular values above max(0.81, 1.5*0.81)
3465

3466
    @skipCUDAIfNoMagma
3467
    @skipCPUIfNoLapack
3468
    @skipCUDAVersionIn([(11, 6), (11, 7)])  # https://github.com/pytorch/pytorch/issues/75391
3469
    @dtypes(*floating_and_complex_types())
3470
    def test_matrix_rank_empty(self, device, dtype):
3471
        matrix_rank = torch.linalg.matrix_rank
3472

3473
        # NumPy doesn't work for input with no elements
3474
        def run_test(shape0, shape1, batch):
3475
            a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3476
            rank_a = matrix_rank(a)
3477
            expected = torch.zeros(batch, dtype=torch.int64, device=device)
3478

3479
            self.assertEqual(rank_a, matrix_rank(a.mH))
3480

3481
            aaH = torch.matmul(a, a.mH)
3482
            rank_aaH = matrix_rank(aaH)
3483
            rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3484
            self.assertEqual(rank_aaH, rank_aaH_hermitian)
3485

3486
            aHa = torch.matmul(a.mH, a)
3487
            self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3488

3489
            self.assertEqual(rank_a, expected)
3490
            self.assertEqual(matrix_rank(a, 0.01), expected)
3491

3492
            self.assertEqual(rank_aaH, expected)
3493
            self.assertEqual(matrix_rank(aaH, 0.01), expected)
3494

3495
            self.assertEqual(rank_aaH_hermitian, expected)
3496
            self.assertEqual(matrix_rank(aaH, 0.01, True), expected)
3497

3498
        batches = ((), (4, ), (3, 5, ))
3499
        for batch in batches:
3500
            run_test(0, 0, batch)
3501
            run_test(0, 3, batch)
3502
            run_test(3, 0, batch)
3503

3504
    @skipCUDAIfNoMagma
3505
    @skipCPUIfNoLapack
3506
    @dtypes(*floating_and_complex_types())
3507
    def test_matrix_rank_out_errors_and_warnings(self, device, dtype):
3508
        # dtypes should be safely castable
3509
        a = torch.eye(2, dtype=dtype, device=device)
3510
        out = torch.empty(0, dtype=torch.bool, device=device)
3511
        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Bool"):
3512
            torch.linalg.matrix_rank(a, out=out)
3513

3514
        # device should match
3515
        if torch.cuda.is_available():
3516
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3517
            out = torch.empty(0, dtype=dtype, device=wrong_device)
3518
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3519
                torch.linalg.matrix_rank(a, out=out)
3520

3521
        # if out tensor with wrong shape is passed a warning is given
3522
        with warnings.catch_warnings(record=True) as w:
3523
            out = torch.empty(3, dtype=dtype, device=device)
3524
            # Trigger warning
3525
            torch.linalg.matrix_rank(a, out=out)
3526
            # Check warning occurs
3527
            self.assertEqual(len(w), 1)
3528
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3529

3530
    @skipCUDAIfNoMagma
3531
    @skipCPUIfNoLapack
3532
    @dtypes(*floating_and_complex_types())
3533
    def test_matrix_rank_basic(self, device, dtype):
3534
        matrix_rank = torch.linalg.matrix_rank
3535

3536
        a = torch.eye(10, dtype=dtype, device=device)
3537
        self.assertEqual(matrix_rank(a).item(), 10)
3538
        self.assertEqual(matrix_rank(a, hermitian=True).item(), 10)
3539

3540
        a[5, 5] = 0
3541
        self.assertEqual(matrix_rank(a).item(), 9)
3542
        self.assertEqual(matrix_rank(a, hermitian=True).item(), 9)
3543

3544
    @onlyNativeDeviceTypes
3545
    @dtypes(torch.double)
3546
    # This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for.
3547
    def test_chain_matmul(self, device, dtype):
3548
        # chain_matmul accepts a single input tensor while multi_dot does not
3549
        t = make_tensor((2, 2), dtype=dtype, device=device)
3550
        self.assertEqual(t, torch.chain_matmul(t))
3551
        with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
3552
            torch.chain_matmul()
3553

3554
        # chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
3555
        # be either 1D or 2D
3556
        with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
3557
            torch.chain_matmul(make_tensor(1, dtype=dtype, device=device), make_tensor(1, dtype=dtype, device=device))
3558

3559
    @onlyNativeDeviceTypes
3560
    @dtypes(torch.double, torch.cdouble)
3561
    def test_multi_dot(self, device, dtype):
3562
        def check(*shapes):
3563
            tensors = [make_tensor(shape, dtype=dtype, device=device) for shape in shapes]
3564
            np_arrays = [tensor.cpu().numpy() for tensor in tensors]
3565
            res = torch.linalg.multi_dot(tensors).cpu()
3566
            ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays)))
3567
            self.assertEqual(res, ref)
3568

3569
        # test for inputs with empty dimensions
3570
        check([0], [0])
3571
        check([2], [2, 0])
3572
        check([1, 0], [0])
3573
        check([0, 2], [2, 1])
3574
        check([2, 2], [2, 0])
3575
        check([2, 0], [0, 3])
3576
        check([0, 0], [0, 1])
3577
        check([4, 2], [2, 0], [0, 3], [3, 2])
3578

3579
        # test variable output shapes
3580
        check([2], [2])
3581
        check([1, 2], [2])
3582
        check([2], [2, 1])
3583
        check([1, 2], [2, 1])
3584
        check([3, 2], [2, 4])
3585

3586
        # test multiple input tensors
3587
        check([3], [3, 4], [4, 2], [2, 5], [5])
3588
        check([1, 2], [2, 2], [2, 3], [3, 1])
3589

3590
        # test large tensors
3591
        check([10, 100], [100, 5], [5, 50])
3592
        check([10, 20], [20, 30], [30, 5])
3593

3594
    @onlyNativeDeviceTypes
3595
    @dtypes(torch.float)
3596
    def test_multi_dot_errors(self, device, dtype):
3597
        def check(tensors, out, msg):
3598
            with self.assertRaisesRegex(RuntimeError, msg):
3599
                torch.linalg.multi_dot(tensors, out=out)
3600

3601
        a = make_tensor(2, dtype=dtype, device=device)
3602

3603
        check([], None, "expected at least 2 tensors")
3604
        check([a], None, "expected at least 2 tensors")
3605

3606
        check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D")
3607
        check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D")
3608

3609
        check([a, a, a], None, "tensor 1 must be 2D")
3610
        check([a, make_tensor((2, 2, 2), dtype=dtype, device=device), a], None, "tensor 1 must be 2D")
3611

3612
        check([a, make_tensor(2, dtype=torch.double, device=device)], None, "all tensors must have be the same dtype")
3613
        check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype")
3614

3615
        if self.device_type == 'cuda':
3616
            check([a, make_tensor(2, dtype=dtype, device="cpu")], None, "all tensors must be on the same device")
3617
            check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device")
3618

3619
        check([a, make_tensor(3, dtype=dtype, device=device)], None, "cannot be multiplied")
3620
        check([a, make_tensor((3, 2), dtype=dtype, device=device), a], None, "cannot be multiplied")
3621

3622
    @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
3623
    @skipCUDAIfNoCusolver
3624
    @skipCPUIfNoLapack
3625
    @dtypes(*floating_and_complex_types())
3626
    def test_qr(self, device, dtype):
3627
        def run_test(tensor_dims, some):
3628
            A = torch.randn(*tensor_dims, dtype=dtype, device=device)
3629
            Q, R = torch.qr(A, some=some)
3630

3631
            # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n)
3632
            m, n = tensor_dims[-2:]
3633
            n_columns = m if (not some) and m > n else min(m, n)
3634
            self.assertEqual(Q.size(-2), m)
3635
            self.assertEqual(R.size(-1), n)
3636
            self.assertEqual(Q.size(-1), n_columns)
3637

3638
            A_ = A.cpu().numpy()
3639
            Q_ = Q.cpu().numpy()
3640
            R_ = R.cpu().numpy()
3641

3642
            # Check1: A = QR
3643
            self.assertEqual(A_, np.matmul(Q_, R_))
3644

3645
            # Check2: A = QR (with out)
3646
            Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan)
3647
            torch.qr(A, some=some, out=(Q_out, R_out))
3648
            Q_out_ = Q_out.cpu().numpy()
3649
            R_out_ = R_out.cpu().numpy()
3650
            self.assertEqual(A_, np.matmul(Q_out_, R_out_))
3651

3652
            # Check3: Q == Q_out, R == R_out
3653
            self.assertEqual(Q_, Q_out_)
3654
            self.assertEqual(R_, R_out_)
3655

3656
            # Check4: Q^{T}Q = I, triu(R) = R
3657
            eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy()
3658
            self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye)
3659
            self.assertEqual(R.triu(), R)
3660

3661
        tensor_dims_list = [(0, 5), (0, 0), (5, 0),  # Empty Tensors
3662
                            (2, 1, 0, 5), (2, 1, 0, 0), (2, 1, 5, 0), (2, 0, 5, 5),  # Batched empty Tensors
3663
                            (3, 5), (5, 5), (5, 3),  # Single matrix
3664
                            (7, 3, 5), (7, 5, 5), (7, 5, 3),  # 3-dim Tensors
3665
                            (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)]  # 4-dim Tensors
3666
        for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]):
3667
            run_test(tensor_dims, some)
3668

3669
    @skipCUDAIfNoCusolver
3670
    @skipCPUIfNoLapack
3671
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3672
    def test_qr_vs_numpy(self, device, dtype):
3673
        """
3674
        test torch.linalg.qr vs numpy.linalg.qr
3675
        """
3676
        sizes_to_test = [
3677
            (7, 5),
3678
            (5, 7),
3679
            (5, 0),    # empty
3680
            (0, 5),    # empty
3681
        ]
3682
        for size in sizes_to_test:
3683
            t = torch.randn(size, device=device, dtype=dtype)
3684
            np_t = t.cpu().numpy()
3685
            for mode in ['reduced', 'complete']:
3686
                exp_q, exp_r = np.linalg.qr(np_t, mode=mode)
3687
                q, r = torch.linalg.qr(t, mode=mode)
3688
                self.assertEqual(q, exp_q)
3689
                self.assertEqual(r, exp_r)
3690
            #
3691
            # for mode='r' we need a special logic because numpy returns only r
3692
            exp_r = np.linalg.qr(np_t, mode='r')
3693
            q, r = torch.linalg.qr(t, mode='r')
3694
            # check that q is empty
3695
            self.assertEqual(q.shape, (0,))
3696
            self.assertEqual(q.dtype, t.dtype)
3697
            self.assertEqual(q.device, t.device)
3698
            # check r
3699
            self.assertEqual(r, exp_r)
3700

3701
    @skipCUDAIfNoCusolver
3702
    @skipCPUIfNoLapack
3703
    @dtypes(torch.float)
3704
    def test_linalg_qr_autograd_errors(self, device, dtype):
3705
        # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
3706
        # without 'q' you cannot compute the backward pass. Check that
3707
        # linalg_qr_backward complains cleanly in that case.
3708
        inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
3709
        q, r = torch.linalg.qr(inp, mode='r')
3710
        self.assertEqual(q.shape, (0,))  # empty tensor
3711
        b = torch.sum(r)
3712
        with self.assertRaisesRegex(RuntimeError,
3713
                                    "The derivative of linalg.qr depends on Q"):
3714
            b.backward()
3715
        inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
3716
        q, r = torch.linalg.qr(inp, mode='complete')
3717
        b = torch.sum(r)
3718
        with self.assertRaisesRegex(RuntimeError,
3719
                                    "The QR decomposition is not differentiable when mode='complete' and nrows > ncols"):
3720
            b.backward()
3721

3722
    @skipCUDAIfNoCusolver
3723
    @skipCPUIfNoLapack
3724
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3725
    def test_qr_batched(self, device, dtype):
3726
        """
3727
        test torch.linalg.qr vs numpy.linalg.qr. We need some special logic
3728
        because numpy does not support batched qr
3729
        """
3730
        def np_qr_batched(a, mode):
3731
            """poor's man batched version of np.linalg.qr"""
3732
            all_q = []
3733
            all_r = []
3734
            for matrix in a:
3735
                result = np.linalg.qr(matrix, mode=mode)
3736
                if mode == 'r':
3737
                    all_r.append(result)
3738
                else:
3739
                    q, r = result
3740
                    all_q.append(q)
3741
                    all_r.append(r)
3742
            if mode == 'r':
3743
                return np.array(all_r)
3744
            else:
3745
                return np.array(all_q), np.array(all_r)
3746

3747
        t = torch.randn((3, 7, 5), device=device, dtype=dtype)
3748
        np_t = t.cpu().numpy()
3749
        for mode in ['reduced', 'complete']:
3750
            exp_q, exp_r = np_qr_batched(np_t, mode=mode)
3751
            q, r = torch.linalg.qr(t, mode=mode)
3752
            self.assertEqual(q, exp_q)
3753
            self.assertEqual(r, exp_r)
3754
        # for mode='r' we need a special logic because numpy returns only r
3755
        exp_r = np_qr_batched(np_t, mode='r')
3756
        q, r = torch.linalg.qr(t, mode='r')
3757
        # check that q is empty
3758
        self.assertEqual(q.shape, (0,))
3759
        self.assertEqual(q.dtype, t.dtype)
3760
        self.assertEqual(q.device, t.device)
3761
        # check r
3762
        self.assertEqual(r, exp_r)
3763

3764
    @skipCUDAIfNoCusolver
3765
    @skipCPUIfNoLapack
3766
    @dtypes(torch.float)
3767
    def test_qr_error_cases(self, device, dtype):
3768
        t1 = torch.randn(5, device=device, dtype=dtype)
3769
        with self.assertRaisesRegex(RuntimeError, 'linalg.qr: The input tensor A must have at least 2 dimensions.'):
3770
            torch.linalg.qr(t1)
3771
        t2 = torch.randn((5, 7), device=device, dtype=dtype)
3772
        with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
3773
            torch.linalg.qr(t2, mode='hello')
3774

3775
    def _check_einsum(self, *args, np_args=None):
3776
        if np_args is None:
3777
            np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args]
3778
        ref = np.einsum(*np_args)
3779
        res = torch.einsum(*args)
3780
        self.assertEqual(ref, res)
3781

3782
        # Check that the other variations for opt_einsum work too
3783
        if TEST_OPT_EINSUM:
3784
            with opt_einsum.flags(enabled=False):
3785
                res = torch.einsum(*args)
3786
                self.assertEqual(ref, res)
3787

3788
            with opt_einsum.flags(enabled=True, strategy='greedy'):
3789
                res = torch.einsum(*args)
3790
                self.assertEqual(ref, res)
3791

3792
            with opt_einsum.flags(enabled=True, strategy='optimal'):
3793
                res = torch.einsum(*args)
3794
                self.assertEqual(ref, res)
3795

3796
    @dtypes(torch.double, torch.cdouble)
3797
    def test_einsum(self, device, dtype):
3798
        # Test cases from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f
3799
        x = make_tensor((5,), dtype=dtype, device=device)
3800
        y = make_tensor((7,), dtype=dtype, device=device)
3801
        A = make_tensor((3, 5), dtype=dtype, device=device)
3802
        B = make_tensor((2, 5), dtype=dtype, device=device)
3803
        C = make_tensor((2, 3, 5), dtype=dtype, device=device)
3804
        D = make_tensor((2, 5, 7), dtype=dtype, device=device)
3805
        E = make_tensor((7, 9), dtype=dtype, device=device)
3806
        F = make_tensor((2, 3, 3, 5), dtype=dtype, device=device)
3807
        G = make_tensor((5, 4, 6), dtype=dtype, device=device)
3808
        H = make_tensor((4, 4), dtype=dtype, device=device)
3809
        I = make_tensor((2, 3, 2), dtype=dtype, device=device)
3810

3811
        # Vector operations
3812
        self._check_einsum('i->', x)                     # sum
3813
        self._check_einsum('i,i->', x, x)                # dot
3814
        self._check_einsum('i,i->i', x, x)               # vector element-wisem mul
3815
        self._check_einsum('i,j->ij', x, y)              # outer
3816

3817
        # Matrix operations
3818
        self._check_einsum("ij->ji", A)                  # transpose
3819
        self._check_einsum("ij->j", A)                   # row sum
3820
        self._check_einsum("ij->i", A)                   # col sum
3821
        self._check_einsum("ij,ij->ij", A, A)            # matrix element-wise mul
3822
        self._check_einsum("ij,j->i", A, x)              # matrix vector multiplication
3823
        self._check_einsum("ij,kj->ik", A, B)            # matmul
3824
        self._check_einsum("ij,ab->ijab", A, E)          # matrix outer product
3825

3826
        # Tensor operations
3827
        self._check_einsum("Aij,Ajk->Aik", C, D)         # batch matmul
3828
        self._check_einsum("ijk,jk->i", C, A)            # tensor matrix contraction
3829
        self._check_einsum("aij,jk->aik", D, E)          # tensor matrix contraction
3830
        self._check_einsum("abCd,dFg->abCFg", F, G)      # tensor tensor contraction
3831
        self._check_einsum("ijk,jk->ik", C, A)           # tensor matrix contraction with double indices
3832
        self._check_einsum("ijk,jk->ij", C, A)           # tensor matrix contraction with double indices
3833
        self._check_einsum("ijk,ik->j", C, B)            # non contiguous
3834
        self._check_einsum("ijk,ik->jk", C, B)           # non contiguous with double indices
3835

3836
        # Test diagonals
3837
        self._check_einsum("ii", H)                      # trace
3838
        self._check_einsum("ii->i", H)                   # diagonal
3839
        self._check_einsum('iji->j', I)                  # non-contiguous trace
3840
        self._check_einsum('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device))
3841

3842
        # Test ellipsis
3843
        self._check_einsum("i...->...", H)
3844
        self._check_einsum("ki,...k->i...", A.t(), B)
3845
        self._check_einsum("k...,jk->...", A.t(), B)
3846
        self._check_einsum('...ik, ...j -> ...ij', C, x)
3847
        self._check_einsum('Bik,k...j->i...j', C, make_tensor((5, 3), dtype=dtype, device=device))
3848
        self._check_einsum('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), dtype=dtype, device=device))
3849

3850
        # torch.bilinear with noncontiguous tensors
3851
        l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3852
        r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3853
        w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3854
        self._check_einsum("bn,anm,bm->ba", l, w, r)
3855

3856
        # with strided tensors
3857
        self._check_einsum("bn,Anm,bm->bA", l[:, ::2], w[:, ::2, ::2], r[:, ::2])
3858

3859
        # test multiple inputs
3860
        self._check_einsum("...,be,b...,beg,gi,bc...->bi...", A, B, C, D, E, F)
3861

3862
    @dtypes(torch.double, torch.cdouble)
3863
    def test_einsum_sublist_format(self, device, dtype):
3864
        x = make_tensor((5,), dtype=dtype, device=device)
3865
        y = make_tensor((7,), dtype=dtype, device=device)
3866
        A = make_tensor((3, 5), dtype=dtype, device=device)
3867
        B = make_tensor((2, 5), dtype=dtype, device=device)
3868
        C = make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device)
3869

3870
        self._check_einsum(x, [0])
3871
        self._check_einsum(x, [0], [])
3872
        self._check_einsum(x, [0], y, [1], [0, 1])
3873
        self._check_einsum(A, [0, 1], [1, 0])
3874
        self._check_einsum(A, [0, 1], x, [1], [0])
3875
        self._check_einsum(A, [0, 1], B, [2, 1])
3876
        self._check_einsum(A, [0, 1], B, [2, 1], [0, 2])
3877
        self._check_einsum(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis])
3878
        self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0])
3879
        self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis])
3880
        self._check_einsum(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis])
3881

3882
        # torch.bilinear with noncontiguous tensors
3883
        l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3884
        r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3885
        w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3886
        self._check_einsum(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2])
3887

3888
    @dtypes(torch.double, torch.cdouble)
3889
    def test_einsum_random(self, device, dtype):
3890
        def convert_label(label):
3891
            if label == ...:
3892
                return '...'
3893
            elif label < 26:
3894
                return chr(ord('A') + label)
3895
            else:
3896
                return chr(ord('a') + label - 26)
3897

3898
        def convert_sublist(sublist):
3899
            return ''.join(convert_label(label) for label in sublist)
3900

3901
        def test(n=10,                       # how many tests to generate
3902
                 n_labels=5,                 # how many labels available
3903
                 min_ops=1, max_ops=4,       # min and max number of operands per test
3904
                 min_dims=1, max_dims=3,     # min and max number of dimensions per operand
3905
                 min_size=1, max_size=8,     # min and max size of each dimension
3906
                 max_out_dim=3,              # max number of dimensions for the output
3907
                 enable_diagonals=True,      # controls if labels can be repeated for diagonals
3908
                 ellipsis_prob=0.5,          # probability of including ellipsis in operand
3909
                 broadcasting_prob=0.1):     # probability of turning some dim sizes 1 for broadcasting
3910

3911
            all_labels = torch.arange(52)
3912

3913
            assert 0 <= n
3914
            assert 0 <= n_labels < len(all_labels)
3915
            assert 0 < min_ops <= max_ops
3916
            assert 0 <= min_dims <= max_dims
3917
            assert 0 <= min_size <= max_size
3918
            assert 0 <= max_out_dim
3919
            assert enable_diagonals or max_dims <= n_labels
3920

3921
            for _ in range(n):
3922

3923
                # Select a subset of labels for this test and give them random sizes
3924
                possible_labels = all_labels[torch.randperm(len(all_labels))[:n_labels]]
3925
                labels_size = torch.randint_like(all_labels, min_size, max_size + 1)
3926
                ellipsis_shape = torch.randint(min_size, max_size + 1, (max_dims - min_dims,))
3927

3928
                operands = []
3929
                sublists = []
3930

3931
                ell_size = 0
3932
                valid_labels = set()
3933

3934
                # create random input operands
3935
                for _ in range(random.randint(min_ops, max_ops)):
3936
                    n_dim = random.randint(min_dims, max_dims)
3937
                    labels_idx = torch.ones(len(possible_labels)).multinomial(n_dim, enable_diagonals)
3938
                    labels = possible_labels[labels_idx]
3939
                    valid_labels.update(labels.tolist())
3940
                    shape = labels_size[labels]
3941

3942
                    # turn some dimensions to size 1 for testing broadcasting
3943
                    mask = Binomial(probs=broadcasting_prob).sample((n_dim,))
3944
                    broadcast_labels = torch.unique(labels[mask == 1])
3945
                    shape[(labels[..., None] == broadcast_labels).any(-1)] = 1
3946

3947
                    labels = labels.tolist()
3948
                    shape = shape.tolist()
3949

3950
                    # include ellipsis if not all dimensions were assigned a label already
3951
                    if n_dim < max_dims and torch.rand(1) < ellipsis_prob:
3952
                        ell_num_dim = random.randint(1, max_dims - n_dim)
3953
                        ell_size = max(ell_size, ell_num_dim)
3954
                        ell_shape = ellipsis_shape[-ell_num_dim:]
3955
                        # again, turn some dimensions to size 1 for broadcasting
3956
                        mask = Binomial(probs=broadcasting_prob).sample((ell_num_dim,))
3957
                        ell_shape[mask == 1] = 1
3958
                        ell_index = random.randint(0, n_dim)
3959
                        shape[ell_index:ell_index] = ell_shape
3960
                        labels.insert(ell_index, ...)
3961

3962
                    operands.append(make_tensor(shape, dtype=dtype, device=device))
3963
                    sublists.append(labels)
3964

3965
                # NumPy has a bug with the sublist format so for now we compare PyTorch sublist
3966
                # implementation against the equation format implementation of NumPy
3967
                # see https://github.com/numpy/numpy/issues/10926
3968
                np_operands = [op.cpu().numpy() for op in operands]
3969

3970
                # test equation format
3971
                equation = ','.join(convert_sublist(l) for l in sublists)
3972
                self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3973

3974
                # test sublist format
3975
                args = list(itertools.chain.from_iterable(zip(operands, sublists)))
3976
                self._check_einsum(*args, np_args=(equation, *np_operands))
3977

3978
                # generate an explicit output
3979
                out_sublist = []
3980
                num_out_labels = max(0, random.randint(0, min(max_out_dim, len(valid_labels))) - ell_size)
3981
                if num_out_labels > 0:
3982
                    out_labels_idx = torch.ones(len(valid_labels)).multinomial(num_out_labels)
3983
                    out_sublist = torch.tensor(list(valid_labels))[out_labels_idx].tolist()
3984
                out_sublist.insert(random.randint(0, num_out_labels), ...)
3985

3986
                # test equation format with explicit output
3987
                equation += '->' + convert_sublist(out_sublist)
3988
                self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3989

3990
                # test sublist format with explicit output
3991
                args.append(out_sublist)
3992
                self._check_einsum(*args, np_args=(equation, *np_operands))
3993

3994
        test(500)
3995

3996
    def test_einsum_corner_cases(self, device):
3997
        def check(equation, *operands, expected_output):
3998
            tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)
3999
                       else make_tensor(operand, dtype=torch.float32, device=device) for operand in operands]
4000
            output = torch.einsum(equation, tensors)
4001
            self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device))
4002

4003
        # Test equation variantions
4004
        check(' ', 1, expected_output=1)
4005
        check(' -> ', 1, expected_output=1)
4006
        check(' , ', 2, 2, expected_output=4)
4007
        check(' , , ', 2, 2, 2, expected_output=8)
4008
        check(' , -> ', 2, 2, expected_output=4)
4009
        check(' i ', [1], expected_output=[1])
4010
        check(' i -> ', [1], expected_output=1)
4011
        check(' i -> i ', [1], expected_output=[1])
4012
        check(' i , i ', [2], [2], expected_output=4)
4013
        check(' i , i -> i ', [2], [2], expected_output=[4])
4014

4015
        # Test tensors with 0 size dimensions
4016
        check('i', [], expected_output=[])
4017
        check(' i j -> j', [[], []], expected_output=[])
4018
        check('ij->i', [[], []], expected_output=[0., 0.])
4019
        check(' i j k  ,  k  -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []])
4020

4021
        # Test broadcasting
4022
        check('i,j', [2], [1, 2], expected_output=[[2, 4]])
4023
        check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]])
4024

4025
        # Test ellipsis broadcasting
4026
        check('...', 1, expected_output=1)
4027
        check('...->', 1, expected_output=1)
4028
        check('...->...', 1, expected_output=1)
4029
        check('...', [1], expected_output=[1])
4030
        check('...->', [1], expected_output=1)
4031
        check('z...->z', [1], expected_output=[1])
4032
        check('Z...->...Z', [1], expected_output=[1])
4033
        check('...a->', [[2], [4]], expected_output=6)
4034
        check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]])
4035

4036
    def test_einsum_error_cases(self, device):
4037
        def check(*args, regex, exception=RuntimeError):
4038
            with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex):
4039
                torch.einsum(*args)
4040

4041
        x = make_tensor((2,), dtype=torch.float32, device=device)
4042
        y = make_tensor((2, 3), dtype=torch.float32, device=device)
4043

4044
        check('', [], regex=r'at least one operand', exception=ValueError)
4045
        check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis')
4046
        check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found')
4047
        check('1', [x], regex=r'invalid subscript given at index 0')
4048
        check(',', [x], regex=r'fewer operands were provided than specified in the equation')
4049
        check('', [x, x], regex=r'more operands were provided than specified in the equation')
4050
        check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number '
4051
              r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4052
        check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number '
4053
              r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4054
        check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number '
4055
              r'of dimensions \(1\) for operand 0')
4056
        check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found')
4057
        check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)')
4058
        check('a->1', [x], regex=r'invalid subscript given at index 3')
4059
        check('a->aa', [x], regex=r'output subscript a appears more than once in the output')
4060
        check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand')
4061
        check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2')
4062
        check('...,...', [x, y], regex=r'does not broadcast')
4063
        check('a,a', [x, make_tensor((3,), dtype=torch.float32, device=device)], regex=r'does not broadcast')
4064
        check('a, ba', [x, y], regex=r'subscript a has size 3 for operand 1 which does not broadcast with previously'
4065
              r' seen size 2')
4066

4067
        check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4068
        check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4069

4070
    def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False):
4071
        make_arg = partial(make_tensor, dtype=dtype, device=device)
4072
        make_fullrank = partial(make_fullrank_matrices_with_distinct_singular_values, dtype=dtype, device=device)
4073
        b, n, k = shape
4074
        for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8):
4075
            # expand means that we generate a batch of matrices with a stride of zero in the batch dimension
4076
            if (conj_a or conj_b) and not dtype.is_complex:
4077
                continue
4078
            # We just expand on the batch size
4079
            if (expand_a or expand_b) and b == 1:
4080
                continue
4081

4082
            size_a = (b, n, n) if left else (b, k, k)
4083
            size_b = (b, n, k) if not tr_b else (b, k, n)
4084

4085
            # If expand_a or expand_b, we'll expand them to the correct size later
4086
            if b == 1 or expand_a:
4087
                size_a = size_a[1:]
4088
            if b == 1 or expand_b:
4089
                size_b = size_b[1:]
4090

4091
            if well_conditioned:
4092
                PLU = torch.linalg.lu(make_fullrank(*size_a))
4093
                if uni:
4094
                    # A = L from PLU
4095
                    A = PLU[1].transpose(-2, -1).contiguous()
4096
                else:
4097
                    # A = U from PLU
4098
                    A = PLU[2].contiguous()
4099
            else:
4100
                A = make_arg(size_a)
4101
                A.triu_()
4102

4103
            diag = A.diagonal(0, -2, -1)
4104
            if uni:
4105
                diag.fill_(1.)
4106
            else:
4107
                diag[diag.abs() < 1e-6] = 1.
4108

4109
            B = make_arg(size_b)
4110

4111
            if tr_a:
4112
                A.transpose_(-2, -1)
4113
            if tr_b:
4114
                B.transpose_(-2, -1)
4115
            if conj_a:
4116
                A = A.conj()
4117
            if conj_b:
4118
                B = B.conj()
4119
            if expand_a:
4120
                A = A.expand(b, *size_a)
4121
            if expand_b:
4122
                B = B.expand(b, n, k)
4123
            yield A, B, left, not tr_a, uni
4124

4125
    def _test_linalg_solve_triangular(self, A, B, upper, left, uni):
4126
        X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4127
        if left:
4128
            self.assertEqual(A @ X, B)
4129
        else:
4130
            self.assertEqual(X @ A, B)
4131
        out = B
4132
        # B may be expanded
4133
        if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous():
4134
            out = B.clone()
4135
        torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out)
4136
        self.assertEqual(X, out)
4137

4138
    # Tolerances dictated by widest acceptable range on CPU before failure
4139
    @dtypes(*floating_and_complex_types())
4140
    @precisionOverride({torch.float32: 1e-3 if TEST_WITH_ROCM else 1e-1,
4141
                        torch.float64: 1e-8,
4142
                        torch.complex64: 1e-1,
4143
                        torch.complex128: 1e-8})
4144
    def test_linalg_solve_triangular(self, device, dtype):
4145
        # This exercises the API + BLAS CPU + batched cuBLAS
4146
        ks = (3, 1, 0)
4147
        ns = (5, 0)
4148
        bs = (1, 2, 0)
4149

4150
        gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4151
        for b, n, k in product(bs, ns, ks):
4152
            for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device, well_conditioned=True):
4153
                self._test_linalg_solve_triangular(A, B, upper, left, uni)
4154

4155
    @slowTest
4156
    @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
4157
    @onlyCUDA
4158
    @skipCUDAIfNoMagma  # Magma needed for the PLU decomposition
4159
    @dtypes(*floating_and_complex_types())
4160
    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4161
                        torch.float64: 1e-8, torch.complex128: 1e-8})
4162
    def test_linalg_solve_triangular_large(self, device, dtype):
4163
        # Exercises magma and cublas
4164
        magma = (9, 513, 1)
4165
        iterative_cublas = (2, 64, 1)
4166

4167
        gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4168
        for shape in (magma, iterative_cublas):
4169
            for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True):
4170
                self._test_linalg_solve_triangular(A, B, upper, left, uni)
4171

4172
    @dtypes(*floating_and_complex_types())
4173
    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4174
                        torch.float64: 1e-8, torch.complex128: 1e-8})
4175
    def test_linalg_solve_triangular_broadcasting(self, device, dtype):
4176
        make_arg = partial(make_tensor, dtype=dtype, device=device)
4177

4178
        sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)),
4179
                 ((2, 1, 3, 4, 4), (4, 6)),
4180
                 ((4, 4), (2, 1, 3, 4, 2)),
4181
                 ((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)))
4182
        for size_A, size_B in sizes:
4183
            for left, upper, uni in itertools.product([True, False], repeat=3):
4184
                A = make_arg(size_A)
4185
                if upper:
4186
                    A.triu_()
4187
                else:
4188
                    A.tril_()
4189
                diag = A.diagonal(0, -2, -1)
4190
                if uni:
4191
                    diag.fill_(1.)
4192
                else:
4193
                    diag[diag.abs() < 1e-6] = 1.
4194
                B = make_arg(size_B)
4195
                if not left:
4196
                    B.transpose_(-2, -1)
4197

4198
                X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4199
                if left:
4200
                    B_other = A @ X
4201
                else:
4202
                    B_other = X @ A
4203

4204
                self.assertEqual(*torch.broadcast_tensors(B, B_other))
4205

4206
    def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
4207
                                     device, dtype):
4208
        triangle_function = torch.triu if upper else torch.tril
4209
        b = torch.randn(*b_dims, dtype=dtype, device=device)
4210
        A = torch.randn(*A_dims, dtype=dtype, device=device)
4211
        # create positive definite matrix
4212
        A = torch.matmul(A, A.mT)
4213
        A_triangular = triangle_function(A)
4214
        if unitriangular:
4215
            A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.)
4216
        return b, A_triangular
4217

4218
    @skipCUDAIfNoMagma
4219
    @skipCPUIfNoLapack
4220
    @skipIfTorchDynamo("flaky, needs investigation")
4221
    @dtypes(*floating_and_complex_types())
4222
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4223
                        torch.float64: 1e-8, torch.complex128: 1e-8})
4224
    def test_triangular_solve(self, device, dtype):
4225
        ks = [0, 1, 3]
4226
        ns = [0, 5]
4227
        for k, n, (upper, unitriangular, transpose) in itertools.product(ks, ns,
4228
                                                                         itertools.product([True, False], repeat=3)):
4229
            b, A = self.triangular_solve_test_helper((n, n), (n, k), upper,
4230
                                                     unitriangular, device, dtype)
4231
            x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4232
            if transpose:
4233
                self.assertEqual(b, np.matmul(A.t().cpu(), x.cpu()))
4234
            else:
4235
                self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
4236

4237
    @skipCPUIfNoLapack
4238
    @skipCUDAIfNoMagma
4239
    @dtypes(*floating_and_complex_types())
4240
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4241
                        torch.float64: 1e-8, torch.complex128: 1e-8})
4242
    def test_triangular_solve_batched(self, device, dtype):
4243
        def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4244
            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4245
                                                     unitriangular, device, dtype)
4246
            x_exp_list = []
4247
            for i in range(b_dims[0]):
4248
                x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper,
4249
                                                         unitriangular=unitriangular,
4250
                                                         transpose=transpose)[0])
4251
            x_exp = torch.stack(x_exp_list)  # Stacked output
4252
            x_act = torch.triangular_solve(b, A, upper=upper,
4253
                                           unitriangular=unitriangular,
4254
                                           transpose=transpose)[0]  # Actual output
4255
            self.assertEqual(x_act, x_exp)  # Equality check
4256
            if transpose:
4257
                A = A.mT
4258

4259
            Ax = np.matmul(A.cpu(), x_act.cpu())
4260
            self.assertEqual(b, Ax)
4261

4262
        def triangular_solve_zero_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4263
            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4264
                                                     unitriangular, device, dtype)
4265
            x = torch.triangular_solve(b, A, upper=upper,
4266
                                       unitriangular=unitriangular,
4267
                                       transpose=transpose)[0]
4268
            self.assertTrue(x.shape == b.shape)
4269

4270
        for upper, unitriangular, transpose in itertools.product([True, False], repeat=3):
4271
            batchsize = 3
4272
            triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4273
                                          upper, unitriangular, transpose)
4274

4275
            # test empty input
4276
            triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 10),
4277
                                          upper, unitriangular, transpose)
4278
            triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 0),
4279
                                          upper, unitriangular, transpose)
4280

4281
            # test zero batch case
4282
            batchsize = 0
4283
            triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4284
                                               upper, unitriangular, transpose)
4285

4286

4287
    @slowTest
4288
    @skipCUDAIfNoMagma
4289
    @skipCPUIfNoLapack
4290
    @dtypes(*floating_and_complex_types())
4291
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4292
                        torch.float64: 1e-8, torch.complex128: 1e-8})
4293
    def test_triangular_solve_batched_many_batches(self, device, dtype):
4294
        for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4295
            # test batched A case
4296
            b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1),
4297
                                                     upper, unitriangular, device, dtype)
4298
            x, _ = torch.triangular_solve(b, A,
4299
                                          upper=upper, transpose=transpose, unitriangular=unitriangular)
4300
            if transpose:
4301
                A = A.mT
4302

4303
            Ax = torch.matmul(A, x)
4304

4305
            rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision
4306
            self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol)
4307

4308
            # test batched b case
4309
            b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1),
4310
                                                     upper, unitriangular, device, dtype)
4311
            x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose,
4312
                                          unitriangular=unitriangular)
4313
            if transpose:
4314
                A = A.mT
4315

4316
            self.assertEqual(torch.matmul(A, x), b)
4317

4318
    @skipCUDAIfNoMagma
4319
    @skipCPUIfNoLapack
4320
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
4321
    @skipIfTorchDynamo("flaky, needs investigation")
4322
    @dtypes(*floating_and_complex_types())
4323
    def test_triangular_solve_batched_broadcasting(self, device, dtype):
4324
        from scipy.linalg import solve_triangular as tri_solve
4325

4326
        def scipy_tri_solve_batched(A, B, upper, trans, diag):
4327
            batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4328
            single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4329
            expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4330
                                                     torch.Size(batch_dims_B)))
4331
            expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4332
            expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4333
            flat_A = expand_A.reshape((-1,) + single_dim_A)
4334
            flat_B = expand_B.reshape((-1,) + single_dim_B)
4335
            flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag)
4336
                                for a, b in zip(flat_A, flat_B)])
4337
            return flat_X.reshape(expand_B.shape)
4338

4339
        def run_test(A_dims, b_dims, device, upper, transpose, unitriangular):
4340
            b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4341
                                                     unitriangular, device, dtype)
4342
            x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(),
4343
                                                            upper, transpose, unitriangular))
4344
            x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
4345

4346
            self.assertEqual(x, x_exp.to(device))
4347

4348
        for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4349
            # test against scipy.linalg.solve_triangular
4350
            run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular)  # no broadcasting
4351
            run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular)  # broadcasting b
4352
            run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular)  # broadcasting A
4353
            run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular)  # broadcasting A & b
4354

4355
    @onlyCUDA
4356
    @dtypes(torch.float)
4357
    def test_triangular_solve_large(self, device, dtype):
4358
        # Repro for https://github.com/pytorch/pytorch/issues/79191
4359
        A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_()
4360
        B = torch.randn(1, 2, 524281, device=device, dtype=dtype)
4361
        X = torch.linalg.solve_triangular(A, B, upper=False)
4362
        self.assertEqual(A @ X, B)
4363

4364
    @skipCUDAIfNoMagma
4365
    @skipCPUIfNoLapack
4366
    @dtypes(*floating_and_complex_types())
4367
    def test_triangular_solve_out_errors_and_warnings(self, device, dtype):
4368
        # dtypes should be safely castable
4369
        a = torch.eye(2, dtype=dtype, device=device)
4370
        b = torch.randn(2, 1, dtype=dtype, device=device)
4371
        out = torch.empty_like(b).to(torch.int)
4372
        clone_a = torch.empty_like(a)
4373
        with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4374
            torch.triangular_solve(b, a, out=(out, clone_a))
4375

4376
        out = torch.empty_like(b)
4377
        clone_a = clone_a.to(torch.int)
4378
        with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4379
            torch.triangular_solve(b, a, out=(out, clone_a))
4380

4381
        # device should match
4382
        if torch.cuda.is_available():
4383
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
4384
            out = torch.empty(0, dtype=dtype, device=wrong_device)
4385
            clone_a = torch.empty_like(a)
4386
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4387
                torch.triangular_solve(b, a, out=(out, clone_a))
4388
            out = torch.empty(0, dtype=dtype, device=device)
4389
            clone_a = torch.empty_like(a).to(wrong_device)
4390
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4391
                torch.triangular_solve(b, a, out=(out, clone_a))
4392

4393
        # Trigger the WARN_ONCE deprecation error
4394
        torch.triangular_solve(b, a)
4395

4396
        # if out tensor with wrong shape is passed a warning is given
4397
        with warnings.catch_warnings(record=True) as w:
4398
            out = torch.empty(1, dtype=dtype, device=device)
4399
            clone_a = torch.empty(1, dtype=dtype, device=device)
4400
            # Trigger warning
4401
            torch.triangular_solve(b, a, out=(out, clone_a))
4402
            # Check warning occurs
4403
            self.assertEqual(len(w), 2)
4404
            self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
4405
            self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
4406

4407

4408
    def check_single_matmul(self, x, y):
4409

4410
        def assertEqual(answer, expected):
4411
            if x.dtype.is_floating_point or x.dtype.is_complex:
4412
                k = max(x.shape[-1], 1)  # Scale the atol with the size of the matrix
4413
                self.assertEqual(answer, expected,
4414
                                 msg=f"{x.shape} x {y.shape} = {answer.shape}",
4415
                                 atol=k * 5e-5,
4416
                                 rtol=1e-4)
4417
            else:
4418
                self.assertEqual(answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}")
4419

4420
        # test x @ y
4421
        expected = np.matmul(x.cpu(), y.cpu())
4422
        ans = torch.matmul(x, y)
4423
        self.assertTrue(ans.is_contiguous())
4424
        assertEqual(ans, expected)
4425

4426
        # test out
4427
        out = torch.empty_like(ans)
4428
        ans = torch.matmul(x, y, out=out)
4429
        self.assertIs(ans, out)
4430
        self.assertTrue(ans.is_contiguous())
4431
        assertEqual(ans, expected)
4432

4433
    def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
4434
        """
4435
        Generates sequences of tuples (x, y) of with size(x) = x_dim and
4436
        size(y) <= y_dim that are compatible wrt. matmul
4437
        """
4438
        assert x_dim >= 1
4439
        assert y_dim >= 2
4440
        x = x_dim
4441
        for y in range(1, y_dim + 1):
4442
            for batch, mn in product(product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
4443
                                     product(range(matrix_size), repeat=min(y, 2))):
4444
                if x == 1:
4445
                    size_x = mn[:1]
4446
                    size_y = batch + mn
4447
                    yield size_x, size_y
4448
                else:
4449
                    for k in range(matrix_size):
4450
                        size_x = (k,) + mn[:1]
4451
                        if x > 2:
4452
                            size_x = batch[-(x - 2):] + size_x
4453
                        size_y = mn
4454
                        if y > 2:
4455
                            size_y = batch[-(y - 2):] + size_y
4456
                        yield size_x, size_y
4457

4458
    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4459
    @dtypes(torch.int64, torch.float, torch.complex64)
4460
    @setBlasBackendsToDefaultFinally
4461
    def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
4462
        for backend in ["cublas", "cublaslt"]:
4463
            if torch.device(device).type == 'cuda':
4464
                torch.backends.cuda.preferred_blas_library(backend)
4465

4466
            make_arg = partial(make_tensor, device=device, dtype=dtype)
4467

4468
            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4469
                x = make_arg(size_x, noncontiguous=nctg_x)
4470
                y = make_arg(size_y, noncontiguous=nctg_y)
4471
                self.check_single_matmul(x, y)
4472

4473
    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4474
    @dtypes(torch.int64, torch.float, torch.complex64)
4475
    @setBlasBackendsToDefaultFinally
4476
    def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
4477
        for backend in ["cublas", "cublaslt"]:
4478
            if torch.device(device).type == 'cuda':
4479
                torch.backends.cuda.preferred_blas_library(backend)
4480

4481
            make_arg = partial(make_tensor, device=device, dtype=dtype)
4482

4483
            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
4484
                x = make_arg(size_x, noncontiguous=nctg_x)
4485
                y = make_arg(size_y, noncontiguous=nctg_y)
4486
                self.check_single_matmul(x, y)
4487

4488
    @dtypesIfCUDA(torch.float, torch.complex64)  # Integer matmul just supported on CPU
4489
    @dtypes(torch.int64, torch.float, torch.complex64)
4490
    @setBlasBackendsToDefaultFinally
4491
    def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
4492
        for backend in ["cublas", "cublaslt"]:
4493
            if torch.device(device).type == 'cuda':
4494
                torch.backends.cuda.preferred_blas_library(backend)
4495

4496
            make_arg = partial(make_tensor, device=device, dtype=dtype)
4497

4498
            for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(3), (True, False), (True, False)):
4499
                x = make_arg(size_x, noncontiguous=nctg_x)
4500
                y = make_arg(size_y, noncontiguous=nctg_y)
4501
                self.check_single_matmul(x, y)
4502

4503
    @onlyCUDA
4504
    @dtypes(*floating_types_and(torch.half))
4505
    def test_matmul_small_brute_force_tunableop(self, device, dtype):
4506
        # disable tunableop buffer rotation for all tests everywhere, it can be slow
4507
        import os
4508
        os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0"
4509
        set_tunableop_defaults()
4510

4511
        torch.cuda.tunable.enable()
4512
        # set these to single iterations to keep it short but still exercise the code
4513
        torch.cuda.tunable.set_max_tuning_duration(1)
4514
        torch.cuda.tunable.set_max_tuning_iterations(1)
4515

4516
        make_arg = partial(make_tensor, device=device, dtype=dtype)
4517

4518
        for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4519
            x = make_arg(size_x, noncontiguous=nctg_x)
4520
            y = make_arg(size_y, noncontiguous=nctg_y)
4521
            self.check_single_matmul(x, y)
4522

4523
        filename1 = torch.cuda.tunable.get_filename()
4524
        filename2 = "tunableop_results_tmp1.csv"
4525
        filename3 = "tunableop_results_tmp2.csv"
4526
        ordinal = torch.cuda.current_device()
4527
        assert filename1 == f"tunableop_results{ordinal}.csv"
4528
        assert len(torch.cuda.tunable.get_validators()) > 0
4529
        validators = {}
4530
        for key, value in torch.cuda.tunable.get_validators():
4531
            validators[key] = value
4532
        if torch.version.hip:
4533
            assert "HIPBLASLT_VERSION" in validators
4534
            assert re.match(r'^\d{3}-[a-z0-9]{8}$', validators["HIPBLASLT_VERSION"])
4535
        assert len(torch.cuda.tunable.get_results()) > 0
4536

4537
        assert torch.cuda.tunable.write_file()  # use default filename
4538
        assert torch.cuda.tunable.write_file(filename2)  # use custom, one-time filename
4539
        torch.cuda.tunable.set_filename(filename3)
4540
        assert torch.cuda.tunable.write_file()  # use previously set filename
4541
        assert torch.cuda.tunable.read_file()  # use previously set filename, will ignore duplicates and return True
4542

4543
        with open(filename1) as file1:
4544
            file1_contents = file1.read()
4545
        with open(filename2) as file2:
4546
            file2_contents = file2.read()
4547
        with open(filename3) as file3:
4548
            file3_contents = file3.read()
4549
        assert file1_contents == file2_contents
4550
        assert file1_contents == file3_contents
4551

4552
        # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors
4553
        for filename in [filename1, filename2, filename3]:
4554
            try:
4555
                import os
4556
                os.remove(filename)
4557
            except FileNotFoundError:
4558
                pass
4559

4560
        # disables TunableOp
4561
        torch.cuda.tunable.enable(False)
4562

4563
    @onlyCUDA
4564
    @skipCUDAIfNotRocm
4565
    @dtypes(torch.float)
4566
    def test_bmm_tunableop_rocm(self, device, dtype):
4567
        # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault
4568
        set_tunableop_defaults()
4569
        torch.cuda.tunable.enable(True)
4570
        torch.cuda.tunable.set_max_tuning_iterations(10)
4571
        # the following 3 cases cover all previous failure cases and are here to catch regressions
4572
        B = 16
4573
        N = M = K = 256
4574
        dtype = torch.bfloat16
4575
        device = torch.device("cuda:0")
4576
        # case 1
4577
        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4578
        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4579
        out = torch.bmm(i1, i2)
4580
        # case 2
4581
        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4582
        i1 = torch.permute(i1, (1, 2, 0))
4583
        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4584
        i2 = torch.permute(i2, (1, 0, 2))
4585
        out = torch.bmm(i1, i2)
4586
        # case 3
4587
        i1 = torch.randn((N, B, M), device=device, dtype=dtype)
4588
        i1 = torch.permute(i1, (1, 0, 2))
4589
        i2 = torch.randn((M, B, K), device=device, dtype=dtype)
4590
        i2 = torch.permute(i2, (1, 2, 0))
4591
        out = torch.bmm(i1, i2)
4592
        # case 4
4593
        input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype)
4594
        input_tensor = torch.as_strided(
4595
            input_tensor, size=(1920, 1, 100), stride=(100, 100, 1)
4596
        )
4597
        batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype)
4598
        batch1_tensor = torch.as_strided(
4599
            batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1)
4600
        )
4601
        batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype)
4602
        batch2_tensor = torch.as_strided(
4603
            batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1)
4604
        )
4605
        out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor)
4606
        # clean up, remove any file that was generated
4607
        try:
4608
            import os
4609
            filename = torch.cuda.tunable.get_filename()
4610
            os.remove(filename)
4611
        except FileNotFoundError:
4612
            pass
4613

4614
        # disable TunableOp
4615
        torch.cuda.tunable.enable(False)
4616

4617
    @onlyCUDA
4618
    @skipCUDAIfNotRocm
4619
    @dtypes(torch.float)
4620
    def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
4621
        from torch.testing._internal.common_utils import CudaMemoryLeakCheck
4622
        import os
4623
        # run operator first without tuning to ensure all rocm libs are loaded,
4624
        # otherwise false positive mem leak
4625
        B = 16
4626
        N = M = K = 256
4627
        dtype = torch.bfloat16
4628
        device = torch.device("cuda:0")
4629
        i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4630
        i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4631
        out = torch.bmm(i1, i2)
4632
        # enable tunableop numeric check via env variable.
4633
        PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
4634
        prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
4635
        try:
4636
            os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
4637
            torch.cuda.tunable.enable(True)
4638
            ordinal = torch.cuda.current_device()
4639
            filename = f"tunableop_results{ordinal}.csv"
4640
            torch.cuda.tunable.set_filename(filename)
4641
            iterations = torch.cuda.tunable.get_max_tuning_iterations()
4642
            torch.cuda.tunable.set_max_tuning_iterations(10)
4643
            with CudaMemoryLeakCheck(self):
4644
                out = torch.bmm(i1, i2)
4645
                torch.cuda.tunable.set_max_tuning_iterations(iterations)
4646
                torch.cuda.tunable.enable(False)
4647
                # clean up, remove any file that was generated
4648
                try:
4649
                    os.remove(filename)
4650
                except FileNotFoundError:
4651
                    pass
4652
        finally:
4653
            if prev_val is None:
4654
                del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
4655
            else:
4656
                os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
4657

4658
    @onlyCUDA
4659
    @skipCUDAIfNotRocm
4660
    @dtypes(torch.float)
4661
    def test_validator_tunableop_rocm(self, device, dtype):
4662
        # Test that the validator on ROCM has exactly 5 lines
4663
        # Format of the Validator is as follows:
4664
        # Validator,PT_VERSION,X.Y.Z.
4665
        # Validator,ROCBLAS_VERSION,X.Y,Z
4666
        # Validator,HIPBLASLT_VERSION,X,Y.Z
4667
        # Validator,ROCM_Version,X,Y.Z
4668
        # Validator,GCN_ARCH_NAME,<architecutre name>
4669
        validator_num_lines = 5
4670

4671
        # Test in try-finally block to avoid leaking state
4672
        # if test is interrupted.
4673
        try:
4674
            set_tunableop_defaults()
4675
            torch.cuda.tunable.enable()
4676
            # set these to single iterations to keep it short but still exercise the code
4677
            torch.cuda.tunable.set_max_tuning_iterations(1)
4678

4679
            N = M = K = 4
4680
            A = torch.randn(N, K, device=device, dtype=dtype)
4681
            B = torch.randn(K, M, device=device, dtype=dtype)
4682
            C = torch.matmul(A, B)
4683
            self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
4684
        finally:
4685
            # disable TunableOp
4686
            torch.cuda.tunable.enable(False)
4687

4688
            # clean up, remove any file that was generated
4689
            try:
4690
                import os
4691
                filename = torch.cuda.tunable.get_filename()
4692
                os.remove(filename)
4693
            except FileNotFoundError:
4694
                pass
4695

4696
    @onlyCUDA
4697
    @dtypes(torch.half)
4698
    def test_minimum_tuning_iteration_tunableop(self, device, dtype):
4699
        # Make sure that there is at least one tuning iteration under various scenarios
4700

4701
        # Test in try-finally block to avoid leaking state
4702
        # if test is interrupted.
4703
        try:
4704
            set_tunableop_defaults()
4705
            torch.cuda.tunable.enable()
4706
            # set these to single iterations to keep it short but still exercise the code
4707
            torch.cuda.tunable.set_max_tuning_iterations(1)
4708

4709
            # Set tuning duration to zero milliseconds
4710
            # Tune a single GEMM and verify that we get a new tuning result
4711
            import os
4712
            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "0"
4713
            self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4714
            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"  # reset to default
4715

4716
            # Reference number of results
4717
            ref_num_results = len(torch.cuda.tunable.get_results())
4718

4719
            N = M = K = 8
4720
            A = torch.randn(N, K, device=device, dtype=dtype)
4721
            B = torch.randn(K, M, device=device, dtype=dtype)
4722
            C = torch.matmul(A, B)
4723

4724
            # This stores total number of cummulative results
4725
            total_num_results = len(torch.cuda.tunable.get_results())
4726

4727
            # There must be a new tuning result
4728
            self.assertEqual((total_num_results - ref_num_results), 1)
4729

4730
            # Set tuning iterations to zero
4731
            # Tune a single GEMM and verify that we get a new tuning result
4732
            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "0"
4733
            self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4734
            os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "100"  # reset to default
4735

4736
            # Reference number of results
4737
            ref_num_results = total_num_results
4738

4739
            N = M = K = 16
4740
            A = torch.randn(N, K, device=device, dtype=dtype)
4741
            B = torch.randn(K, M, device=device, dtype=dtype)
4742
            C = torch.matmul(A, B)
4743

4744
            # This stores total number of cummulative results
4745
            total_num_results = len(torch.cuda.tunable.get_results())
4746

4747
            # There must be a new tuning result
4748
            self.assertEqual((total_num_results - ref_num_results), 1)
4749

4750
        finally:
4751
            # disable TunableOp
4752
            torch.cuda.tunable.enable(False)
4753

4754
            # clean up, remove any file that was generated
4755
            try:
4756
                import os
4757
                filename = torch.cuda.tunable.get_filename()
4758
                os.remove(filename)
4759
            except FileNotFoundError:
4760
                pass
4761

4762
    @onlyCUDA
4763
    @dtypes(torch.half)
4764
    def test_matmul_check_entries_tunableop(self, device, dtype):
4765
        # Tune a couple of matrix multiplies
4766
        # Verify we get the correct number of results
4767

4768
        try:
4769
            set_tunableop_defaults()
4770
            torch.cuda.tunable.enable()
4771
            # set these to single iterations to keep it short but still exercise the code
4772
            torch.cuda.tunable.set_max_tuning_iterations(1)
4773

4774
            # Reference number of results
4775
            ref_num_results = len(torch.cuda.tunable.get_results())
4776

4777
            # Execute matrix multiplies. We intentionally throw in M list the same index
4778
            # twice. The CSV file should only get unique GEMMs
4779
            count_matmul = 4
4780
            K = 64
4781
            for M in [32, 64, 32]:
4782
                for N in [32, 64]:
4783
                    A = torch.randn(N, K, device=device, dtype=dtype)
4784
                    B = torch.randn(K, M, device=device, dtype=dtype)
4785
                    C = torch.matmul(A, B)
4786

4787
            # This stores total number of cummulative results
4788
            total_num_results = len(torch.cuda.tunable.get_results())
4789

4790
            # Take the difference to calculate the number of results from
4791
            # the this test and verify that it agrees with the number of
4792
            # GEMMs.
4793
            self.assertEqual((total_num_results - ref_num_results), count_matmul)
4794

4795
        finally:
4796
            # disable TunableOp
4797
            torch.cuda.tunable.enable(False)
4798

4799
            # clean up, remove any file that was generated
4800
            try:
4801
                import os
4802
                filename = torch.cuda.tunable.get_filename()
4803
                os.remove(filename)
4804
            except FileNotFoundError:
4805
                pass
4806

4807
    @dtypes(torch.float, torch.complex64)
4808
    def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
4809
        a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
4810
        b = torch.empty((4, 128, 512), device=device, dtype=dtype, requires_grad=True).transpose(-1, -2)
4811
        c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
4812

4813
        torch.matmul(a.detach(), b.detach(), out=c)
4814

4815
        with self.assertRaisesRegex(RuntimeError, "functions with out=... arguments don't support automatic differentiation"):
4816
            torch.matmul(a, b, out=c)
4817

4818
        with torch.no_grad():
4819
            torch.matmul(a, b, out=c)
4820

4821
    # 4GB should do, but we run tests in parallel in CI, so let's be generous
4822
    @largeTensorTest('16GB', device='cuda')
4823
    def test_large_bmm_mm_backward(self, device):
4824
        A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4825
        B = torch.randn([1024, 65536], device="cuda", requires_grad=True)
4826
        G = torch.randn([1024, 2, 65536], device="cuda")
4827

4828
        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
4829
        (A @ B).backward(G)
4830

4831
    # 4GB should do, but we run tests in parallel in CI, so let's be generous
4832
    @largeTensorTest('16GB', device='cuda')
4833
    def test_large_bmm_backward(self, device):
4834
        A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4835
        B = torch.randn([1, 1024, 65536], device="cuda", requires_grad=True)
4836
        G = torch.randn([1024, 2, 65536], device="cuda")
4837

4838
        # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
4839
        (A @ B).backward(G)
4840

4841
    def test_linear_algebra_scalar_raises(self, device) -> None:
4842
        m = torch.randn(5, 5, device=device)
4843
        v = torch.randn(5, device=device)
4844
        s = torch.tensor(7, device=device)
4845
        self.assertRaises(RuntimeError, lambda: torch.mv(m, s))
4846
        self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s))
4847

4848
    @dtypes(torch.float32, torch.complex64)
4849
    def test_cross(self, device, dtype):
4850
        x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4851
        y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4852
        res1 = torch.cross(x, y)
4853
        res2 = torch.tensor((), dtype=dtype, device=device)
4854
        torch.cross(x, y, out=res2)
4855
        self.assertEqual(res1, res2)
4856

4857
    @dtypes(torch.float32, torch.complex64)
4858
    def test_linalg_cross(self, device, dtype):
4859
        x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4860
        y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4861
        res1 = torch.linalg.cross(x, y, dim=1)
4862
        res2 = torch.tensor((), dtype=dtype, device=device)
4863
        torch.linalg.cross(x, y, dim=1, out=res2)
4864
        self.assertEqual(res1, res2)
4865

4866
        # test for broadcastable inputs
4867
        x = torch.rand(1, 3, 2, dtype=dtype, device=device)
4868
        y = torch.rand(4, 3, 1, dtype=dtype, device=device)
4869
        res1 = torch.linalg.cross(x, y, dim=1)
4870
        res2 = torch.tensor((), dtype=dtype, device=device)
4871
        torch.linalg.cross(x, y, dim=1, out=res2)
4872
        self.assertEqual(res1, res2)
4873

4874
    @dtypes(torch.float32, torch.complex64)
4875
    def test_cross_with_and_without_dim(self, device, dtype):
4876
        x = torch.rand(100, 3, dtype=dtype, device=device)
4877
        y = torch.rand(100, 3, dtype=dtype, device=device)
4878
        res1 = torch.cross(x, y, dim=1)
4879
        res2 = torch.cross(x, y, dim=-1)
4880
        res3 = torch.cross(x, y)
4881
        self.assertEqual(res1, res2)
4882
        self.assertEqual(res1, res3)
4883

4884
    @dtypes(torch.float32, torch.complex64)
4885
    def test_linalg_cross_with_and_without_dim(self, device, dtype):
4886
        x = torch.rand(100, 3, dtype=dtype, device=device)
4887
        y = torch.rand(100, 3, dtype=dtype, device=device)
4888
        res1 = torch.linalg.cross(x, y, dim=1)
4889
        res2 = torch.linalg.cross(x, y, dim=-1)
4890
        res3 = torch.linalg.cross(x, y)
4891
        self.assertEqual(res1, res2)
4892
        self.assertEqual(res1, res3)
4893

4894
    def test_renorm(self, device):
4895
        m1 = torch.randn(20, 20, device=device)  # big enough to exercise vectorized path
4896
        res1 = torch.tensor((), device=device)
4897

4898
        def renorm(matrix, value, dim, max_norm):
4899
            m1 = matrix.transpose(dim, 0).contiguous()
4900
            # collapse non-dim dimensions.
4901
            m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
4902
            norms = m2.norm(value, 1, True)
4903
            # clip
4904
            new_norms = norms.clone()
4905
            new_norms[torch.gt(norms, max_norm)] = max_norm
4906
            new_norms.div_(norms.add_(1e-7))
4907
            # renormalize
4908
            m1.mul_(new_norms.expand_as(m1))
4909
            return m1.transpose(dim, 0)
4910

4911
        # note that the axis fed to torch.renorm is different (2~=1)
4912
        maxnorm = m1.norm(2, 1).mean()
4913
        m2 = renorm(m1, 2, 1, maxnorm)
4914
        m1.renorm_(2, 1, maxnorm)
4915
        self.assertEqual(m1, m2, atol=1e-5, rtol=0)
4916
        self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0)
4917

4918
        m1 = torch.randn(3, 4, 5, device=device)
4919
        m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
4920
        maxnorm = m2.norm(2, 0).mean()
4921
        m2 = renorm(m2, 2, 1, maxnorm)
4922
        m1.renorm_(2, 1, maxnorm)
4923
        m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
4924
        self.assertEqual(m3, m2)
4925
        self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
4926

4927
    @skipCPUIfNoLapack
4928
    @skipCUDAIfNoCusolver
4929
    @dtypes(*floating_and_complex_types())
4930
    def test_ormqr(self, device, dtype):
4931

4932
        def run_test(batch, m, n, fortran_contiguous):
4933
            A = make_tensor((*batch, m, n), dtype=dtype, device=device)
4934
            reflectors, tau = torch.geqrf(A)
4935
            if not fortran_contiguous:
4936
                self.assertTrue(reflectors.mT.is_contiguous())
4937
                reflectors = reflectors.contiguous()
4938

4939
            # Q is of size m x m
4940
            Q, _ = torch.linalg.qr(A, mode='complete')
4941
            C_right = make_tensor((*batch, m, n), dtype=dtype, device=device)
4942
            C_left = make_tensor((*batch, n, m), dtype=dtype, device=device)
4943

4944
            expected = Q @ C_right
4945
            actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=False)
4946
            self.assertEqual(expected, actual)
4947

4948
            expected = C_left @ Q
4949
            actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=False)
4950
            self.assertEqual(expected, actual)
4951

4952
            expected = Q.mH @ C_right
4953
            actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=True)
4954
            self.assertEqual(expected, actual)
4955

4956
            expected = C_left @ Q.mH
4957
            actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=True)
4958
            self.assertEqual(expected, actual)
4959

4960
            # if tau is all zeros then the implicit matrix Q is the identity matrix
4961
            # so the actual result should be C_right in this case
4962
            zero_tau = torch.zeros_like(tau)
4963
            actual = torch.ormqr(reflectors, zero_tau, C_right, left=True, transpose=False)
4964
            self.assertEqual(C_right, actual)
4965

4966
        batches = [(), (0, ), (2, ), (2, 1)]
4967
        ns = [5, 2, 0]
4968
        for batch, (m, n), fortran_contiguous in product(batches, product(ns, ns), [True, False]):
4969
            run_test(batch, m, n, fortran_contiguous)
4970

4971
    @skipCPUIfNoLapack
4972
    @skipCUDAIfNoCusolver
4973
    @dtypes(*floating_and_complex_types())
4974
    def test_ormqr_errors_and_warnings(self, device, dtype):
4975
        test_cases = [
4976
            # input1 size, input2 size, input3 size, error regex
4977
            ((10,), (2,), (2,), r"input must have at least 2 dimensions"),
4978
            ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"),
4979
            ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"),
4980
            ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"),
4981
            ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"),
4982
            ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"),
4983
        ]
4984
        for a_size, tau_size, c_size, error_regex in test_cases:
4985
            a = make_tensor(a_size, dtype=dtype, device=device)
4986
            tau = make_tensor(tau_size, dtype=dtype, device=device)
4987
            c = make_tensor(c_size, dtype=dtype, device=device)
4988
            with self.assertRaisesRegex(RuntimeError, error_regex):
4989
                torch.ormqr(a, tau, c)
4990

4991
    def test_blas_empty(self, device):
4992
        def fn(torchfn, *args, test_out=False, **kwargs):
4993
            def call_torch_fn(*args, **kwargs):
4994
                return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
4995
                                      for shape in args), **kwargs)
4996
            result = call_torch_fn(*args, **kwargs)
4997
            if not test_out:
4998
                return result
4999
            else:
5000
                out = torch.full_like(result, math.nan)
5001
                out1 = call_torch_fn(*args, **kwargs, out=out)
5002
                return out
5003

5004
        # mm, addmm
5005
        self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
5006
        self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
5007
        self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
5008
        self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
5009
        self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)))
5010
        self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True))
5011

5012
        self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
5013
        self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape)
5014
        t = torch.randn((5, 6), device=device)
5015
        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
5016
        self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
5017

5018
        # mv, addmv
5019
        self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
5020
        self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
5021
        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
5022
        self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True))
5023

5024
        self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
5025
        t = torch.randn((3,), device=device)
5026
        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
5027
        self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
5028

5029
        # bmm, baddbmm
5030
        self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
5031
        self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
5032
        self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
5033
        self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)))
5034
        self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True))
5035

5036
        self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape)
5037
        self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape)
5038
        self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape)
5039
        self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape)
5040
        c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
5041
        self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2))  # Issue #33467
5042
        self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True))  # Issue #33467
5043

5044
        # addbmm
5045
        self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
5046
        self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
5047
        t = torch.randn((5, 6), device=device)
5048
        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
5049
        self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
5050

5051
        # matmul
5052
        self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,)))
5053
        self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True))
5054
        self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
5055
        self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
5056
        self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
5057
        self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4)))
5058
        self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True))
5059

5060
        # dot
5061
        self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
5062
        self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))
5063

5064
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
5065
                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5066
    @dtypesIfCUDA(*floating_and_complex_types_and(
5067
                  torch.half,
5068
                  *[torch.bfloat16] if SM53OrLater else []
5069
                  ))
5070
    @dtypes(*all_types_and_complex_and(torch.bfloat16))
5071
    def test_corner_cases_of_cublasltmatmul(self, device, dtype):
5072
        # common case
5073
        M = torch.randn(128, device=device).to(dtype)
5074
        m1 = torch.randn(2048, 2400, device=device).to(dtype)
5075
        m2 = torch.randn(128, 2400, device=device).to(dtype)
5076
        torch.nn.functional.linear(m1, m2, M)
5077
        # Ntrans_B has ld >> rows
5078
        m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
5079
        m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
5080
        M = torch.rand([128]).to(dtype).to(device)
5081
        torch.addmm(M, m2.t(), m1)
5082
        # trans_A has ld >> rows
5083
        m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
5084
        m2 = torch.randn(2048, 2400, device=device).to(dtype)
5085
        M = torch.rand([128]).to(dtype).to(device)
5086
        torch.addmm(M, m2, m1)
5087
        # large tensor dim > 65535
5088
        M = torch.randn(16, device=device).to(dtype)
5089
        m1 = torch.randn(32, 131071 , device=device).to(dtype)
5090
        m2 = torch.randn(16, 131071, device=device).to(dtype)
5091
        torch.nn.functional.linear(m1, m2, M)
5092

5093
    @onlyCUDA
5094
    @skipCUDAIfNotRocm
5095
    @dtypes(*floating_types_and(torch.bfloat16, torch.half))
5096
    def test_hipblaslt_corner_cases_rocm(self, device, dtype):
5097
        if dtype == torch.double:
5098
            raise unittest.SkipTest("hipblasLt doesn't support doubles yet")
5099

5100
        # enable hipblaslt path via env variable.
5101
        import os
5102
        DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT"
5103
        prev_val = os.getenv(DISABLE_ADDMM_HIP_LT)
5104
        try:
5105
            os.environ[DISABLE_ADDMM_HIP_LT] = "0"
5106
            # common case
5107
            M = torch.randn(128, device=device, dtype=dtype)
5108
            m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5109
            m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5110
            out1 = torch.nn.functional.linear(m1, m2, M)
5111
            M_cpu = M.to('cpu')
5112
            m1_cpu = m1.to('cpu')
5113
            m2_cpu = m2.to('cpu')
5114
            out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu)
5115
            self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2))
5116

5117
            # common case without bias
5118
            m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5119
            m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5120
            out2 = torch.nn.functional.linear(m1, m2, bias=None)
5121
            m1_cpu = m1.to('cpu')
5122
            m2_cpu = m2.to('cpu')
5123
            out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None)
5124
            self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2))
5125
        finally:
5126
            if prev_val is None:
5127
                del os.environ[DISABLE_ADDMM_HIP_LT]
5128
            else:
5129
                os.environ[DISABLE_ADDMM_HIP_LT] = prev_val
5130

5131
    @dtypesIfCUDA(*floating_and_complex_types_and(
5132
                  torch.half,
5133
                  *[torch.bfloat16] if SM53OrLater else []
5134
                  ))
5135
    @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.half))
5136
    def test_blas_alpha_beta_empty(self, device, dtype):
5137
        # This test is disabled on CUDA 9 due to:
5138
        # See: https://github.com/pytorch/pytorch/issues/31006
5139
        if dtype is torch.bfloat16 and self.device_type == 'xla':
5140
            # TODO (@zasdfgbnm): this causes the following error on test
5141
            # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16:
5142
            #
5143
            #   RuntimeError: _th_equal not supported on CPUType for BFloat16
5144
            return
5145
        # ensure beta is respected
5146
        value = 11
5147
        input = torch.full((2,), value, dtype=dtype, device=device)
5148
        mat = torch.ones((2, 0), dtype=dtype, device=device)
5149
        vec = torch.ones((0,), dtype=dtype, device=device)
5150
        out = torch.empty((2,), dtype=dtype, device=device)
5151
        if dtype.is_complex:
5152
            alpha = 6 + 7j
5153
            beta = 3 + 4j
5154
        else:
5155
            alpha = 6
5156
            beta = 3
5157
        self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5158
                         torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta))
5159
        self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5160
                         torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out))
5161

5162
        # torch.addmm
5163
        input = torch.full((2, 3), value, dtype=dtype, device=device)
5164
        mat2 = torch.ones((0, 3), dtype=dtype, device=device)
5165
        out = torch.empty((2, 3), dtype=dtype, device=device)
5166
        self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5167
                         torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta))
5168
        self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5169
                         torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out))
5170

5171
    @dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
5172
    def test_blas_nan_out(self, device, dtype):
5173
        # These functions should work correctly with NaN filled outputs,
5174
        # but need special handling, see [NOTE: cpu_zero]
5175
        b = 3
5176
        n = 5
5177
        m = 7
5178
        p = 11
5179

5180
        # torch.mv
5181
        nm = torch.randn((m, n), device=device).t()
5182
        _m = torch.randn((), device=device).expand(m)
5183
        _m_out = torch.full((m,), float('nan'), device=device)
5184
        self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5185
        self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum())
5186

5187
        # torch.mm
5188
        mp = torch.randn((p, m), device=device).t()
5189
        np_out = torch.full((n, p), float('nan'), device=device)
5190
        self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out))
5191

5192
        # torch.bmm
5193
        bnm = torch.randn((b, m, n), device=device).transpose(1, 2)
5194
        bmp = torch.randn((b, p, m), device=device).transpose(1, 2)
5195
        bnp_out = torch.full((b, n, p), float('nan'), device=device)
5196
        self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out))
5197

5198
    @onlyCPU  # not supported by CUBLAS
5199
    def test_blas_mv_large_input(self, device):
5200
        # This would previously fail if the allocated output had NaNs, see:
5201
        # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero]
5202
        n = 3000
5203
        m = 200
5204

5205
        nm = torch.randn((m, n), device=device).t()
5206
        _m = torch.randn((), device=device).expand(m)
5207
        _m_out = torch.full((m,), 0., device=device)
5208

5209
        self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5210

5211
    @onlyCPU
5212
    def test_renorm_ps(self, device):
5213
        # full reduction
5214
        x = torch.randn(5, 5)
5215
        xn = x.numpy()
5216
        for p in [1, 2, 3, 4, inf]:
5217
            res = x.renorm(p, 1, 1)
5218
            expected = x / x.norm(p, 0, keepdim=True).clamp(min=1)
5219
            self.assertEqual(res, expected, msg=f"renorm failed for {p}-norm")
5220

5221
    @skipCPUIfNoLapack
5222
    @skipCUDAIfNoCusolver
5223
    @dtypes(*floating_and_complex_types())
5224
    def test_householder_product(self, device, dtype):
5225
        def generate_reflectors_and_tau(A):
5226
            """
5227
            This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf.
5228
            There is torch.geqrf function but it doesn't work with complex-valued input.
5229
            """
5230
            if A.numel() > 0:
5231
                A_cpu = A.cpu()
5232
                flattened_batch_shape = [-1, *A_cpu.shape[-2:]]
5233
                reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape)
5234
                tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]]
5235
                tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1])
5236
                for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau):
5237
                    reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw'))
5238
                    reflectors_i[:] = reflectors_tmp.T
5239
                reflectors = reflectors.view(*A_cpu.shape)
5240
                tau = tau.view(tau_shape)
5241
                return reflectors.to(A.device), tau.to(A.device)
5242

5243
            reflectors = torch.empty_like(A)
5244
            tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device)
5245
            return reflectors, tau
5246

5247
        def run_test(shape):
5248
            A = torch.randn(*shape, dtype=dtype, device=device)
5249
            reflectors, tau = generate_reflectors_and_tau(A)
5250
            expected, _ = torch.linalg.qr(A)
5251
            actual = torch.linalg.householder_product(reflectors, tau)
5252
            # torch.linalg.qr does not work correctly for zero batch dimension tensors
5253
            # see https://github.com/pytorch/pytorch/issues/50576
5254
            if (A.numel() > 0):
5255
                self.assertEqual(expected, actual)
5256
            else:
5257
                self.assertTrue(actual.shape == shape)
5258

5259
            # if tau is empty and A is not the result should be a matrix with ones on the diagonal
5260
            if (A.numel() > 0):
5261
                tau_empty = torch.empty(*shape[:-2], 0, dtype=dtype, device=device)
5262
                identity_mat = torch.zeros_like(reflectors)
5263
                identity_mat.diagonal(dim1=-1, dim2=-2)[:] = 1
5264
                actual = torch.linalg.householder_product(reflectors, tau_empty)
5265
                self.assertEqual(actual, identity_mat)
5266

5267
            out = torch.empty_like(A)
5268
            ans = torch.linalg.householder_product(reflectors, tau, out=out)
5269
            self.assertEqual(ans, out)
5270
            if (A.numel() > 0):
5271
                self.assertEqual(expected, out)
5272

5273
        shapes = [(0, 0), (5, 0),  # Empty matrix
5274
                  (5, 5), (5, 3),  # Single matrix
5275
                  (0, 0, 0), (0, 5, 5), (0, 5, 3),  # Zero batch dimension tensors
5276
                  (2, 5, 5), (2, 5, 3),  # 3-dim tensors
5277
                  (2, 1, 5, 5), (2, 1, 5, 3)]  # 4-dim tensors
5278
        for shape in shapes:
5279
            run_test(shape)
5280

5281
    @skipCPUIfNoLapack
5282
    @skipCUDAIfNoCusolver
5283
    def test_householder_product_errors_and_warnings(self, device):
5284
        test_cases = [
5285
            # input1 size, input2 size, error regex
5286
            ((10,), (2,), r"input must have at least 2 dimensions"),
5287
            ((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"),
5288
            ((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"),
5289
        ]
5290
        for a_size, tau_size, error_regex in test_cases:
5291
            a = torch.rand(*a_size, device=device)
5292
            tau = torch.rand(*tau_size, device=device)
5293
            with self.assertRaisesRegex(RuntimeError, error_regex):
5294
                torch.linalg.householder_product(a, tau)
5295

5296
        # if out tensor with wrong shape is passed a warning is given
5297
        reflectors = torch.randn(3, 3, device=device)
5298
        tau = torch.randn(3, device=device)
5299
        out = torch.empty(2, 3, device=device)
5300
        with warnings.catch_warnings(record=True) as w:
5301
            # Trigger warning
5302
            torch.linalg.householder_product(reflectors, tau, out=out)
5303
            # Check warning occurs
5304
            self.assertEqual(len(w), 1)
5305
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
5306

5307
        # dtypes should be safely castable
5308
        out = torch.empty_like(reflectors).to(torch.int)
5309
        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
5310
            torch.linalg.householder_product(reflectors, tau, out=out)
5311

5312
        with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"):
5313
            torch.linalg.householder_product(reflectors, tau.to(torch.int))
5314

5315
        if torch.cuda.is_available():
5316
            # device of out and input should match
5317
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5318
            out = torch.empty_like(reflectors).to(wrong_device)
5319
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5320
                torch.linalg.householder_product(reflectors, tau, out=out)
5321

5322
            # device of tau and input should match
5323
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5324
            tau = tau.to(wrong_device)
5325
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5326
                torch.linalg.householder_product(reflectors, tau)
5327

5328
    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5329
    @skipCUDAIfNoMagmaAndNoCusolver
5330
    @skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor")
5331
    @skipCPUIfNoLapack
5332
    @dtypes(*floating_and_complex_types())
5333
    def test_linalg_lu_family(self, device, dtype):
5334
        # Tests torch.lu
5335
        #       torch.linalg.lu_factor
5336
        #       torch.linalg.lu_factor_ex
5337
        #       torch.lu_unpack
5338
        #       torch.linalg.lu_solve
5339
        #       torch.linalg.solve
5340
        make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype)
5341
        make_arg = partial(make_tensor, device=device, dtype=dtype)
5342

5343
        def run_test(A, pivot, singular, fn):
5344
            k = min(A.shape[-2:])
5345
            batch = A.shape[:-2]
5346
            check_errors = (fn == torch.linalg.lu_factor)
5347
            if singular and check_errors:
5348
                # It may or may not throw as the LU decomposition without pivoting
5349
                # may still succeed for singular matrices
5350
                try:
5351
                    LU, pivots = fn(A, pivot=pivot)
5352
                except RuntimeError:
5353
                    return
5354
            else:
5355
                LU, pivots = fn(A, pivot=pivot)[:2]
5356

5357
            self.assertEqual(LU.size(), A.shape)
5358
            self.assertEqual(pivots.size(), batch + (k,))
5359

5360
            if not pivot:
5361
                self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, )))
5362

5363
            P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot)
5364

5365
            self.assertEqual(P @ L @ U if pivot else L @ U, A)
5366

5367
            PLU = torch.linalg.lu(A, pivot=pivot)
5368
            self.assertEqual(P, PLU.P)
5369
            self.assertEqual(L, PLU.L)
5370
            self.assertEqual(U, PLU.U)
5371

5372
            if not singular and A.size(-2) == A.size(-1):
5373
                nrhs = ((), (1,), (3,))
5374
                for left, rhs in product((True, False), nrhs):
5375
                    # Vector case when left = False is not allowed
5376
                    if not left and rhs == ():
5377
                        continue
5378
                    if left:
5379
                        shape_B = A.shape[:-1] + rhs
5380
                    else:
5381
                        shape_B = A.shape[:-2] + rhs + A.shape[-1:]
5382
                    B = make_arg(shape_B)
5383

5384
                    # Test linalg.lu_solve. It does not support vectors as rhs
5385
                    # See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913
5386
                    if rhs != ():
5387
                        for adjoint in (True, False):
5388
                            X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
5389
                            A_adj = A.mH if adjoint else A
5390
                            if left:
5391
                                self.assertEqual(B, A_adj @ X)
5392
                            else:
5393
                                self.assertEqual(B, X @ A_adj)
5394

5395
                    # Test linalg.solve
5396
                    X = torch.linalg.solve(A, B, left=left)
5397
                    X_ = X.unsqueeze(-1) if rhs == () else X
5398
                    B_ = B.unsqueeze(-1) if rhs == () else B
5399
                    if left:
5400
                        self.assertEqual(B_, A @ X_)
5401
                    else:
5402
                        self.assertEqual(B_, X_ @ A)
5403

5404

5405
        sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
5406
        batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
5407
        # Non pivoting just implemented for CUDA
5408
        pivots = (True, False) if self.device_type == "cuda" else (True,)
5409
        fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex)
5410
        for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns):
5411
            shape = batch + ms
5412
            A = make_arg(shape) if singular else make_arg_full(*shape)
5413
            # Just do one of them on singular matrices
5414
            if A.numel() == 0 and not singular:
5415
                continue
5416
            run_test(A, pivot, singular, fn)
5417

5418
            # Reproducer of a magma bug,
5419
            # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on
5420
            # This is also a bug in cuSOLVER < 11.3
5421
            if (dtype == torch.double
5422
               and singular):
5423
                A = torch.ones(batch + ms, dtype=dtype, device=device)
5424
                run_test(A, pivot, singular, fn)
5425

5426
        # Info should be positive for rank deficient matrices
5427
        A = torch.ones(5, 3, 3, device=device)
5428
        self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all())
5429

5430
        if self.device_type == 'cpu':
5431
            # Error checking, no pivoting variant on CPU
5432
            fns = [torch.lu, torch.linalg.lu_factor, torch.linalg.lu_factor_ex, torch.linalg.lu]
5433
            for f in fns:
5434
                with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
5435
                    f(torch.empty(1, 2, 2), pivot=False)
5436

5437

5438
    @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5439
    @skipCUDAIfNoMagmaAndNoCusolver
5440
    @skipCPUIfNoLapack
5441
    @setLinalgBackendsToDefaultFinally
5442
    @dtypes(*floating_and_complex_types())
5443
    def test_linalg_lu_solve(self, device, dtype):
5444
        make_arg = partial(make_tensor, dtype=dtype, device=device)
5445

5446
        backends = ["default"]
5447

5448
        if torch.device(device).type == 'cuda':
5449
            if torch.cuda.has_magma:
5450
                backends.append("magma")
5451
            if has_cusolver():
5452
                backends.append("cusolver")
5453

5454
        def gen_matrices():
5455
            rhs = 3
5456
            ns = (5, 2, 0)
5457
            batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
5458
            for batch, n in product(batches, ns):
5459
                yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs))
5460
            # Shapes to exercise all the paths
5461
            shapes = ((1, 64), (2, 128), (1025, 2))
5462
            for b, n in shapes:
5463
                yield make_arg((b, n, n)), make_arg((b, n, rhs))
5464

5465

5466
        for A, B in gen_matrices():
5467
            LU, pivots = torch.linalg.lu_factor(A)
5468
            for backend in backends:
5469
                torch.backends.cuda.preferred_linalg_library(backend)
5470

5471
                for left, adjoint in product((True, False), repeat=2):
5472
                    B_left = B if left else B.mT
5473
                    X = torch.linalg.lu_solve(LU, pivots, B_left, left=left, adjoint=adjoint)
5474
                    A_adj = A.mH if adjoint else A
5475
                    if left:
5476
                        self.assertEqual(B_left, A_adj @ X)
5477
                    else:
5478
                        self.assertEqual(B_left, X @ A_adj)
5479

5480

5481
    @onlyCPU
5482
    @dtypes(*floating_and_complex_types())
5483
    def test_linalg_lu_cpu_errors(self, device, dtype):
5484
        # Square tests
5485
        sample = torch.randn(3, 2, 2, device=device, dtype=dtype)
5486
        B = torch.randn(3, 2, 2, device=device, dtype=dtype)
5487
        LU, pivots = torch.linalg.lu_factor(sample)
5488

5489
        # This should run without issues
5490
        torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5491
        torch.lu_unpack(LU, pivots)
5492

5493
        pivots[0] = 0
5494
        with self.assertRaisesRegex(RuntimeError, r"greater or equal to 1"):
5495
            torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5496
        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5497
            torch.lu_unpack(LU, pivots)
5498

5499
        pivots[0] = 3
5500
        with self.assertRaisesRegex(RuntimeError, r"smaller or equal to LU.size\(-2\)"):
5501
            torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5502
        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5503
            torch.lu_unpack(LU, pivots)
5504

5505
        # Rectangular tests
5506
        sample = torch.randn(3, 4, 2, device=device, dtype=dtype)
5507
        B = torch.randn(3, 4, 2, device=device, dtype=dtype)
5508
        LU, pivots = torch.linalg.lu_factor(sample)
5509

5510
        # This should run without issues
5511
        torch.lu_unpack(LU, pivots)
5512

5513
        pivots[0] = 0
5514
        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5515
            torch.lu_unpack(LU, pivots)
5516

5517
        pivots[0] = 5
5518
        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5519
            torch.lu_unpack(LU, pivots)
5520

5521

5522
        # Rectangular tests
5523
        sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
5524
        B = torch.randn(2, 3, 5, device=device, dtype=dtype)
5525
        LU, pivots = torch.linalg.lu_factor(sample)
5526

5527
        # This should run without issues
5528
        torch.lu_unpack(LU, pivots)
5529

5530
        pivots[0] = 0
5531
        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5532
            torch.lu_unpack(LU, pivots)
5533

5534
        pivots[0] = 4
5535
        with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5536
            torch.lu_unpack(LU, pivots)
5537

5538

5539
    @skipCPUIfNoLapack
5540
    @skipCUDAIfNoMagma
5541
    @dtypes(torch.double)
5542
    def test_lu_unpack_check_input(self, device, dtype):
5543
        x = torch.rand(5, 5, 5, device=device, dtype=dtype)
5544
        lu_data, lu_pivots = torch.linalg.lu_factor(x)
5545

5546
        with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
5547
            torch.lu_unpack(lu_data, lu_pivots.long())
5548

5549
        # check that onces flags are unset, Nones are returned
5550
        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False)
5551
        self.assertTrue(l.numel() == 0 and u.numel() == 0)
5552
        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False)
5553
        self.assertTrue(p.numel() == 0)
5554
        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False)
5555
        self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0)
5556

5557
    @skipCUDAIfNoMagma
5558
    @skipCPUIfNoLapack
5559
    @dtypes(torch.double)
5560
    def test_lobpcg_basic(self, device, dtype):
5561
        self._test_lobpcg_method(device, dtype, 'basic')
5562

5563
    @skipCUDAIfNoCusolver
5564
    @skipCPUIfNoLapack
5565
    @dtypes(torch.double)
5566
    def test_lobpcg_ortho(self, device, dtype):
5567
        if torch.version.hip:
5568
            torch.backends.cuda.preferred_linalg_library('magma')
5569
        self._test_lobpcg_method(device, dtype, 'ortho')
5570
        if torch.version.hip:
5571
            torch.backends.cuda.preferred_linalg_library('default')
5572

5573
    def _test_lobpcg_method(self, device, dtype, method):
5574
        from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix
5575
        from torch._linalg_utils import matmul, qform
5576
        from torch._lobpcg import lobpcg
5577

5578
        def test_tracker(worker):
5579
            k = worker.iparams['k']
5580
            nc = worker.ivars['converged_count']
5581
            if k <= nc:
5582
                tol = worker.fparams['tol']
5583
                rerr = worker.tvars['rerr']
5584
                X = worker.X
5585
                E = worker.E
5586
                B = worker.B
5587
                A = worker.A
5588
                dtype = X.dtype
5589
                device = X.device
5590

5591
                # Check convergence
5592
                self.assertLessEqual(rerr[:k].max(), tol)
5593

5594
                # Check B-orthogonality
5595
                I = torch.eye(k, k, dtype=dtype, device=device)
5596
                self.assertEqual(qform(B, X[:, :k]), I)
5597

5598
                # Check block equation
5599
                self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0)
5600

5601
        orig_lobpcg = lobpcg
5602

5603
        def lobpcg(*args, **kwargs):
5604
            kwargs['tracker'] = test_tracker
5605
            kwargs['niter'] = 1000
5606
            kwargs['method'] = method
5607
            kwargs['tol'] = 1e-8
5608
            return orig_lobpcg(*args, **kwargs)
5609
        prec = 5e-4
5610

5611
        # check dense input
5612
        mm = torch.matmul
5613
        for batches in [(), (2,), (2, 3)]:
5614
            for m, n, k in [
5615
                    (9, 3, 1),
5616
                    (9, 3, 2),
5617
                    (9, 2, 2),
5618
                    (100, 15, 5),
5619
            ]:
5620
                # skip tests that are known to fail with the basic
5621
                # LOBPCG method due to calling cholesky on singular
5622
                # input
5623
                if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]:
5624
                    continue
5625
                A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5626
                B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5627

5628
                # classical eigenvalue problem, smallest eigenvalues
5629
                E, V = lobpcg(A, k=k, n=n, largest=False)
5630
                self.assertEqual(E.shape, batches + (k,))
5631
                self.assertEqual(V.shape, batches + (m, k))
5632
                self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5633
                e = torch.linalg.eigvalsh(A)
5634
                e_smallest = e[..., :k]
5635
                self.assertEqual(E, e_smallest)
5636

5637
                # classical eigenvalue problem, largest eigenvalues
5638
                E, V = lobpcg(A, k=k, n=n, largest=True)
5639
                e_largest, _ = torch.sort(e[..., -k:], descending=True)
5640
                self.assertEqual(E, e_largest, atol=prec, rtol=0)
5641
                self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5642

5643
                # generalized eigenvalue problem, smallest eigenvalues
5644
                E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5645
                self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0)
5646

5647
                # generalized eigenvalue problem, largest eigenvalues
5648
                E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5649
                self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5650
                                 atol=prec, rtol=0)
5651

5652
        # check sparse input
5653
        for m, n, k, density in [
5654
                (5, 1, 1, 0.8),
5655
                (9, 3, 2, 0.5),
5656
                (100, 1, 1, 0.1),
5657
                (1000, 7, 3, 0.01),
5658
        ]:
5659
            # skip tests that are known to fail with the basic LOBCG
5660
            # method due to insufficient accuracy
5661
            if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]:
5662
                continue
5663
            A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5664
            B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5665
            A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m
5666
            e_smallest = A_eigenvalues[..., :k]
5667
            e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True)
5668

5669
            # classical eigenvalue problem, smallest eigenvalues
5670
            E, V = lobpcg(A, k=k, n=n, largest=False)
5671
            self.assertEqual(E, e_smallest)
5672
            self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5673

5674
            # classical eigenvalue problem, largest eigenvalues
5675
            E, V = lobpcg(A, k=k, n=n, largest=True)
5676
            self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5677
            self.assertEqual(E, e_largest)
5678

5679
            # generalized eigenvalue problem, smallest eigenvalues
5680
            E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5681
            self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0)
5682

5683
            # generalized eigenvalue problem, largest eigenvalues
5684
            E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5685
            self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5686
                             atol=prec, rtol=0)
5687

5688
    @skipCPUIfNoLapack
5689
    @onlyCPU
5690
    @dtypes(torch.double)
5691
    def test_lobpcg_torchscript(self, device, dtype):
5692
        from torch.testing._internal.common_utils import random_sparse_pd_matrix
5693
        from torch._linalg_utils import matmul as mm
5694

5695
        lobpcg = torch.jit.script(torch.lobpcg)
5696

5697
        m = 500
5698
        k = 5
5699
        A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5700
        X1 = torch.randn((m, k), dtype=dtype, device=device)
5701
        E1, V1 = lobpcg(A1, X=X1)
5702
        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5703
        self.assertLess(eq_err, 1e-6)
5704

5705
    @unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1")
5706
    @skipCPUIfNoLapack
5707
    @skipIfTorchDynamo("fails in tracing scipy.sparse.lobpcg")
5708
    @onlyCPU
5709
    @dtypes(torch.double)
5710
    def test_lobpcg_scipy(self, device, dtype):
5711
        """Compare torch and scipy.sparse.linalg implementations of lobpcg
5712
        """
5713
        import time
5714
        from torch.testing._internal.common_utils import random_sparse_pd_matrix
5715
        from torch._linalg_utils import matmul as mm
5716
        from scipy.sparse.linalg import lobpcg as scipy_lobpcg
5717
        import scipy.sparse
5718

5719
        def toscipy(A):
5720
            if A.layout == torch.sparse_coo:
5721
                values = A.coalesce().values().cpu().numpy().copy()
5722
                indices = A.coalesce().indices().cpu().numpy().copy()
5723
                return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape)
5724
            return A.cpu().numpy().copy()
5725

5726
        niter = 1000
5727
        repeat = 10
5728
        m = 500   # size of the square matrix
5729
        k = 7     # the number of requested eigenpairs
5730
        A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5731
        B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5732
        X1 = torch.randn((m, k), dtype=dtype, device=device)
5733

5734
        A2 = toscipy(A1)
5735
        B2 = toscipy(B1)
5736
        X2 = toscipy(X1)
5737

5738
        lambdas1 = []
5739

5740
        def tracker(worker):
5741
            lambdas1.append(worker.E[:])
5742

5743
        tol = 1e-8
5744
        # tol for scipy lobpcg will be choosed so that the number of
5745
        # iterations will be equal or very close to pytorch lobpcg
5746
        # (that is around 170-180)
5747

5748
        # Standard eigenvalue problem
5749
        E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5750
        E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol)
5751
        iters1 = len(lambdas1)
5752
        iters2 = len(lambdas2)
5753
        self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5754

5755
        E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False)
5756

5757
        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5758
        eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5759
        self.assertLess(eq_err, 1e-6)        # std
5760
        self.assertLess(eq_err_scipy, 1e-6)  # std
5761

5762
        self.assertEqual(E1, torch.from_numpy(E2.copy()))
5763

5764
        # Generalized eigenvalue problem
5765
        lambdas1 = []
5766

5767
        def tracker(worker):
5768
            lambdas1.append(worker.E[:])
5769

5770
        E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5771
        E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol)
5772
        E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False)
5773
        iters1 = len(lambdas1)
5774
        iters2 = len(lambdas2)
5775
        self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5776

5777
        eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5778
        eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5779
        self.assertLess(eq_err, 1e-6)        # general
5780
        self.assertLess(eq_err_scipy, 1e-6)  # general
5781

5782
        self.assertEqual(E1, torch.from_numpy(E2.copy()))
5783

5784
        # Timings
5785
        elapsed_ortho = 0
5786
        elapsed_ortho_general = 0
5787
        elapsed_scipy = 0
5788
        elapsed_general_scipy = 0
5789
        for i in range(repeat):
5790
            start = time.time()
5791
            torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol)
5792
            end = time.time()
5793
            elapsed_ortho += end - start
5794

5795
            start = time.time()
5796
            torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol)
5797
            end = time.time()
5798
            elapsed_ortho_general += end - start
5799

5800
            start = time.time()
5801
            scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol)
5802
            end = time.time()
5803
            elapsed_scipy += end - start
5804

5805
            start = time.time()
5806
            scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol)
5807
            end = time.time()
5808
            elapsed_general_scipy += end - start
5809

5810
        elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat
5811
        elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat
5812
        elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat
5813
        elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat
5814

5815
        print(f'''
5816
CPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5817
-------------------------------------------------------
5818
              | standard    | generalized | method
5819
torch.lobpcg  | {elapsed_ortho_ms:10.2f}  | {elapsed_ortho_general_ms:10.2f}  | ortho
5820
scipy_lobpcg  | {elapsed_scipy_ms:10.2f}  | {elapsed_general_scipy_ms:10.2f}  | N/A
5821
-(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)-
5822
        ''')
5823

5824
        # Handling of very small tolerence
5825
        tol = 1e-100
5826

5827
        lambdas1 = []
5828

5829
        def tracker(worker):
5830
            lambdas1.append(worker.E[:])
5831

5832
        E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5833
        iters1 = len(lambdas1)
5834
        eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5835

5836
        try:
5837
            E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5838
            iters2 = len(lambdas2)
5839
            eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5840
        except Exception as msg:
5841
            print('Calling scipy_lobpcg failed [standard]:', msg)
5842
            iters2 = -1
5843
            eq_err_scipy = -1
5844

5845
        lambdas1 = []
5846

5847
        def tracker(worker):
5848
            lambdas1.append(worker.E[:])
5849

5850
        E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol)
5851
        iters1_general = len(lambdas1)
5852
        eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5853

5854
        try:
5855
            E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5856
            iters2_general = len(lambdas2)
5857
            eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5858
        except Exception as msg:
5859
            print('Calling scipy_lobpcg failed [generalized]:', msg)
5860
            iters2_general = -1
5861
            eq_err_general_scipy = -1
5862

5863
        print(f'''\
5864
Handling of small tol={tol:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5865
----------------------------------------------------------------------------
5866
              | standard    | generalized |  niter | method
5867
torch.lobpcg  | {eq_err:10.2e}  | {eq_err_general:10.2e}  | {iters1:6} | ortho
5868
scipy_lobpcg  | {eq_err_scipy:10.2e}  | {eq_err_general_scipy:10.2e}  | {iters2:6} | N/A
5869
---(input size: {m:4}, eigenpairs:{k:2}, units: relative error, maxiter={niter:4})---
5870
''')
5871

5872
    def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None):
5873
        dtype = t.dtype
5874
        numpy_dtype = dtype
5875
        if dtype in {torch.bfloat16, torch.half}:
5876
            numpy_dtype = torch.float
5877
        if dtype.is_complex:
5878
            alpha = 0.9 + 0.3j if alpha is None else alpha
5879
            beta = 0.5 + 0.6j if beta is None else beta
5880
        else:
5881
            alpha = 1.2 if alpha is None else alpha
5882
            beta = 0.8 if beta is None else beta
5883
        if activation == "gelu":
5884
            res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
5885
        else:
5886
            res1 = f(t, m, v, alpha=alpha, beta=beta)
5887
        res2 = torch.full_like(res1, math.nan)
5888
        if transpose_out:
5889
            res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
5890
        if activation == "gelu":
5891
            f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
5892
        else:
5893
            f(t, m, v, alpha=alpha, beta=beta, out=res2)
5894
        res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
5895
        if beta != 0:
5896
            res3 += (beta * t).to(numpy_dtype).cpu().numpy()
5897
        if activation == "relu":
5898
            res3 = res3 * (res3 > 0)
5899
        elif activation == "gelu":
5900
            res3_t = torch.from_numpy(res3).to(dtype)
5901
            approximate = "tanh" if t.is_cuda else "none"
5902
            res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
5903
            res3 = res3_t.to(numpy_dtype).cpu().numpy()
5904
        else:
5905
            assert activation is None, f"unsupported activation {activation}"
5906
        res3 = torch.from_numpy(res3).to(dtype)
5907
        self.assertEqual(res1, res2)
5908
        self.assertEqual(res1, res3)
5909

5910
    @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4, torch.double: 1e-8,
5911
                        torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5912
    @dtypesIfCUDA(*floating_and_complex_types_and(
5913
                  *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [],
5914
                  torch.half))
5915
    @dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
5916
    def test_addmv(self, device, dtype):
5917
        if IS_ARM64 and device == 'cpu' and dtype == torch.float16:
5918
            raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
5919
        # have to use torch.randn(...).to(bfloat16) instead of
5920
        # torch.randn(..., dtype=bfloat16). randn does not support
5921
        # bfloat16 yet.
5922
        # "*0.2" to reduce errors for low precision
5923
        ts = [
5924
            0.2 * torch.randn(50, device=device).to(dtype),
5925
            0.2 * torch.randn(1, device=device).to(dtype).expand(50),
5926
        ]
5927
        vs = [
5928
            0.2 * torch.randn(100, device=device).to(dtype),
5929
            0.2 * torch.ones(1, device=device).to(dtype).expand(100),  # to reduce errors for low precision
5930
        ]
5931
        ms = [
5932
            # 0d
5933
            0.2 * torch.ones((), device=device).to(dtype).expand(50, 100),  # to reduce errors for low precision
5934
            # 1d
5935
            0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
5936
            # this initialization reduces errors for low precision for broadcasted matrices
5937
            # by making sure that intermediate and result values are exactly representable
5938
            # in low precision type
5939
            0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100),
5940
            # 2d
5941
            0.2 * torch.randn((50, 100), device=device).to(dtype),
5942
            0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
5943
        ]
5944
        for m, v, t in itertools.product(ms, vs, ts):
5945
            self._test_addmm_addmv(torch.addmv, t, m, v)
5946
        # Test beta=0, t=nan
5947
        t = torch.full((50,), math.nan, device=device).to(dtype)
5948
        for m, v in itertools.product(ms, vs):
5949
            self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
5950

5951
    @dtypesIfCUDA(*floating_types_and(*[torch.bfloat16] if TEST_WITH_ROCM or
5952
                  SM53OrLater else []))
5953
    @dtypes(torch.float, torch.double)
5954
    def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
5955
        # tests (o, s)*(s).  o is output size, s is summed size.
5956
        o = 5
5957
        s = 3
5958
        a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
5959
        x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
5960
        y_data = torch.ones(o, device=device, dtype=dtype)
5961
        control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype)
5962

5963
        def _test(row_major, incx, incy, lda_tail):
5964
            if row_major:
5965
                a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype)
5966
            else:
5967
                a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0)
5968
            a = a_storage[:o, :s].copy_(a_data)
5969

5970
            x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype)
5971
            x = x_storage[:, 0].copy_(x_data)
5972

5973
            y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype)
5974
            y = y_storage[:, 0].copy_(y_data)
5975

5976
            self._test_addmm_addmv(torch.addmv, y, a, x)
5977

5978
        for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)):
5979
            _test(row_major, incx, incy, lda_tail)
5980

5981
    def _test_addmm_impl(self, func, activation, device, dtype):
5982
        M = torch.randn(10, 25, device=device).to(dtype)
5983
        m1 = torch.randn(10, 50, device=device).to(dtype)
5984
        m2 = torch.randn(50, 25, device=device).to(dtype)
5985
        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
5986

5987
        # vector-shaped bias and beta=1 result in epilogue fusion in CUDA
5988
        V = torch.randn(25, device=device).to(dtype)
5989
        self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
5990

5991
        # Test 0-strided
5992
        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
5993
        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
5994
        m2 = torch.randn(50, 25, device=device).to(dtype)
5995
        self._test_addmm_addmv(func, M, m1, m2, activation=activation)
5996

5997
        # Test beta=0, M=nan
5998
        M = torch.full((10, 25), math.nan, device=device).to(dtype)
5999
        m1 = torch.randn(10, 50, device=device).to(dtype)
6000
        m2 = torch.randn(50, 25, device=device).to(dtype)
6001
        self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
6002

6003
        # Test transpose
6004
        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
6005
            def maybe_transpose(cond, m):
6006
                if not cond:
6007
                    return m
6008
                return m.t().clone(memory_format=torch.contiguous_format).t()
6009

6010
            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
6011
            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
6012
            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
6013
            self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
6014

6015
            if t1:
6016
                # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
6017
                self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
6018

6019
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
6020
                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6021
    @dtypesIfMPS(torch.float32)
6022
    @dtypesIfCUDA(*floating_and_complex_types_and(
6023
                  *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
6024
    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6025
    @tf32_on_and_off(0.05)
6026
    @bf32_on_and_off(0.05)
6027
    def test_addmm(self, device, dtype):
6028
        self._test_addmm_impl(torch.addmm, None, device, dtype)
6029

6030
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6031
                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6032
    @dtypesIfCUDA(*floating_types_and(
6033
                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6034
    @dtypes(*floating_types_and(torch.bfloat16))
6035
    @tf32_on_and_off(0.05)
6036
    @bf32_on_and_off(0.05)
6037
    def test_addmm_relu(self, device, dtype):
6038
        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6039

6040
    @onlyCUDA
6041
    @skipCUDAIfNotRocm
6042
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6043
                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6044
    @dtypesIfCUDA(*floating_types_and(
6045
                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6046
    @dtypes(*floating_types_and(torch.bfloat16))
6047
    @tf32_on_and_off(0.05)
6048
    @bf32_on_and_off(0.05)
6049
    def test_addmm_relu_tunableop_rocm(self, device, dtype):
6050
        torch.cuda.tunable.enable(True)
6051
        ordinal = torch.cuda.current_device()
6052
        filename = f"tunableop_results{ordinal}.csv"
6053
        torch.cuda.tunable.set_filename(filename)
6054
        iterations = torch.cuda.tunable.get_max_tuning_iterations()
6055
        torch.cuda.tunable.set_max_tuning_iterations(10)
6056
        self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6057
        # clean up, remove any file that was generated
6058
        try:
6059
            import os
6060
            os.remove(filename)
6061
        except FileNotFoundError:
6062
            pass
6063
        # reset back to prior settings
6064
        torch.cuda.tunable.set_max_tuning_iterations(iterations)
6065
        torch.cuda.tunable.enable(False)
6066

6067
    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6068
                        torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6069
    @dtypesIfCUDA(*floating_types_and(
6070
                  *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6071
    @dtypes(*floating_types_and(torch.bfloat16))
6072
    @tf32_on_and_off(0.05)
6073
    @bf32_on_and_off(0.05)
6074
    def test_addmm_gelu(self, device, dtype):
6075
        self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
6076

6077
    @dtypes(torch.float, torch.double)
6078
    @dtypesIfCUDA(*floating_and_complex_types())
6079
    @tf32_on_and_off(0.005)
6080
    @bf32_on_and_off(0.005)
6081
    def test_addmm_sizes(self, device, dtype):
6082
        for m in [0, 1, 25]:
6083
            for n in [0, 1, 10]:
6084
                for k in [0, 1, 8]:
6085
                    M = torch.randn(n, m, device=device).to(dtype)
6086
                    m1 = torch.randn(n, k, device=device).to(dtype)
6087
                    m2 = torch.randn(k, m, device=device).to(dtype)
6088
                    self._test_addmm_addmv(torch.addmm, M, m1, m2)
6089

6090
                    m1 = torch.randn(n, k + 1, device=device).to(dtype)
6091
                    m2 = torch.randn(k, m, device=device).to(dtype)
6092
                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
6093
                    self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
6094

6095
    @dtypes(torch.half)
6096
    @onlyCUDA
6097
    def test_addmm_baddbmm_overflow(self, device, dtype):
6098
        orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
6099
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
6100
        inp = torch.zeros(128, 128, dtype=torch.half, device=device)
6101
        mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100
6102
        mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100
6103
        out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6104
        # just check for no overflow on ROCM
6105
        if TEST_WITH_ROCM:
6106
            self.assertFalse(out.isinf().any())
6107
        else:
6108
            self.assertTrue((out == 10000.).all())
6109
        inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device)
6110
        mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100
6111
        mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100
6112
        out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6113
        if TEST_WITH_ROCM:
6114
            self.assertFalse(out.isinf().any())
6115
        else:
6116
            self.assertTrue((out == 10000.).all())
6117
        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
6118

6119
    @dtypes(torch.float)
6120
    def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
6121
        for shape in [[3, 2, 2], [2, 20, 20]]:
6122
            mat1, mat2 = (torch.randn(shape, dtype=dtype, device=device) for _ in range(2))
6123
            inputs = [torch.randn(shape, dtype=dtype, device=device),
6124
                      torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6125
            outs = [None, torch.randn(shape, dtype=dtype, device=device),
6126
                    torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6127
            options = itertools.product(inputs, outs)
6128
            for input, out in options:
6129
                y_ref = torch.bmm(mat1, mat2)
6130
                y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
6131
                self.assertEqual(y_ref, y)
6132

6133
    @dtypes(torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64)
6134
    def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
6135
        batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6136
        batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6137
        input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
6138
        if dtype != torch.float32:
6139
            with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
6140
                y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
6141
        else:
6142
            out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
6143
            y_ref = torch.bmm(batch1, batch2)
6144
            y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
6145
            self.assertEqual(out, y_ref)
6146

6147

6148
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6149
    @onlyCUDA
6150
    def test_matmul_45724(self, device):
6151
        # https://github.com/pytorch/pytorch/issues/45724
6152
        a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
6153
        b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
6154
        c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
6155
        cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
6156
        torch.matmul(a, b, out=c)
6157
        self.assertEqual(c, cpu_result)
6158

6159
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6160
    @unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90")
6161
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6162
    @onlyCUDA
6163
    @parametrize("k", [16, 32])
6164
    @parametrize("n", [16, 32])
6165
    @parametrize("use_transpose_a", [True, False])
6166
    @parametrize("use_transpose_b", [True, False])
6167
    def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b):
6168
        def genf_int_float(x, y, use_transpose):
6169
            if use_transpose:
6170
                x, y = y, x
6171
            x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6172
            x_float = x_int8.to(torch.float32)
6173
            if use_transpose:
6174
                return x_int8.t(), x_float.t()
6175
            return x_int8, x_float
6176

6177
        def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
6178
            a_int8, a_float = genf_int_float(m, k, transpose_a)
6179
            b_int8, b_float = genf_int_float(k, n, transpose_b)
6180
            c_int32 = torch._int_mm(a_int8, b_int8)
6181
            self.assertTrue(c_int32.dtype is torch.int32)
6182
            self.assertEqual(c_int32.device, torch.device(device))
6183
            if test_equal:
6184
                self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6185
            else:
6186
                self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float))
6187
            c_int32_result = c_int32.new_empty(c_int32.size())
6188
            # Checking out variant
6189
            torch._int_mm(a_int8, b_int8, out=c_int32_result)
6190
            if test_equal:
6191
                self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6192
            else:
6193
                self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6194

6195
        # NOTE: We're just exercising terrible failures here.
6196
        version = _get_torch_cuda_version()
6197
        SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
6198
        SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0)
6199
        SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5)
6200

6201
        if TEST_WITH_ROCM:
6202
            _test(17, k, n, use_transpose_a, use_transpose_b, True)
6203
        elif version >= (11, 7):
6204
            if not use_transpose_a and use_transpose_b:
6205
                if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6206
                    _test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7))
6207
                else:
6208
                    with self.assertRaisesRegex(RuntimeError,
6209
                                                "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6210
                        _test(17, k, n, use_transpose_a, use_transpose_b)
6211

6212
            if use_transpose_a and not use_transpose_b:
6213
                with self.assertRaisesRegex(RuntimeError,
6214
                                            "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6215
                    _test(17, k, n, use_transpose_a, use_transpose_b)
6216

6217
            if use_transpose_a and use_transpose_b:
6218
                with self.assertRaisesRegex(RuntimeError,
6219
                                            "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6220
                    _test(17, k, n, use_transpose_a, use_transpose_b)
6221

6222
            if not use_transpose_a and not use_transpose_b:
6223
                if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6224
                    _test(17, k, n, use_transpose_a, use_transpose_b)
6225
                else:
6226
                    with self.assertRaisesRegex(RuntimeError,
6227
                                                "CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6228
                        _test(17, k, n, use_transpose_a, use_transpose_b)
6229
        else:
6230
            with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"):
6231
                _test(17, k, n, use_transpose_a, use_transpose_b, False)
6232

6233
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6234
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6235
    @onlyCUDA
6236
    def test__int_mm_errors(self, device):
6237
        if TEST_WITH_ROCM:
6238
            self.skipTest("_int_mm not compiled for ROCM")
6239

6240
        version = _get_torch_cuda_version()
6241
        if version < (11, 7):
6242
            self.skipTest("_int_mm only compiled for CUDA 11.7")
6243

6244
        def genf_int(x, y):
6245
            return torch.empty((x, y), dtype=torch.int8, device=device)
6246

6247
        def _gen_pair(m, k, n):
6248
            return genf_int(m, k), genf_int(k, n)
6249

6250
        self.assertRaisesRegex(RuntimeError,
6251
                               r"self.size\(0\) needs to be greater than 16, but got 16",
6252
                               lambda: torch._int_mm(*_gen_pair(16, 8, 32)))
6253
        self.assertRaisesRegex(RuntimeError,
6254
                               r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7",
6255
                               lambda: torch._int_mm(*_gen_pair(17, 7, 32)))
6256
        self.assertRaisesRegex(RuntimeError,
6257
                               r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7",
6258
                               lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32)))
6259
        self.assertRaisesRegex(RuntimeError,
6260
                               r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31",
6261
                               lambda: torch._int_mm(*_gen_pair(17, 8, 31)))
6262
        self.assertRaisesRegex(RuntimeError,
6263
                               r"expected scalar type Char but found Float",
6264
                               lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32)))
6265
        self.assertRaisesRegex(RuntimeError,
6266
                               r"expected scalar type Char but found Float",
6267
                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float()))
6268
        self.assertRaisesRegex(RuntimeError,
6269
                               r"Expected result dtype to be of type kInt but got float",
6270
                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float()))
6271
        self.assertRaisesRegex(RuntimeError,
6272
                               r"Expected result.size\(0\) to be 17 but got 15",
6273
                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int()))
6274
        self.assertRaisesRegex(RuntimeError,
6275
                               r"Expected result.size\(0\) to be 17 but got 16",
6276
                               lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))
6277

6278
    @onlyCPU
6279
    @parametrize("m", [0, 8, 17])
6280
    @parametrize("k", [0, 16, 32])
6281
    @parametrize("n", [16, 32])
6282
    @parametrize("use_transpose_a", [True, False])
6283
    @parametrize("use_transpose_b", [True, False])
6284
    @parametrize("non_contig_type", [0, 1, 2])
6285
    def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):
6286
        # non_contig_type:
6287
        # 0: the whole data buffer is contiguous (can be transposed)
6288
        # 1: stride of one dimension is 1, but the whole buffer is not contiguous
6289
        # 2: Neither stride is 1
6290

6291
        def genf_int_float(x, y, use_transpose, non_contig_type):
6292
            if use_transpose:
6293
                x, y = y, x
6294
            if non_contig_type != 0:
6295
                y = y * 2
6296
            x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6297
            x_float = x_int8.to(torch.float32)
6298
            if non_contig_type == 1:
6299
                x_int8 = x_int8[:, : y // 2]
6300
                x_float = x_float[:, : y // 2]
6301
            elif non_contig_type == 2:
6302
                x_int8 = x_int8[:, ::2]
6303
                x_float = x_float[:, ::2]
6304
            if use_transpose:
6305
                return x_int8.t(), x_float.t()
6306
            return x_int8, x_float
6307

6308
        if non_contig_type != 0 and (m == 0 or k == 0):
6309
            return
6310
        a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
6311
        b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
6312
        c_int32 = torch._int_mm(a_int8, b_int8)
6313
        self.assertTrue(c_int32.dtype is torch.int32)
6314
        self.assertEqual(c_int32.device, torch.device(device))
6315
        self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6316
        c_int32_result = c_int32.new_empty(c_int32.size())
6317
        # Checking out variant
6318
        torch._int_mm(a_int8, b_int8, out=c_int32_result)
6319
        self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6320

6321
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6322
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6323
    @onlyNativeDeviceTypes
6324
    def test__convert_weight_to_int4pack(self, device):
6325
        # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
6326
        test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)]
6327
        if self.device_type == 'cuda' and not SM80OrLater:
6328
            self.skipTest("requires SM80 or later")
6329

6330
        if TEST_WITH_ROCM:
6331
            if not CDNA2OrLater():
6332
                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6333

6334
        torch.manual_seed(1)
6335
        for shape, innerKTiles in test_list:
6336
            b = torch.rand(shape, dtype=torch.bfloat16, device=device)
6337
            b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32)
6338
            b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles)
6339
            b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles)
6340
            self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
6341

6342
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6343
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6344
    @onlyNativeDeviceTypes
6345
    @parametrize("m", [32, 64])
6346
    @parametrize("k", [32, 64])
6347
    @parametrize("n", [48, 64])
6348
    def test__int4_mm(self, device, m, k, n):
6349
        if self.device_type == 'cuda' and not SM80OrLater:
6350
            self.skipTest("requires SM80 or later")
6351

6352
        if TEST_WITH_ROCM:
6353
            if not CDNA2OrLater():
6354
                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6355

6356
        q_group = 32
6357
        inner_k_tiles = 2
6358

6359
        torch.manual_seed(1)
6360
        a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6361
        b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6362

6363
        def convert_weight_to_int4pack(b):
6364
            b_uint8, b_scales_and_zeros = _group_quantize_tensor(
6365
                b, n_bit=4, q_group_size=q_group
6366
            )
6367
            b_int4pack = torch._convert_weight_to_int4pack(
6368
                b_uint8, inner_k_tiles
6369
            )
6370

6371
            return b_int4pack, b_scales_and_zeros
6372

6373
        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
6374
            return torch._weight_int4pack_mm(
6375
                a, b_int4pack, q_group, b_scales_and_zeros
6376
            )
6377

6378
        b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)
6379

6380
        for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []):
6381
            a = a_bf16.to(dtype=dtype)
6382
            b = b_bf16.to(dtype=dtype)
6383
            b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
6384
            ref = torch.mm(a, b)
6385
            res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
6386

6387
            mean_err = ((res - ref).abs() / ref).mean()
6388
            self.assertTrue(mean_err < 0.05)
6389

6390

6391
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6392
    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6393
    @onlyNativeDeviceTypes
6394
    @parametrize("m", [32, 64])
6395
    @parametrize("k", [32, 64])
6396
    @parametrize("n", [48, 64])
6397
    def test_compile_int4_mm(self, device, m, k, n):
6398
        if self.device_type == 'cuda' and not SM80OrLater:
6399
            self.skipTest("requires SM80 or later")
6400

6401
        if TEST_WITH_ROCM:
6402
            if not CDNA2OrLater():
6403
                self.skipTest("_int4_mm is supported only for CDNA2 or later")
6404

6405
        q_group = 32
6406
        inner_k_tiles = 2
6407

6408
        torch.manual_seed(1)
6409
        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6410
        b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6411

6412
        b_int32, b_scales_and_zeros = _group_quantize_tensor(
6413
            b, n_bit=4, q_group_size=q_group
6414
        )
6415

6416
        @torch.compile
6417
        def int4_mm(a, b_int32, b_scales_and_zeros):
6418
            b_int4pack = torch._convert_weight_to_int4pack(
6419
                b_int32, inner_k_tiles
6420
            )
6421
            return torch._weight_int4pack_mm(
6422
                a, b_int4pack, q_group, b_scales_and_zeros
6423
            )
6424

6425
        res = int4_mm(a, b_int32, b_scales_and_zeros)
6426
        ref = torch.mm(a, b)
6427

6428
        mean_err = ((res - ref).abs() / ref).mean()
6429
        self.assertTrue(mean_err < 0.05)
6430

6431
    @onlyCPU
6432
    @parametrize("m", [32, 64])
6433
    @parametrize("k", [32, 64])
6434
    @parametrize("n", [48, 64])
6435
    def test__int8_mm(self, device, m, k, n):
6436
        torch.manual_seed(1)
6437
        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6438
        b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6439

6440
        def convert_weight_to_int8pack(b):
6441
            b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6442
                b, -128, 127, torch.int8
6443
            )
6444
            return b_int8pack, b_scales
6445

6446
        def weight_int8pack_mm(a, b_int8pack, b_scales):
6447
            return torch._weight_int8pack_mm(
6448
                a, b_int8pack, b_scales
6449
            )
6450

6451
        b_int8pack, b_scales = convert_weight_to_int8pack(b)
6452
        res = weight_int8pack_mm(a, b_int8pack, b_scales)
6453
        ref = torch.mm(a, b.transpose(0, 1))
6454

6455
        mean_err = ((res - ref).abs() / ref).mean()
6456
        self.assertTrue(mean_err < 0.05)
6457

6458
    @onlyCPU
6459
    @parametrize("m", [32, 64])
6460
    @parametrize("k", [32, 64])
6461
    @parametrize("n", [48, 64])
6462
    def test_compile_int8_mm(self, device, m, k, n):
6463
        torch.manual_seed(1)
6464
        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6465
        b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6466

6467
        b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6468
            b, -128, 127, torch.int8
6469
        )
6470

6471
        @torch.compile
6472
        def int8_mm(a, b_int8pack, b_scales):
6473
            return torch._weight_int8pack_mm(
6474
                a, b_int8pack, b_scales
6475
            )
6476

6477
        res = int8_mm(a, b_int8pack, b_scales)
6478
        ref = torch.mm(a, b.transpose(0, 1))
6479

6480
        mean_err = ((res - ref).abs() / ref).mean()
6481
        self.assertTrue(mean_err < 0.05)
6482

6483
    @onlyCPU
6484
    @parametrize("m", [32, 35, 36, 40, 64])
6485
    @parametrize("k", [32, 35, 36, 40, 64])
6486
    # NOTE: This is intended to cover fp16_gemv_trans in
6487
    # BlasKernel.cpp. Currently, bounds being divisible by 32, 8-but-not-32, and 4-but-not-8
6488
    # all matter.
6489
    def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k):
6490
        torch.manual_seed(1)
6491
        a = torch.rand((m, k), dtype=torch.half, device=device)
6492
        b = torch.rand((1, k), dtype=torch.half, device=device)
6493

6494
        prev = torch._C._get_cpu_allow_fp16_reduced_precision_reduction()
6495
        try:
6496
            torch._C._set_cpu_allow_fp16_reduced_precision_reduction(False)
6497
            ref = torch.mm(a, b.t())
6498
            try:
6499
                torch._C._set_cpu_allow_fp16_reduced_precision_reduction(True)
6500
            except RuntimeError as e:
6501
                raise unittest.SkipTest from e
6502
            res = torch.mm(a, b.t())
6503
            torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2)
6504
        finally:
6505
            torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev)
6506

6507
    @slowTest
6508
    @onlyNativeDeviceTypes
6509
    # bfloat16 doesn't have sufficient precision to pass this test
6510
    @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
6511
    @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble)
6512
    @tf32_on_and_off(0.01)
6513
    @bf32_on_and_off(0.01)
6514
    def test_mm(self, device, dtype):
6515
        def _test_mm(n, m, p, dtype, genf):
6516
            # helper function
6517
            def matrixmultiply(mat1, mat2):
6518
                n = mat1.size(0)
6519
                m = mat1.size(1)
6520
                p = mat2.size(1)
6521
                dtype_ = torch.float if dtype == torch.half else dtype
6522
                if dtype == torch.half:
6523
                    mat1 = mat1.float()
6524
                    mat2 = mat2.float()
6525
                res = torch.zeros(n, p, dtype=dtype_, device=device)
6526
                for i, j in iter_indices(res):
6527
                    res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
6528
                return res.half() if dtype == torch.half else res
6529

6530
            # contiguous case
6531
            mat1 = genf(n, m)
6532
            mat2 = genf(m, p)
6533
            res = torch.mm(mat1, mat2)
6534

6535
            res2 = matrixmultiply(mat1, mat2)
6536
            self.assertEqual(res, res2)
6537

6538
            # non contiguous case 1
6539
            mat1 = genf(n, m)
6540
            mat2 = genf(p, m).t()
6541
            res = torch.mm(mat1, mat2)
6542

6543
            res2 = matrixmultiply(mat1, mat2)
6544
            self.assertEqual(res, res2)
6545

6546
            # non contiguous case 2
6547
            mat1 = genf(m, n).t()
6548
            mat2 = genf(m, p)
6549
            res = torch.mm(mat1, mat2)
6550

6551
            res2 = matrixmultiply(mat1, mat2)
6552
            self.assertEqual(res, res2)
6553

6554
            # non contiguous case 3
6555
            mat1 = genf(m, n).t()
6556
            mat2 = genf(p, m).t()
6557
            res = torch.mm(mat1, mat2)
6558

6559
            res2 = matrixmultiply(mat1, mat2)
6560
            self.assertEqual(res, res2)
6561

6562
            # test with zero stride
6563
            mat1 = genf(n, m)
6564
            mat2 = genf(m, 1).expand(m, p)
6565
            res = torch.mm(mat1, mat2)
6566

6567
            res2 = matrixmultiply(mat1, mat2)
6568
            self.assertEqual(res, res2)
6569

6570
            # explicitly exercise the _out variant in torch.mm().
6571
            # contiguous case
6572
            mat1 = genf(n, m)
6573
            mat2 = genf(m, p)
6574
            res = genf(n, p)
6575
            torch.mm(mat1, mat2, out=res)
6576

6577
            res2 = matrixmultiply(mat1, mat2)
6578
            self.assertEqual(res, res2)
6579

6580
            # explicitly exercise the _out variant in torch.mm().
6581
            # non contiguous case 3
6582
            mat1 = genf(m, n).t()
6583
            mat2 = genf(p, m).t()
6584
            res = genf(n, p)
6585
            torch.mm(mat1, mat2, out=res)
6586

6587
            res2 = matrixmultiply(mat1, mat2)
6588
            self.assertEqual(res, res2)
6589

6590
        def genf_int(x, y):
6591
            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
6592

6593
        def genf_bfloat(x, y):
6594
            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
6595

6596
        def genf_float(x, y):
6597
            return torch.randn(x, y, dtype=dtype, device=device)
6598

6599
        def genf_Half(x, y):
6600
            return torch.randn(x, y, dtype=dtype, device=device)
6601

6602
        for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
6603
            if (dtype == torch.int32) or (dtype == torch.int64):
6604
                genf = genf_int
6605
            elif (dtype == torch.bfloat16):
6606
                genf = genf_bfloat
6607
            elif (dtype == torch.half):
6608
                genf = genf_Half
6609
            else:
6610
                genf = genf_float
6611

6612
            _test_mm(n, m, p, dtype, genf)
6613

6614
    @onlyNativeDeviceTypes
6615
    def test_mm_bmm_non_memory_dense(self, device):
6616
        def _slice(tensor, fn):
6617
            return fn(tensor)[..., ::2]
6618
        A = torch.randn(3, 6, dtype=torch.cfloat, device=device)
6619
        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6620
        out = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6621
        out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6622
        A_conj = _slice(A, torch.conj)
6623
        A_conj_physical = _slice(A, torch.conj_physical)
6624

6625
        self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out))
6626
        self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out))
6627

6628
        Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device)
6629
        Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device)
6630
        Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3)
6631
        out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).mT
6632

6633
        Ab_conj = _slice(Ab, torch.conj)
6634
        Ab_conj_physical = _slice(Ab, torch.conj_physical)
6635

6636
        def t_b(tensor):
6637
            return tensor.mT
6638

6639
        self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b))
6640
        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b))
6641

6642
        # test broadcasting
6643
        self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b))
6644
        self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b))
6645

6646
    @onlyNativeDeviceTypes
6647
    def test_mm_conjtranspose(self, device):
6648
        A = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6649
        B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6650

6651
        # A conjtranspose
6652
        out1 = torch.mm(A.t().conj(), B)
6653
        out1_ref = torch.mm(A.t().conj_physical(), B)
6654
        self.assertEqual(out1, out1_ref)
6655

6656
        # B conjtranspose
6657
        out1 = torch.mm(A, B.t().conj())
6658
        out1_ref = torch.mm(A, B.t().conj_physical())
6659
        self.assertEqual(out1, out1_ref)
6660

6661
        # A&B conjtranspose
6662
        out1 = torch.mm(A.t().conj(), B.t().conj())
6663
        out1_ref = torch.mm(A.t().conj_physical(), B.t().conj_physical())
6664
        self.assertEqual(out1, out1_ref)
6665

6666
    @onlyNativeDeviceTypes
6667
    def test_mm_empty_inputs_mixed_dtype_errors(self, device):
6668
        a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
6669
        b = torch.randn(10, 20, dtype=torch.float32, device=device)
6670
        with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"):
6671
            torch.mm(a, b)
6672

6673
    @onlyNativeDeviceTypes
6674
    @dtypes(torch.float32, torch.float64)
6675
    def test_strided_mm_bmm(self, device, dtype):
6676
        # Tests strided view case with stride smaller than corresponding dimension size
6677
        x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device)
6678
        new_shape = [2, 2, 2]
6679
        new_stride = [3, 1, 1]
6680
        sx = torch.as_strided(x, size=new_shape, stride=new_stride)
6681

6682
        torch_fn = lambda x: torch.bmm(x, x)  # noqa: E731
6683
        np_fn = lambda x: np.matmul(x, x)  # noqa: E731
6684
        self.compare_with_numpy(torch_fn, np_fn, sx)
6685

6686
        torch_fn = lambda x: torch.mm(x, x)  # noqa: E731
6687
        self.compare_with_numpy(torch_fn, np_fn, sx[0])
6688

6689
    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6690
    @onlyNativeDeviceTypes
6691
    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6692
    @tf32_on_and_off(0.05)
6693
    @bf32_on_and_off(0.05)
6694
    def test_bmm(self, device, dtype):
6695
        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6696
            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6697
            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6698
            # undefined bahavior
6699
            return
6700

6701
        batch_sizes = [1, 10]
6702
        M, N, O = 23, 15, 12
6703
        numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6704

6705
        is_supported = True
6706
        if dtype == torch.bfloat16 and self.device_type == 'cuda':
6707
            is_supported = TEST_WITH_ROCM or SM53OrLater
6708

6709
        if not is_supported:
6710
            for num_batches in batch_sizes:
6711
                b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
6712
                b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
6713
                self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6714
                                       lambda: torch.bmm(b1, b2))
6715
            return
6716

6717
        def invert_perm(p):
6718
            d = {x: i for i, x in enumerate(p)}
6719
            return (d[0], d[1], d[2])
6720

6721
        def generate_inputs(num_batches):
6722
            # transposed tensors
6723
            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6724
                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1)
6725
                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1)
6726
                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6727
                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6728
                yield b1, b2
6729
            # broadcasting tensors
6730
            for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
6731
                shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
6732
                shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
6733
                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, M, N)
6734
                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, N, O)
6735
                yield b1, b2
6736
            # zero-sized tensors
6737
            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6738
                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6739
                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6740
                b1 = torch.randn(shape1, dtype=dtype, device=device)
6741
                b2 = torch.randn(shape2, dtype=dtype, device=device)
6742
                yield b1, b2
6743

6744
        for num_batches in batch_sizes:
6745
            for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
6746
                res1 = torch.bmm(b1, b2)
6747
                res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
6748
                    .permute(perm3).contiguous().permute(invert_perm(perm3))
6749
                torch.bmm(b1, b2, out=res2)
6750
                expect = torch.from_numpy(
6751
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6752
                self.assertEqual(expect, res1)
6753
                self.assertEqual(expect, res2)
6754

6755
                if self.device_type == 'cuda':
6756
                    # check that mixed arguments are rejected
6757
                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
6758
                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
6759
                    self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()))
6760

6761
    def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
6762
        getattr(out_tensor, func + "_")(b1, b2)
6763
        self.assertEqual(out_tensor, ref)
6764
        res3 = out_tensor.clone()
6765

6766
        with self.assertWarnsOnceRegex(
6767
                UserWarning, f"This overload of {func}_ is deprecated"):
6768
            getattr(out_tensor, func + "_")(1, b1, b2)
6769
        self.assertEqual(out_tensor, ref * 2),
6770
        getattr(res3, func + "_")(b1, b2, beta=1)
6771
        self.assertEqual(out_tensor, res3)
6772

6773
        with self.assertWarnsOnceRegex(
6774
                UserWarning, f"This overload of {func}_ is deprecated"):
6775
            getattr(out_tensor, func + "_")(1., .5, b1, b2)
6776
        self.assertEqual(out_tensor, ref * 2.5)
6777
        getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5)
6778
        self.assertEqual(out_tensor, res3)
6779

6780
        with self.assertWarnsOnceRegex(
6781
                UserWarning, f"This overload of {func} is deprecated"):
6782
            self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
6783

6784
        res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5)
6785
        self.assertEqual(res4, ref * 3),
6786

6787
        nan = torch.full_like(out_tensor, math.nan)
6788
        res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
6789
        self.assertEqual(res5, ref)
6790

6791
        if b1.is_complex():
6792
            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j)
6793
            self.assertEqual(res6, out_tensor * .1j + .5j * ref)
6794
        else:
6795
            res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5)
6796
            self.assertEqual(res6, out_tensor * .1 + .5 * ref)
6797

6798
        res7 = torch.full_like(out_tensor, math.nan)
6799
        getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
6800
        self.assertEqual(res7, ref)
6801

6802
    @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6803
    @onlyNativeDeviceTypes
6804
    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6805
    @tf32_on_and_off(0.05)
6806
    @bf32_on_and_off(0.05)
6807
    def test_addbmm(self, device, dtype):
6808
        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6809
            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6810
            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6811
            # undefined bahavior
6812
            return
6813

6814
        num_batches = 2
6815
        M, N, O = 16, 17, 18
6816

6817
        is_supported = True
6818
        if dtype == torch.bfloat16:
6819
            if self.device_type == 'cpu':
6820
                self.precision = 1  # 43 vs 43.75
6821
            else:
6822
                is_supported = TEST_WITH_ROCM or SM53OrLater
6823

6824
        if not is_supported:
6825
            b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6826
            b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6827
            t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
6828
            self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6829
                                   lambda: torch.addbmm(t, b1, b2))
6830
            return
6831

6832
        def invert_perm(p):
6833
            d = {x: i for i, x in enumerate(p)}
6834
            return (d[0], d[1], d[2])
6835

6836
        def generate_tensor():
6837
            numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6838
            # transposed tensors
6839
            for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6840
                for perm3 in itertools.permutations((0, 1)):
6841
                    b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) * 0.1
6842
                    b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) * 0.1
6843
                    b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6844
                    b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6845
                    ref = torch.from_numpy(
6846
                        b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6847
                    ).to(device=device, dtype=dtype).sum(0)
6848
                    out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
6849
                    yield b1, b2, ref, out_tensor
6850
            # broadcasting tensors
6851
            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
6852
                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
6853
                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
6854
                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) * 0.1
6855
                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) * 0.1
6856
                ref = torch.from_numpy(
6857
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6858
                ).to(device=device, dtype=dtype).sum(0)
6859
                out_tensor = torch.zeros_like(ref)
6860
                yield b1, b2, ref, out_tensor
6861
            # zero-sized tensors
6862
            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6863
                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6864
                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6865
                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1
6866
                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1
6867
                ref = torch.from_numpy(
6868
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6869
                ).to(device=device, dtype=dtype).sum(0)
6870
                out_tensor = torch.zeros_like(ref)
6871
                yield b1, b2, ref, out_tensor
6872

6873
        for b1, b2, ref, out_tensor in generate_tensor():
6874
            self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
6875

6876
    @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
6877
    @onlyNativeDeviceTypes
6878
    @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6879
    @tf32_on_and_off(0.05)
6880
    @bf32_on_and_off(0.05)
6881
    def test_baddbmm(self, device, dtype):
6882
        if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6883
            # cuBLAS does not guarantee BFloat16 support on SM < 53.
6884
            # So on PyTorch, we consider BFloat16 support on SM < 53 as
6885
            # undefined bahavior
6886
            return
6887

6888
        num_batches = 10
6889
        M, N, O = 12, 8, 50
6890

6891
        is_supported = True
6892
        if dtype == torch.bfloat16 and self.device_type == 'cuda':
6893
            is_supported = TEST_WITH_ROCM or SM53OrLater
6894

6895
        if not is_supported:
6896
            b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6897
            b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6898
            t = make_tensor((num_batches, M, O), dtype=dtype, device=device, low=-1, high=1)
6899
            self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6900
                                   lambda: torch.baddbmm(t, b1, b2))
6901
            return
6902

6903
        def invert_perm(p):
6904
            d = {x: i for i, x in enumerate(p)}
6905
            return (d[0], d[1], d[2])
6906

6907
        def generate_tensor():
6908
            numpy_dtype = dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
6909
            # transposed tensors
6910
            for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3):
6911
                b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6912
                b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6913
                b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6914
                b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6915
                ref = torch.from_numpy(
6916
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6917
                out_tensor = torch.zeros_like(ref)
6918
                out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
6919
                yield b1, b2, ref, out_tensor
6920
            # broadcasting tensors
6921
            for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
6922
                shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
6923
                shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
6924
                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
6925
                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
6926
                ref = torch.from_numpy(
6927
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6928
                out_tensor = torch.zeros_like(ref)
6929
                yield b1, b2, ref, out_tensor
6930
            # zero-sized tensors
6931
            for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6932
                shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6933
                shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6934
                b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
6935
                b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
6936
                ref = torch.from_numpy(
6937
                    b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6938
                out_tensor = torch.zeros_like(ref)
6939
                yield b1, b2, ref, out_tensor
6940

6941
        for b1, b2, ref, out_tensor in generate_tensor():
6942
            self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
6943

6944
    @precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
6945
    @skipCUDAIfNoMagma
6946
    @skipCPUIfNoLapack
6947
    @dtypes(*floating_and_complex_types())
6948
    def test_pinverse(self, device, dtype):
6949
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
6950
        make_arg = partial(make_fullrank, device=device, dtype=dtype)
6951

6952
        def run_test(M):
6953
            # Testing against definition for pseudo-inverses
6954
            MPI = torch.pinverse(M)
6955
            MPI_ = MPI.cpu().numpy()
6956
            M_ = M.cpu().numpy()
6957
            if M.numel() > 0:
6958
                self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_))
6959
                self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_))
6960
                self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj())
6961
                self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj())
6962
            else:
6963
                self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2]))
6964
        for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5),  # square matrices
6965
                      (3, 2), (5, 3, 2), (7, 5, 3, 2),  # fat matrices
6966
                      (2, 3), (5, 2, 3), (7, 5, 2, 3),  # thin matrices
6967
                      (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:  # zero numel matrices
6968
            M = torch.randn(*sizes, dtype=dtype, device=device)
6969
            run_test(M)
6970

6971
        # Test inverse and pseudo-inverse for invertible matrix
6972
        for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]:
6973
            matsize = sizes[-1]
6974
            batchdims = sizes[:-2]
6975
            M = make_arg(*batchdims, matsize, matsize)
6976
            self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
6977
                             atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
6978

6979
    @skipCPUIfNoLapack
6980
    @skipCUDAIfNoMagmaAndNoCusolver
6981
    @dtypes(torch.double, torch.cdouble)
6982
    def test_matrix_power_non_negative(self, device, dtype):
6983
        def check(*size):
6984
            t = make_tensor(size, dtype=dtype, device=device)
6985
            for n in range(8):
6986
                res = torch.linalg.matrix_power(t, n)
6987
                ref = np.linalg.matrix_power(t.cpu().numpy(), n)
6988
                self.assertEqual(res.cpu(), torch.from_numpy(ref))
6989

6990
        check(0, 0)
6991
        check(1, 1)
6992
        check(5, 5)
6993
        check(0, 3, 3)
6994
        check(2, 3, 3)
6995

6996
    @skipCPUIfNoLapack
6997
    @skipCUDAIfNoMagmaAndNoCusolver
6998
    @dtypes(torch.double, torch.cdouble)
6999
    def test_matrix_power_negative(self, device, dtype):
7000
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7001
        make_arg = partial(make_fullrank, device=device, dtype=dtype)
7002

7003
        def check(*size):
7004
            t = make_arg(*size)
7005
            for n in range(-7, 0):
7006
                res = torch.linalg.matrix_power(t, n)
7007
                ref = np.linalg.matrix_power(t.cpu().numpy(), n)
7008
                self.assertEqual(res.cpu(), torch.from_numpy(ref))
7009

7010
        check(0, 0)
7011
        check(5, 5)
7012
        check(2, 0, 0)
7013
        check(0, 3, 3)
7014
        check(2, 3, 3)
7015
        check(2, 3, 5, 5)
7016

7017
    @skipCUDAIfNoMagma
7018
    @skipCPUIfNoLapack
7019
    @dtypes(torch.float, torch.complex64)
7020
    def test_linalg_matrix_exp_utils(self, device, dtype):
7021
        # test linear combination
7022
        def run_test(coeff_shape, data_shape):
7023
            coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float)
7024
            x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype)
7025

7026
            res1 = torch._compute_linear_combination(x, coeffs)
7027
            res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1)
7028
            self.assertEqual(res1, res2, atol=1e-5, rtol=0.0)
7029

7030
            # check `out=` version
7031
            res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7032
            torch._compute_linear_combination(x, coeffs, out=res3)
7033
            self.assertEqual(res1, res3, atol=1e-5, rtol=0.0)
7034

7035
            res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7036
            torch._compute_linear_combination(x, coeffs, out=res4)
7037
            self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0)
7038

7039
            res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7040
            res5_clone = res5.clone()
7041
            torch._compute_linear_combination(x, coeffs, out=res5)
7042
            self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0)
7043

7044
        run_test([1, 3], [2, 2])
7045
        run_test([3, 1], [2, 2])
7046
        run_test([1, 10], [10, 10])
7047
        run_test([10, 1], [10, 10])
7048
        run_test([5, 3], [2, 2])
7049
        run_test([5, 3], [100, 100])
7050
        run_test([3, 4], [3, 3, 3])
7051
        run_test([3, 4], [3, 3, 3, 3])
7052

7053
        # Regression test for https://github.com/pytorch/pytorch/issues/94124
7054
        with self.assertRaises(RuntimeError):
7055
            x = torch.rand([], device=device, dtype=dtype)
7056
            coeffs = torch.rand([2, 2], device=device, dtype=dtype)
7057
            res = torch._compute_linear_combination(x, coeffs)
7058

7059
    @onlyCPU
7060
    @skipCPUIfNoLapack
7061
    @dtypes(torch.complex64)
7062
    def test_linalg_matrix_exp_no_warnings(self, device, dtype):
7063
        # this tests https://github.com/pytorch/pytorch/issues/80948
7064
        with freeze_rng_state():
7065
            torch.manual_seed(42)
7066
            tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device)
7067
            tens = (0.5 * (tens.transpose(-1, -2) + tens))
7068
            with warnings.catch_warnings(record=True) as w:
7069
                tens.imag = torch.matrix_exp(tens.imag)
7070
                self.assertFalse(len(w))
7071

7072
    @skipCUDAIfNoMagma
7073
    @skipCPUIfNoLapack
7074
    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7075
    def test_linalg_matrix_exp_boundary_cases(self, device, dtype):
7076
        expm = torch.linalg.matrix_exp
7077

7078
        with self.assertRaisesRegex(RuntimeError, "Expected a floating point or complex tensor"):
7079
            expm(torch.randn(3, 3).type(torch.int))
7080

7081
        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7082
            expm(torch.randn(3))
7083

7084
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7085
            expm(torch.randn(3, 2, 1))
7086

7087
        # check 1x1 matrices
7088
        x = torch.randn(3, 3, 1, 1)
7089
        self.assertEqual(expm(x), x.exp())
7090

7091
    @skipCUDAIfNoMagma
7092
    @skipCPUIfNoLapack
7093
    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7094
    def test_linalg_matrix_exp_perverse_nan_values(self, device, dtype):
7095
        expm = torch.linalg.matrix_exp
7096

7097
        def with_nan(x):
7098
            x[0, 0, 0] = torch.nan
7099
            return x
7100

7101
        # Check small batches
7102
        x = with_nan(torch.randn(1, 1, 1))
7103
        self.assertTrue(torch.isnan(expm(x)).any())
7104
        x = with_nan(torch.randn(1, 2, 2))
7105
        for v in [1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 1000]:
7106
            self.assertTrue(torch.isnan(expm(x / v)).any())
7107

7108
        # Check large batches
7109
        x = with_nan(torch.randn(2, 2, 2))
7110
        self.assertTrue(torch.isnan(expm(x)).any())
7111
        x = with_nan(torch.randn(4096, 2, 2))
7112
        self.assertTrue(torch.isnan(expm(x)).any())
7113

7114
    @slowTest
7115
    @skipCUDAIfNoMagma
7116
    @skipCPUIfNoLapack
7117
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7118
    def test_linalg_matrix_exp_analytic(self, device, dtype):
7119
        expm = torch.linalg.matrix_exp
7120
        # check zero matrix
7121
        x = torch.zeros(20, 20, dtype=dtype, device=device)
7122
        self.assertTrue((expm(x) == torch.eye(20, 20, dtype=dtype, device=device)).all().item())
7123

7124
        def normalize_to_1_operator_norm(sample, desired_norm):
7125
            sample_norm, _ = sample.abs().sum(-2).max(-1)
7126
            sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7127
            return sample_to_1_norm * desired_norm
7128

7129
        def gen_good_cond_number_matrices(*n):
7130
            """
7131
            Generates a diagonally-domimant matrix
7132
            with the eigenvalues centered at 1
7133
            and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7134
            """
7135
            identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7136
            x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7137
            x = (x - x * identity) + identity
7138
            return x
7139

7140
        def run_test(*n):
7141
            if dtype == torch.float:
7142
                thetas = [
7143
                    1.192092800768788e-07,  # deg 1
7144
                    5.978858893805233e-04,  # deg 2
7145
                    5.116619363445086e-02,  # deg 4
7146
                    5.800524627688768e-01,  # deg 8
7147
                    1.461661507209034e+00,  # deg 12
7148
                    3.010066362817634e+00   # deg 18
7149
                ]
7150
            else:  # if torch.double
7151
                thetas = [
7152
                    2.220446049250313e-16,  # deg 1
7153
                    2.580956802971767e-08,  # deg 2
7154
                    3.397168839976962e-04,  # deg 4
7155
                    4.991228871115323e-02,  # deg 8
7156
                    2.996158913811580e-01,  # deg 12
7157
                    1.090863719290036e+00   # deg 18
7158
                ]
7159

7160
            # generate input
7161
            q = gen_good_cond_number_matrices(*n)
7162
            q_ = q.cpu().numpy()
7163
            qinv = torch.inverse(q)
7164
            qinv_ = qinv.cpu().numpy()
7165
            d = torch.randn(n[:-1], dtype=dtype, device=device)
7166
            x = torch.from_numpy(
7167
                np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device)
7168
            x_norm, _ = x.abs().sum(-2).max(-1)
7169

7170
            # test simple analytic whatever norm generated
7171
            mexp = expm(x)
7172
            mexp_analytic = np.matmul(
7173
                q_,
7174
                np.matmul(
7175
                    torch.diag_embed(d.exp()).cpu().numpy(),
7176
                    qinv_
7177
                )
7178
            )
7179
            self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7180

7181
            # generate norms to test different degree expansions
7182
            sample_norms = []
7183
            for i in range(len(thetas) - 1):
7184
                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7185
            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7186

7187
            # matrices to equal norm
7188
            for sample_norm in sample_norms:
7189
                x_normalized = normalize_to_1_operator_norm(x, sample_norm)
7190

7191
                mexp = expm(x_normalized)
7192
                mexp_analytic = np.matmul(
7193
                    q_,
7194
                    np.matmul(
7195
                        torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(),
7196
                        qinv_
7197
                    )
7198
                )
7199
                self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7200

7201
        # single matrix
7202
        run_test(2, 2)
7203
        run_test(3, 3)
7204
        run_test(4, 4)
7205
        run_test(5, 5)
7206
        run_test(100, 100)
7207
        run_test(200, 200)
7208

7209
        # small batch of matrices
7210
        run_test(3, 2, 2)
7211
        run_test(3, 3, 3)
7212
        run_test(3, 4, 4)
7213
        run_test(3, 5, 5)
7214
        run_test(3, 100, 100)
7215
        run_test(3, 200, 200)
7216

7217
        # large batch of matrices
7218
        run_test(3, 3, 2, 2)
7219
        run_test(3, 3, 3, 3)
7220
        run_test(3, 3, 4, 4)
7221
        run_test(3, 3, 5, 5)
7222
        run_test(3, 3, 100, 100)
7223
        run_test(3, 3, 200, 200)
7224

7225
    @skipCUDAIfNoMagma
7226
    @skipCPUIfNoLapack
7227
    @dtypes(torch.float, torch.double)
7228
    def test_linalg_matrix_exp_batch(self, device, dtype):
7229

7230
        def run_test(*n):
7231
            tensors_batch = torch.zeros(n, dtype=dtype, device=device)
7232
            tensors_batch = tensors_batch.view(-1, n[-2], n[-1])
7233

7234
            num_matrices = tensors_batch.size(0)
7235
            tensors_list = []
7236
            for i in range(num_matrices):
7237
                tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device))
7238

7239
            for i in range(num_matrices):
7240
                tensors_batch[i, ...] = tensors_list[i]
7241

7242
            tensors_exp_map = (torch.linalg.matrix_exp(x) for x in tensors_list)
7243
            tensors_exp_batch = torch.linalg.matrix_exp(tensors_batch)
7244

7245
            for i, tensor_exp in enumerate(tensors_exp_map):
7246
                self.assertEqual(tensors_exp_batch[i, ...], tensor_exp)
7247

7248
        # small batch of matrices
7249
        run_test(3, 2, 2)
7250
        run_test(3, 3, 3)
7251
        run_test(3, 4, 4)
7252
        run_test(3, 5, 5)
7253

7254
        # large batch of matrices
7255
        run_test(3, 3, 2, 2)
7256
        run_test(3, 3, 3, 3)
7257
        run_test(3, 3, 4, 4)
7258
        run_test(3, 3, 5, 5)
7259

7260
    @skipCUDAIfNoMagma
7261
    @skipCPUIfNoLapack
7262
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7263
    def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype):
7264

7265
        def normalize_to_1_operator_norm(sample, desired_norm):
7266
            sample_norm, _ = sample.abs().sum(-2).max(-1)
7267
            sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7268
            return sample_to_1_norm * desired_norm
7269

7270
        def gen_good_cond_number_matrices(*n):
7271
            """
7272
            Generates a diagonally-domimant matrix
7273
            with the eigenvalues centered at 1
7274
            and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7275
            """
7276
            identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7277
            x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7278
            x = (x - x * identity) + identity
7279
            return x
7280

7281
        def get_taylor_approximation(a, deg):
7282
            a_ = a.cpu().numpy()
7283
            identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a)
7284
            res = identity.cpu().numpy()
7285
            taylor_term = identity.cpu().numpy()
7286

7287
            for i in range(1, deg + 1):
7288
                taylor_term = np.matmul(a_, taylor_term) / i
7289
                res = res + taylor_term
7290

7291
            return res
7292

7293
        def scale_square(a, deg):
7294
            if a.abs().pow(2).sum().sqrt() < 1.0:
7295
                return get_taylor_approximation(a, 12)
7296
            else:
7297
                s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item())
7298
                b = a / (2 ** s)
7299
                b = get_taylor_approximation(b, 18)
7300
                for _ in range(s):
7301
                    b = np.matmul(b, b)
7302
                return torch.from_numpy(b).to(a.device)
7303

7304
        def run_test(*n):
7305
            degs = [1, 2, 4, 8, 12, 18]
7306
            if dtype == torch.float:
7307
                thetas = [
7308
                    1.192092800768788e-07,  # deg 1
7309
                    5.978858893805233e-04,  # deg 2
7310
                    5.116619363445086e-02,  # deg 4
7311
                    5.800524627688768e-01,  # deg 8
7312
                    1.461661507209034e+00,  # deg 12
7313
                    3.010066362817634e+00   # deg 18
7314
                ]
7315
            else:  # if torch.double
7316
                thetas = [
7317
                    2.220446049250313e-16,  # deg 1
7318
                    2.580956802971767e-08,  # deg 2
7319
                    3.397168839976962e-04,  # deg 4
7320
                    4.991228871115323e-02,  # deg 8
7321
                    2.996158913811580e-01,  # deg 12
7322
                    1.090863719290036e+00   # deg 18
7323
                ]
7324

7325
            # generate norms to test different degree expansions
7326
            sample_norms = []
7327
            for i in range(len(thetas) - 1):
7328
                sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7329
            sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7330
            degs = [degs[0]] + degs
7331

7332
            for sample_norm, deg in zip(sample_norms, degs):
7333
                x = gen_good_cond_number_matrices(*n)
7334
                x = normalize_to_1_operator_norm(x, sample_norm)
7335

7336
                mexp = torch.linalg.matrix_exp(x)
7337
                mexp_taylor = scale_square(x, deg)
7338

7339
                self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0)
7340

7341
        # single matrix
7342
        run_test(2, 2)
7343
        run_test(3, 3)
7344
        run_test(4, 4)
7345
        run_test(5, 5)
7346

7347
        # small batch of matrices
7348
        run_test(3, 2, 2)
7349
        run_test(3, 3, 3)
7350
        run_test(3, 4, 4)
7351
        run_test(3, 5, 5)
7352

7353
        # large batch of matrices
7354
        run_test(3, 3, 2, 2)
7355
        run_test(3, 3, 3, 3)
7356
        run_test(3, 3, 4, 4)
7357
        run_test(3, 3, 5, 5)
7358

7359
    @skipCUDAIfNoMagma
7360
    @skipCPUIfNoLapack
7361
    @dtypes(*floating_and_complex_types())
7362
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7363
                        torch.float64: 1e-8, torch.complex128: 1e-8})
7364
    def test_slogdet(self, device, dtype):
7365
        from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix,
7366
                                                          random_hermitian_pd_matrix, random_square_matrix_of_rank)
7367

7368
        # mat_chars denotes matrix characteristics
7369
        # possible values are: hermitian, hermitian_psd, hermitian_pd, singular, non_singular
7370
        def run_test(matsize, batchdims, mat_chars):
7371
            num_matrices = np.prod(batchdims)
7372
            list_of_matrices = []
7373
            if num_matrices != 0:
7374
                for idx in range(num_matrices):
7375
                    mat_type = idx % len(mat_chars)
7376
                    if mat_chars[mat_type] == 'hermitian':
7377
                        list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device))
7378
                    elif mat_chars[mat_type] == 'hermitian_psd':
7379
                        list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device))
7380
                    elif mat_chars[mat_type] == 'hermitian_pd':
7381
                        list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device))
7382
                    elif mat_chars[mat_type] == 'singular':
7383
                        list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7384
                    elif mat_chars[mat_type] == 'non_singular':
7385
                        list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7386
                full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7387
            else:
7388
                full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device)
7389

7390
            actual_value = torch.linalg.slogdet(full_tensor)
7391
            expected_value = np.linalg.slogdet(full_tensor.cpu().numpy())
7392
            self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision)
7393
            self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision)
7394

7395
            # test out=variant
7396
            sign_out = torch.empty_like(actual_value[0])
7397
            logabsdet_out = torch.empty_like(actual_value[1])
7398
            ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out))
7399
            self.assertEqual(ans[0], sign_out)
7400
            self.assertEqual(ans[1], logabsdet_out)
7401
            self.assertEqual(sign_out, actual_value[0])
7402
            self.assertEqual(logabsdet_out, actual_value[1])
7403

7404
        for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]):
7405
            run_test(matsize, batchdims, mat_chars=['hermitian_pd'])
7406
            run_test(matsize, batchdims, mat_chars=['singular'])
7407
            run_test(matsize, batchdims, mat_chars=['non_singular'])
7408
            run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd'])
7409
            run_test(matsize, batchdims, mat_chars=['singular', 'non_singular'])
7410

7411
    @skipCUDAIfNoMagma
7412
    @skipCPUIfNoLapack
7413
    @dtypes(*floating_and_complex_types())
7414
    def test_slogdet_errors_and_warnings(self, device, dtype):
7415
        # slogdet requires the input to be a square matrix or batch of square matrices
7416
        a = torch.randn(2, 3, device=device, dtype=dtype)
7417
        with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
7418
            torch.linalg.slogdet(a)
7419

7420
        # slogdet requires the input to be at least 2 dimensional tensor
7421
        a = torch.randn(2, device=device, dtype=dtype)
7422
        with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
7423
            torch.linalg.slogdet(a)
7424

7425
        a = torch.randn(2, 2, device=device, dtype=torch.bfloat16)
7426
        with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'):
7427
            torch.linalg.slogdet(a)
7428

7429
        # if non-empty out tensor with wrong shape is passed a warning is given
7430
        a = torch.randn(2, 3, 3, device=device, dtype=dtype)
7431
        sign_out = torch.empty(1, device=device, dtype=dtype)
7432
        real_dtype = a.real.dtype if dtype.is_complex else dtype
7433
        logabsdet_out = torch.empty(1, device=device, dtype=real_dtype)
7434
        with warnings.catch_warnings(record=True) as w:
7435
            # Trigger warning
7436
            torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7437
            # Check warning occurs
7438
            self.assertEqual(len(w), 1)
7439
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7440

7441
        # device should match
7442
        if torch.cuda.is_available():
7443
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7444
            sign_out = torch.empty(0, device=wrong_device, dtype=dtype)
7445
            logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype)
7446
            with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
7447
                torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7448

7449
    # FIXME One of the backends of lu_factor fails in windows. I haven't investigated which or why
7450
    # https://github.com/pytorch/pytorch/issues/75225
7451
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
7452
    @skipCUDAIfNoCusolver
7453
    @skipCPUIfNoLapack
7454
    @dtypes(torch.double)
7455
    def test_det_logdet_slogdet(self, device, dtype):
7456
        def reference_slogdet(M):
7457
            sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy())
7458
            return M.new_tensor(sdet), M.new_tensor(logabsdet)
7459

7460
        def test_single_det(M, target, desc):
7461
            target_sdet, target_logabsdet = target
7462

7463
            det = M.det()
7464
            logdet = M.logdet()
7465
            sdet, logabsdet = M.slogdet()
7466
            linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M)
7467

7468
            # Test det
7469
            self.assertEqual(det, target_sdet * target_logabsdet.exp(),
7470
                             atol=1e-6, rtol=0, msg=f'{desc} (det)')
7471

7472
            # Test slogdet
7473
            # Compare the overall value rather than individual parts because of
7474
            # precision issues when det is near zero.
7475
            self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7476
                             atol=1e-6, rtol=0, msg=f'{desc} (slogdet)')
7477
            self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7478
                             atol=1e-6, rtol=0, msg=f'{desc} (linalg_slogdet)')
7479

7480
            # Test logdet
7481
            # Compare logdet against our own pytorch slogdet because they should
7482
            # be consistent, while it may behave slightly differently with other
7483
            # slogdet implementations when det is near zero due to precision
7484
            # issues.
7485
            if sdet.item() < 0:
7486
                self.assertTrue(logdet.item() != logdet.item(), f'{desc} (logdet negative case)')
7487
            else:
7488
                self.assertEqual(logdet.exp(), target_logabsdet.exp(),
7489
                                 atol=1e-6, rtol=0, msg=f'{desc} (logdet non-negative case)')
7490

7491
        eye = torch.eye(5, dtype=dtype, device=device)
7492
        test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity')
7493
        # Testing bug in #34061 (https://github.com/pytorch/pytorch/issues/34061)
7494
        for n in range(250, 551, 100):
7495
            mat = torch.randn(n, n, dtype=dtype, device=device)
7496
            q, _ = torch.qr(mat)
7497
            ref_det, ref_logabsdet = reference_slogdet(q)
7498
            test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal')
7499

7500
        def test(M):
7501
            assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
7502
            M = M.to(device)
7503

7504
            ref_M_sdet, ref_M_logabsdet = reference_slogdet(M)
7505

7506
            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic')
7507
            if ref_M_logabsdet.exp().item() >= 1e-6:  # skip singular
7508
                M_inv = M.inverse()
7509
                test_single_det(M_inv, reference_slogdet(M_inv), 'inverse')
7510

7511
            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose')
7512

7513
            for x in [0, 2, 4]:
7514
                for scale in [-2, -0.1, 0, 10]:
7515
                    if scale > 0:
7516
                        target = ref_M_sdet, ref_M_logabsdet + math.log(scale)
7517
                    elif scale == 0:
7518
                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7519
                    else:
7520
                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale)
7521

7522
                    # dim 0
7523
                    M_clone = M.clone()
7524
                    M_clone[:, x] *= scale
7525
                    test_single_det(M_clone, target, 'scale a row')
7526
                    # dim 1
7527
                    M_clone = M.clone()
7528
                    M_clone[x, :] *= scale
7529
                    test_single_det(M_clone, target, 'scale a column')
7530

7531
            for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
7532
                assert x1 != x2, 'x1 and x2 needs to be different for this test'
7533
                target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7534
                # dim 0
7535
                M_clone = M.clone()
7536
                M_clone[:, x2] = M_clone[:, x1]
7537
                test_single_det(M_clone, target, 'two rows are same')
7538
                # dim 1
7539
                M_clone = M.clone()
7540
                M_clone[x2, :] = M_clone[x1, :]
7541
                test_single_det(M_clone, target, 'two columns are same')
7542

7543
                for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
7544
                    det_scale = scale1 * scale2 * -1
7545
                    if det_scale > 0:
7546
                        target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale)
7547
                    elif det_scale == 0:
7548
                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7549
                    else:
7550
                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale)
7551

7552
                    # dim 0
7553
                    M_clone = M.clone()
7554
                    t = M_clone[:, x1] * scale1
7555
                    M_clone[:, x1] += M_clone[:, x2] * scale2
7556
                    M_clone[:, x2] = t
7557
                    test_single_det(M_clone, target, 'exchanging rows')
7558
                    # dim 1
7559
                    M_clone = M.clone()
7560
                    t = M_clone[x1, :] * scale1
7561
                    M_clone[x1, :] += M_clone[x2, :] * scale2
7562
                    M_clone[x2, :] = t
7563
                    test_single_det(M_clone, target, 'exchanging columns')
7564

7565
        def get_random_mat_scale(n):
7566
            # For matrices with values i.i.d. with 0 mean, unit variance, and
7567
            # subexponential tail, we have:
7568
            #   E[log det(A^2)] \approx log((n-1)!)
7569
            #
7570
            # Notice:
7571
            #   log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)]
7572
            #
7573
            # So:
7574
            #   stddev[det(A)] >= sqrt( (n-1)! )
7575
            #
7576
            # We use this as an intuitive guideline to scale random generated
7577
            # matrices so our closeness tests can work more robustly:
7578
            #   scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n))
7579
            #
7580
            # source: https://arxiv.org/pdf/1112.0752.pdf
7581

7582
            # TODO: technically we need subexponential distn for this to hold,
7583
            #       but we mostly use gaussian entries below. Consider switching
7584
            #       to Chi-sq if this turns out not stable enough, since Chi-sq
7585
            #       is easy enough to sample from.
7586
            return math.factorial(n - 1) ** (-1.0 / (2 * n))
7587

7588
        for n in [5, 10, 25]:
7589
            scale = get_random_mat_scale(n)
7590
            test(torch.randn(n, n, dtype=dtype, device=device) * scale)
7591
            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7592
            # symmetric psd
7593
            test(r.mm(r.t()))
7594
            # symmetric pd
7595
            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7596
            test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6)
7597
            # symmetric
7598
            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7599
            for i in range(n):
7600
                for j in range(i):
7601
                    r[i, j] = r[j, i]
7602
            test(r)
7603
            # non-contiguous
7604
            test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:])
7605
            # det = 0
7606
            r = torch.randn(n, n, dtype=dtype, device=device) * scale
7607
            u, s, v = r.svd()
7608
            if reference_slogdet(u)[0] < 0:
7609
                u = -u
7610
            if reference_slogdet(v)[0] < 0:
7611
                v = -v
7612
            s[0] *= -1
7613
            s[-1] = 0
7614
            test(u.mm(s.diag()).mm(v))
7615

7616
        # Small values to test numerical stability. Note that we don't scale
7617
        # this matrix.
7618
        r = torch.randn(512, 512, dtype=dtype, device=device)
7619
        u, s, v = r.svd()
7620
        s.fill_(1. / (100 * s.numel()))
7621
        test(u.mm(s.diag()).mm(v))
7622

7623
    @skipCUDAIfNoMagma
7624
    @skipCPUIfNoLapack
7625
    @dtypes(torch.double)
7626
    def test_det_logdet_slogdet_batched(self, device, dtype):
7627
        from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix,
7628
                                                          random_symmetric_pd_matrix, random_square_matrix_of_rank)
7629

7630
        # mat_chars denotes matrix characteristics
7631
        # possible values are: sym, sym_psd, sym_pd, sing, non_sym
7632
        def run_test(matsize, batchdims, mat_chars):
7633
            num_matrices = reduce(operator.mul, batchdims, 1)
7634
            list_of_matrices = []
7635

7636
            for idx in range(num_matrices):
7637
                mat_type = idx % len(mat_chars)
7638
                if mat_chars[mat_type] == 'sym':
7639
                    list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device))
7640
                elif mat_chars[mat_type] == 'sym_psd':
7641
                    list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device))
7642
                elif mat_chars[mat_type] == 'sym_pd':
7643
                    list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device))
7644
                elif mat_chars[mat_type] == 'sing':
7645
                    list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7646
                elif mat_chars[mat_type] == 'non_sing':
7647
                    list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7648
            full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7649
            # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet
7650
            full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize)))
7651

7652
            for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
7653
                expected_value = []
7654
                actual_value = fn(full_tensor)
7655
                for full_idx in itertools.product(*(list(range(x)) for x in batchdims)):
7656
                    expected_value.append(fn(full_tensor[full_idx]))
7657

7658
                if fn == torch.slogdet or fn == torch.linalg.slogdet:
7659
                    sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims)
7660
                    expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims)
7661
                    self.assertEqual(sign_value, actual_value[0])
7662
                    self.assertEqual(expected_value, actual_value[1])
7663
                else:
7664
                    expected_value = torch.stack(expected_value, dim=0).reshape(batchdims)
7665
                    self.assertEqual(actual_value, expected_value)
7666

7667
        for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]):
7668
            run_test(matsize, batchdims, mat_chars=['sym_pd'])
7669
            run_test(matsize, batchdims, mat_chars=['sing'])
7670
            run_test(matsize, batchdims, mat_chars=['non_sing'])
7671
            run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd'])
7672
            run_test(matsize, batchdims, mat_chars=['sing', 'non_sing'])
7673

7674
    @skipCUDAIfNoMagma
7675
    @skipCPUIfNoLapack
7676
    @dtypes(*floating_and_complex_types())
7677
    def test_cholesky_inverse(self, device, dtype):
7678
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
7679

7680
        def run_test(shape, batch, upper, contiguous):
7681
            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
7682
            if A.numel() > 0 and not contiguous:
7683
                A = A.mT
7684
                self.assertFalse(A.is_contiguous())
7685
            L = torch.linalg.cholesky(A)
7686
            expected_inverse = torch.inverse(A)
7687
            L = L.mH if upper else L
7688
            actual_inverse = torch.cholesky_inverse(L, upper)
7689
            self.assertEqual(actual_inverse, expected_inverse)
7690

7691
        shapes = (0, 3, 5)
7692
        batches = ((), (0,), (3, ), (2, 2))
7693
        for shape, batch, upper, contiguous in list(itertools.product(shapes, batches, (True, False), (True, False))):
7694
            run_test(shape, batch, upper, contiguous)
7695

7696
        # check the out= variant
7697
        A = random_hermitian_pd_matrix(3, 2, dtype=dtype, device=device)
7698
        L = torch.linalg.cholesky(A)
7699

7700
        # There are two code paths currently for the out= variant
7701
        # 1. When 'out' tensor is in Fortran (column-major) memory format
7702
        # then the fast route is taken and the storage is reused directly in the computations
7703
        # 2. When 'out' tensor is not in Fortran format then a temporary tensor is allocated internally
7704
        # and the result is copied from the temporary tensor to 'out' tensor
7705

7706
        # This test checks the first code path
7707
        out = torch.empty_like(A)
7708
        out_t = out.mT.clone(memory_format=torch.contiguous_format)
7709
        out = out_t.mT
7710
        ans = torch.cholesky_inverse(L, out=out)
7711
        self.assertEqual(ans, out)
7712
        expected = torch.inverse(A)
7713
        self.assertEqual(expected, out)
7714

7715
        # This test checks the second code path
7716
        out = torch.empty_like(A)
7717
        ans = torch.cholesky_inverse(L, out=out)
7718
        self.assertEqual(ans, out)
7719
        expected = torch.inverse(A)
7720
        self.assertEqual(expected, out)
7721

7722
    @skipCUDAIfNoMagma
7723
    @skipCPUIfNoLapack
7724
    @dtypes(*floating_and_complex_types())
7725
    def test_cholesky_inverse_errors_and_warnings(self, device, dtype):
7726
        # cholesky_inverse requires the input to be at least 2 dimensional tensor
7727
        a = torch.randn(2, device=device, dtype=dtype)
7728
        with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7729
            torch.cholesky_inverse(a)
7730

7731
        # cholesky_inverse requires a square matrix
7732
        a = torch.randn(2, 3, device=device, dtype=dtype)
7733
        with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7734
            torch.cholesky_inverse(a)
7735

7736
        # if non-empty out tensor with wrong shape is passed a warning is given
7737
        a = torch.randn(3, 3, device=device, dtype=dtype)
7738
        out = torch.empty(2, 3, device=device, dtype=dtype)
7739
        with warnings.catch_warnings(record=True) as w:
7740
            # Trigger warning
7741
            torch.cholesky_inverse(a, out=out)
7742
            # Check warning occurs
7743
            self.assertEqual(len(w), 1)
7744
            self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7745

7746
        # dtypes should be safely castable
7747
        out = torch.empty(*a.shape, dtype=torch.int, device=device)
7748
        with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
7749
            torch.cholesky_inverse(a, out=out)
7750

7751
        # device should match
7752
        if torch.cuda.is_available():
7753
            wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7754
            out = torch.empty(0, device=wrong_device, dtype=dtype)
7755
            with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
7756
                torch.cholesky_inverse(a, out=out)
7757

7758
        # cholesky_inverse raises an error for invalid inputs on CPU
7759
        # for example if at least one diagonal element is zero
7760
        a = torch.randn(3, 3, device=device, dtype=dtype)
7761
        a[1, 1] = 0
7762
        if self.device_type == 'cpu':
7763
            with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"):
7764
                torch.cholesky_inverse(a)
7765
        # cholesky_inverse on GPU does not raise an error for this case
7766
        elif self.device_type == 'cuda':
7767
            out = torch.cholesky_inverse(a)
7768
            self.assertTrue(out.isinf().any() or out.isnan().any())
7769

7770
    def _select_broadcastable_dims(self, dims_full=None):
7771
        # select full dimensionality
7772
        if dims_full is None:
7773
            dims_full = []
7774
            ndims = random.randint(1, 4)
7775
            dims_full = [random.randint(1, 8) for _ in range(ndims)]
7776
        else:
7777
            ndims = len(dims_full)
7778

7779
        # select actual dimensions for ops:
7780
        # larger: full ndims, individual sizes may be reduced
7781
        # smaller: possibly reduced ndims, sizes may be reduced
7782
        smaller_ndims = random.randint(1, ndims)
7783
        dims_small = []
7784
        dims_large = []
7785
        for i in range(ndims - 1, -1, -1):
7786
            j = random.randint(1, 3)
7787
            if j == 1:  # no reduced singleton dimension
7788
                ds = dims_full[i]
7789
                dl = dims_full[i]
7790
            elif j == 2:  # larger may have reduced singleton dimension
7791
                ds = dims_full[i]
7792
                dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
7793
            elif j == 3:  # smaller may have reduced singleton dimension
7794
                ds = 1
7795
                dl = dims_full[i]
7796
            dims_large = [dl] + dims_large
7797
            if len(dims_small) < smaller_ndims:
7798
                dims_small = [ds] + dims_small
7799
        return (dims_small, dims_large, dims_full)
7800

7801
    def test_broadcast_fused_matmul(self, device):
7802
        fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
7803

7804
        for fn in fns:
7805
            batch_dim = random.randint(1, 8)
7806
            n_dim = random.randint(1, 8)
7807
            m_dim = random.randint(1, 8)
7808
            p_dim = random.randint(1, 8)
7809

7810
            def dims_full_for_fn():
7811
                if fn == "baddbmm":
7812
                    return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7813
                elif fn == "addbmm":
7814
                    return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7815
                elif fn == "addmm":
7816
                    return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
7817
                elif fn == "addmv":
7818
                    return ([n_dim], [n_dim, m_dim], [m_dim])
7819
                elif fn == "addr":
7820
                    return ([n_dim, m_dim], [n_dim], [m_dim])
7821
                else:
7822
                    raise AssertionError("unknown function")
7823

7824
            (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
7825
            (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
7826

7827
            t0_small = torch.randn(*t0_dims_small, device=device).float()
7828
            t1 = torch.randn(*t1_dims, device=device).float()
7829
            t2 = torch.randn(*t2_dims, device=device).float()
7830

7831
            t0_full = t0_small.expand(*t0_dims_full).to(device)
7832

7833
            fntorch = getattr(torch, fn)
7834
            r0 = fntorch(t0_small, t1, t2)
7835
            r1 = fntorch(t0_full, t1, t2)
7836
            self.assertEqual(r0, r1)
7837

7838
    @tf32_on_and_off(0.001)
7839
    @bf32_on_and_off(0.001)
7840
    def test_broadcast_batched_matmul(self, device):
7841
        n_dim = random.randint(1, 8)
7842
        m_dim = random.randint(1, 8)
7843
        p_dim = random.randint(1, 8)
7844
        full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))]
7845
        (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
7846

7847
        def verify_batched_matmul(full_lhs, one_dimensional):
7848
            if not one_dimensional:
7849
                lhs_dims = [n_dim, m_dim]
7850
                rhs_dims = [m_dim, p_dim]
7851
                result_dims = [n_dim, p_dim]
7852
            else:
7853
                lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim]
7854
                rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim]
7855
                result_dims = [n_dim] if full_lhs else [p_dim]
7856

7857
            lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim]
7858
            rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1]
7859
            full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims
7860
            dim0_dims = rhs_dims if full_lhs else lhs_dims
7861
            small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims)
7862

7863
            small = torch.randn(*(small_dims), device=device).float()
7864
            dim0 = torch.randn(*(dim0_dims), device=device).float()
7865
            full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float()
7866
            if not one_dimensional:
7867
                (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
7868
            else:
7869
                (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,))
7870

7871
            def maybe_squeeze_result(l, r, result):
7872
                if len(lhs_dims) == 1 and l.dim() != 1:
7873
                    return result.squeeze(-2)
7874
                elif len(rhs_dims) == 1 and r.dim() != 1:
7875
                    return result.squeeze(-1)
7876
                else:
7877
                    return result
7878

7879
            for lhs in lhsTensors:
7880
                lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
7881
                lhs_expanded_matmul_fn = lhs_expanded.matmul
7882
                for rhs in rhsTensors:
7883
                    rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)).
7884
                                    expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
7885
                    truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
7886
                    for l in (lhs, lhs_expanded):
7887
                        for r in (rhs, rhs_expanded):
7888
                            l_matmul_fn = l.matmul
7889
                            result = maybe_squeeze_result(l, r, l_matmul_fn(r))
7890
                            self.assertEqual(truth, result)
7891
                            # test torch.matmul function as well
7892
                            torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
7893
                            self.assertEqual(truth, torch_result)
7894
                            # test torch.matmul with out
7895
                            out = torch.zeros_like(torch_result)
7896
                            torch.matmul(l, r, out=out)
7897
                            self.assertEqual(truth, maybe_squeeze_result(l, r, out))
7898

7899
                # compare to bmm
7900
                bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
7901
                                        rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
7902
                self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
7903

7904
        for indices in itertools.product((True, False), repeat=2):
7905
            verify_batched_matmul(*indices)
7906

7907
    def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
7908
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7909
        make_A = partial(make_fullrank, device=device, dtype=dtype)
7910

7911
        b = torch.randn(*b_dims, dtype=dtype, device=device)
7912
        A = make_A(*A_dims)
7913
        LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
7914
        self.assertEqual(info, torch.zeros_like(info))
7915
        return b, A, LU_data, LU_pivots
7916

7917
    @skipCPUIfNoLapack
7918
    @skipCUDAIfNoMagmaAndNoCusolver
7919
    @dtypes(*floating_and_complex_types())
7920
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7921
                        torch.float64: 1e-8, torch.complex128: 1e-8})
7922
    def test_lu_solve(self, device, dtype):
7923
        def sub_test(pivot):
7924
            for k, n in zip([2, 3, 5], [3, 5, 7]):
7925
                b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype)
7926
                x = torch.lu_solve(b, LU_data, LU_pivots)
7927
                self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
7928

7929
        sub_test(True)
7930
        if self.device_type == 'cuda':
7931
            sub_test(False)
7932

7933
    @skipCPUIfNoLapack
7934
    @skipCUDAIfNoMagmaAndNoCusolver
7935
    @dtypes(*floating_and_complex_types())
7936
    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7937
                        torch.float64: 1e-8, torch.complex128: 1e-8})
7938
    def test_lu_solve_batched(self, device, dtype):
7939
        def sub_test(pivot):
7940
            def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
7941
                b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
7942
                x_exp_list = []
7943
                for i in range(b_dims[0]):
7944
                    x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
7945
                x_exp = torch.stack(x_exp_list)  # Stacked output
7946
                x_act = torch.lu_solve(b, LU_data, LU_pivots)  # Actual output
7947
                self.assertEqual(x_exp, x_act)  # Equality check
7948
                Ax = np.matmul(A.cpu(), x_act.cpu())
7949
                self.assertEqual(b, Ax)
7950

7951
            for batchsize in [1, 3, 4]:
7952
                lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot)
7953

7954
        # Tests tensors with 0 elements
7955
        b = torch.randn(3, 0, 3, dtype=dtype, device=device)
7956
        A = torch.randn(3, 0, 0, dtype=dtype, device=device)
7957
        LU_data, LU_pivots = torch.linalg.lu_factor(A)
7958
        self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
7959

7960
        sub_test(True)
7961
        if self.device_type == 'cuda':
7962
            sub_test(False)
7963

7964
    @slowTest
7965
    @skipCPUIfNoLapack
7966
    @skipCUDAIfNoMagmaAndNoCusolver
7967
    @dtypes(*floating_and_complex_types())
7968
    def test_lu_solve_batched_many_batches(self, device, dtype):
7969
        def run_test(A_dims, b_dims):
7970
            b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
7971
            x = torch.lu_solve(b, LU_data, LU_pivots)
7972
            Ax = torch.matmul(A, x)
7973
            self.assertEqual(Ax, b.expand_as(Ax))
7974

7975
        run_test((65536, 5, 5), (65536, 5, 10))
7976
        run_test((262144, 5, 5), (262144, 5, 10))
7977

7978
    @skipCPUIfNoLapack
7979
    @skipCUDAIfNoMagmaAndNoCusolver
7980
    @dtypes(*floating_and_complex_types())
7981
    def test_lu_solve_batched_broadcasting(self, device, dtype):
7982
        make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7983
        make_A = partial(make_fullrank, device=device, dtype=dtype)
7984

7985
        def run_test(A_dims, b_dims, pivot=True):
7986
            A_matrix_size = A_dims[-1]
7987
            A_batch_dims = A_dims[:-2]
7988
            A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
7989
            b = make_tensor(b_dims, dtype=dtype, device=device)
7990
            x_exp = np.linalg.solve(A.cpu(), b.cpu())
7991
            LU_data, LU_pivots = torch.linalg.lu_factor(A)
7992
            x = torch.lu_solve(b, LU_data, LU_pivots)
7993
            self.assertEqual(x, x_exp)
7994

7995
        # test against numpy.linalg.solve
7996
        run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6))  # no broadcasting
7997
        run_test((2, 1, 3, 4, 4), (4, 6))  # broadcasting b
7998
        run_test((4, 4), (2, 1, 3, 4, 2))  # broadcasting A
7999
        run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))  # broadcasting A & b
8000

8001
    @onlyCUDA
8002
    @skipCUDAIfNoMagma
8003
    @dtypes(*floating_and_complex_types())
8004
    # this tests https://github.com/pytorch/pytorch/issues/36921
8005
    def test_lu_solve_large_matrices(self, device, dtype):
8006
        def run_test(A_dims, b_dims):
8007
            b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
8008
            x = torch.lu_solve(b, LU_data, LU_pivots)
8009
            Ax = torch.matmul(A, x)
8010
            self.assertEqual(Ax, b.expand_as(Ax))
8011

8012
        run_test((1, 1), (1, 1, 1025))
8013

8014
    @skipCUDAIfNoCusolver
8015
    @skipCPUIfNoLapack
8016
    def test_pca_lowrank(self, device):
8017
        from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
8018

8019
        dtype = torch.double
8020

8021
        def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options):
8022
            density = options.pop('density', 1)
8023
            use_svd_lowrank = options.pop('use_svd_lowrank', False)
8024
            if isinstance(matrix_size, int):
8025
                rows = columns = matrix_size
8026
            else:
8027
                rows, columns = matrix_size
8028
            if density == 1:
8029
                a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
8030
                a = a_input
8031
            else:
8032
                a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
8033
                a = a_input.to_dense()
8034

8035
            if use_svd_lowrank:
8036
                m = a_input.mean(dim=-2, keepdim=True)
8037
                u, s, v = pca(a_input, q=guess_rank, M=m, **options)
8038
            else:
8039
                u, s, v = pca(a_input, q=guess_rank, **options)
8040

8041
            self.assertEqual(s.shape[-1], guess_rank)
8042
            self.assertEqual(u.shape[-2], rows)
8043
            self.assertEqual(u.shape[-1], guess_rank)
8044
            self.assertEqual(v.shape[-1], guess_rank)
8045
            self.assertEqual(v.shape[-2], columns)
8046

8047
            A1 = u.matmul(s.diag_embed()).matmul(v.mT)
8048
            ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device)
8049
            c = a.sum(axis=-2) / rows
8050
            c = c.reshape(batches + (1, columns))
8051
            A2 = a - ones_m1.matmul(c)
8052
            self.assertEqual(A1, A2)
8053

8054
            if density == 1:
8055
                # actual rank is known only for dense input
8056
                detect_rank = (s.abs() > 1e-5).sum(axis=-1)
8057
                self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank)
8058
                S = torch.linalg.svdvals(A2)
8059
                self.assertEqual(s[..., :actual_rank], S[..., :actual_rank])
8060

8061
        all_batches = [(), (1,), (3,), (2, 3)]
8062
        for actual_rank, size, all_batches in [  # noqa: B020
8063
                (2, (17, 4), all_batches),
8064
                (2, (100, 4), all_batches),
8065
                (6, (100, 40), all_batches),
8066
                (12, (1000, 1000), [()]),
8067
        ]:
8068
            for batches in all_batches:
8069
                for guess_rank in [
8070
                        actual_rank,
8071
                        actual_rank + 2,
8072
                        actual_rank + 6,
8073
                ]:
8074
                    if guess_rank <= min(*size):
8075
                        run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank)
8076
                        run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank)
8077
                        run_subtest(guess_rank, actual_rank, size, batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8078
                        run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8079

8080
        # sparse input
8081
        for guess_rank, size in [
8082
                (4, (17, 4)), (4, (4, 17)), (16, (17, 17)),
8083
                (21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]:
8084
            for density in [0.005, 0.1]:
8085
                run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density)
8086

8087
        # jitting support
8088
        jitted = torch.jit.script(torch.pca_lowrank)
8089
        guess_rank, actual_rank, size, batches = 2, 2, (17, 4), ()
8090
        run_subtest(guess_rank, actual_rank, size, batches, device, jitted)
8091

8092
    # Ensure that nuclear_norm's out variant gives the same result as the non-out
8093
    @onlyNativeDeviceTypes
8094
    @skipCUDAIfNoMagma
8095
    @skipCPUIfNoLapack
8096
    @dtypes(torch.float32, torch.float64)
8097
    def test_nuclear_norm_out(self, device, dtype):
8098
        test_cases = [
8099
            # input size, dim
8100
            ((25, 25), None),
8101
            ((25, 25), (0, 1)),
8102
            ((25, 25), (1, 0)),
8103
            ((25, 25, 25), (2, 0)),
8104
            ((25, 25, 25), (0, 1)),
8105
        ]
8106
        for keepdim in [False, True]:
8107
            for input_size, dim in test_cases:
8108
                msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}'
8109
                x = torch.randn(*input_size, device=device, dtype=dtype)
8110
                result_out = torch.empty(0, device=device, dtype=dtype)
8111
                if dim is None:
8112
                    result = torch.nuclear_norm(x, keepdim=keepdim)
8113
                    torch.nuclear_norm(x, keepdim=keepdim, out=result_out)
8114
                else:
8115
                    result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim)
8116
                    torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out)
8117
                self.assertEqual(result, result_out, msg=msg)
8118

8119
    @skipCUDAIfNoMagmaAndNoCusolver
8120
    @skipCPUIfNoLapack
8121
    @dtypes(*floating_and_complex_types())
8122
    def test_geqrf(self, device, dtype):
8123

8124
        def run_test(shape):
8125
            # numpy.linalg.qr with mode = 'raw' computes the same operation as torch.geqrf
8126
            # so this test compares against that function
8127
            A = make_tensor(shape, dtype=dtype, device=device)
8128

8129
            # numpy.linalg.qr doesn't work with batched input
8130
            m, n = A.shape[-2:]
8131
            tau_size = "n" if m > n else "m"
8132
            np_dtype = A.cpu().numpy().dtype
8133
            ot = [np_dtype, np_dtype]
8134
            numpy_geqrf_batched = np.vectorize(
8135
                lambda x: np.linalg.qr(x, mode='raw'),
8136
                otypes=ot,
8137
                signature=f'(m,n)->(n,m),({tau_size})')
8138

8139
            expected = numpy_geqrf_batched(A.cpu())
8140
            actual = torch.geqrf(A)
8141

8142
            # numpy.linalg.qr returns transposed result
8143
            self.assertEqual(expected[0].swapaxes(-2, -1), actual[0])
8144
            self.assertEqual(expected[1], actual[1])
8145

8146
        batches = [(), (0, ), (2, ), (2, 1)]
8147
        ns = [5, 2, 0]
8148
        for batch, (m, n) in product(batches, product(ns, ns)):
8149
            run_test((*batch, m, n))
8150

8151
    @skipCUDAIfNoMagma
8152
    @skipCPUIfNoLapack
8153
    def test_lapack_empty(self, device):
8154
        # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here.
8155
        # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although
8156
        # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing"
8157
        # (e.g. lu).  We often name our functions identically to the lapack function, so it will take work
8158
        # to name / migrate-to better wrappers.
8159
        def fn(torchfn, *args):
8160
            return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
8161
                                  for shape in args))
8162

8163
        # inverse, pinverse
8164
        self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape)
8165
        self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape)
8166
        self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape)
8167
        self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape)
8168

8169
        # det, logdet, slogdet
8170
        self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0)))
8171
        self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0)))
8172
        self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)),
8173
                         fn(torch.slogdet, (0, 0)))
8174

8175
    @tf32_on_and_off(0.005)
8176
    @bf32_on_and_off(0.005)
8177
    def test_tensordot(self, device):
8178
        a = torch.arange(60., device=device).reshape(3, 4, 5)
8179
        b = torch.arange(24., device=device).reshape(4, 3, 2)
8180
        c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
8181
        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8182
                                           axes=([1, 0], [0, 1])))
8183
        self.assertEqual(c, cn)
8184

8185
        cout = torch.zeros((5, 2), device=device)
8186
        torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
8187
        self.assertEqual(c, cout)
8188

8189
        a = torch.randn(2, 3, 4, 5, device=device)
8190
        b = torch.randn(4, 5, 6, 7, device=device)
8191
        c = torch.tensordot(a, b, dims=2).cpu()
8192
        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8193
                                           axes=2))
8194

8195
        with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
8196
            torch.tensordot(a, b, dims=-1)
8197

8198
        self.assertEqual(c, cn)
8199
        c = torch.tensordot(a, b).cpu()
8200
        cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
8201
        self.assertEqual(c, cn)
8202

8203
        a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0)
8204
        an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
8205
        self.assertEqual(a, an)
8206

8207
    @skipCUDAIfNoCusolver
8208
    @skipCUDAIfNoMagma
8209
    @skipCPUIfNoLapack
8210
    @skipIfTorchDynamo("flaky, needs investigation")
8211
    @dtypes(*floating_and_complex_types())
8212
    def test_ldl_factor(self, device, dtype):
8213
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8214

8215
        def run_test(shape, batch, hermitian):
8216
            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8217
            actual_factors, actual_pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8218
            actual_L = torch.tril(actual_factors, diagonal=-1)
8219
            actual_L.diagonal(0, -2, -1).fill_(1.0)
8220

8221
            # This test is designed only for inputs with 1x1 block diagonal matrix D.
8222
            # That is for positive definite input matrices, the pivots tensor is always > 0.
8223
            # If negative pivots are encountered, it means that the input matrix is not positive definite.
8224
            # And matrix D is a 2x2 block diagonal matrix.
8225
            self.assertTrue((actual_pivots > 0).all())
8226

8227
            # Construct a 1x1 block diagonal matrix D from factors.
8228
            actual_D = torch.diag_embed(actual_factors.diagonal(0, -2, -1))
8229

8230
            def T(x):
8231
                return x.mH if hermitian else x.mT
8232
            A_reconstructed = actual_L @ actual_D @ T(actual_L)
8233

8234
            def symmetric(A):
8235
                return A.tril() + A.tril(-1).mT
8236

8237
            self.assertEqual(symmetric(A) if not hermitian else A, A_reconstructed)
8238

8239
            # Now test against SciPy implementation
8240
            if TEST_SCIPY:
8241
                from scipy.linalg import ldl as scipy_ldl
8242
                A_np = A.cpu().numpy()
8243
                np_dtype = A_np.dtype
8244
                scipy_ldl_batched = np.vectorize(
8245
                    lambda x: scipy_ldl(x, hermitian=hermitian, lower=True),
8246
                    otypes=[np_dtype, np_dtype, np.dtype('int64')],
8247
                    signature='(m,m)->(m,m),(m,m),(m)')
8248

8249
                expected = scipy_ldl_batched(A_np)
8250
                expected_L, expected_D, expected_pivots = expected
8251

8252
                if expected_pivots.ndim > 1:
8253
                    permuted_expected_L = np.stack(
8254
                        [expected_L[i][expected_pivots[i], :] for i in range(expected_pivots.shape[0])]
8255
                    )
8256
                else:
8257
                    permuted_expected_L = expected_L[expected_pivots, :]
8258
                self.assertEqual(actual_L, permuted_expected_L)
8259
                self.assertEqual(actual_D, expected_D)
8260
            else:
8261
                self.assertEqual(actual_factors.shape, A.shape)
8262
                self.assertEqual(actual_pivots.shape, A.shape[:-1])
8263
                self.assertEqual(info.shape, A.shape[:-2])
8264

8265
        # hermitian=True for complex inputs on CUDA is supported only with MAGMA 2.5.4+
8266
        magma_254_available = self.device_type == 'cuda' and _get_magma_version() >= (2, 5, 4)
8267
        hermitians = (True, False) if dtype.is_complex and (self.device_type == 'cpu' or magma_254_available) else (False,)
8268

8269
        shapes = (5,)
8270
        batches = ((), (4,),)
8271
        for shape, batch, hermitian in itertools.product(shapes, batches, hermitians):
8272
            run_test(shape, batch, hermitian)
8273

8274
    @skipCUDAIfNoCusolver
8275
    @skipCUDAIfNoMagma
8276
    @skipCPUIfNoLapack
8277
    @skipCUDAIfRocm
8278
    @skipCUDAIf(_get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1")
8279
    @dtypes(*floating_and_complex_types())
8280
    def test_ldl_solve(self, device, dtype):
8281
        from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8282

8283
        def run_test(shape, batch, nrhs, hermitian):
8284
            A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8285
            B = make_tensor((*A.shape[:-1], nrhs), dtype=dtype, device=device)
8286
            factors, pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8287
            X = torch.linalg.ldl_solve(factors, pivots, B, hermitian=hermitian)
8288

8289
            def symmetric(A):
8290
                return A.tril() + A.tril(-1).mT
8291

8292
            # verify A @ X == B
8293
            expected_B = symmetric(A) @ X if not hermitian else A @ X
8294
            self.assertEqual(B, expected_B)
8295

8296
        # hermitian=True is not supported on CUDA yet
8297
        hermitians = (True, False) if dtype.is_complex and self.device_type == 'cpu' else (False,)
8298

8299
        shapes = (5,)
8300
        batches = ((), (4,), (2, 2))
8301
        nrhss = (1, 7)
8302
        for shape, batch, nrhs, hermitian in itertools.product(shapes, batches, nrhss, hermitians):
8303
            run_test(shape, batch, nrhs, hermitian)
8304

8305
    @onlyCUDA
8306
    @skipCUDAIfNoMagma
8307
    @skipCUDAIfNoCusolver
8308
    @setLinalgBackendsToDefaultFinally
8309
    def test_preferred_linalg_library(self):
8310
        # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions.
8311
        x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double)
8312

8313
        torch.backends.cuda.preferred_linalg_library('cusolver')
8314
        out1 = torch.linalg.inv(x)
8315

8316
        torch.backends.cuda.preferred_linalg_library('magma')
8317
        out2 = torch.linalg.inv(x)
8318

8319
        torch.backends.cuda.preferred_linalg_library('default')
8320
        # Although linalg preferred flags doesn't affect CPU currently,
8321
        # we set this to make sure the flag can switch back to default normally.
8322
        out_ref = torch.linalg.inv(x.cpu())
8323

8324
        self.assertEqual(out_ref, out1.cpu())
8325
        self.assertEqual(out1, out2)
8326

8327
    @onlyCUDA
8328
    @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
8329
    @setBlasBackendsToDefaultFinally
8330
    def test_preferred_blas_library(self):
8331
        # The main purpose of this test is to make sure these "backend" calls work normally without raising exceptions.
8332
        m1 = torch.randint(2, 5, (2048, 2400), device='cuda', dtype=torch.float)
8333
        m2 = torch.randint(2, 5, (128, 2400), device='cuda', dtype=torch.float)
8334

8335
        torch.backends.cuda.preferred_blas_library('cublaslt')
8336
        out1 = torch.nn.functional.linear(m1, m2)
8337

8338
        torch.backends.cuda.preferred_blas_library('cublas')
8339
        out2 = torch.nn.functional.linear(m1, m2)
8340

8341
        # Although blas preferred flags doesn't affect CPU currently,
8342
        # we set this to make sure the flag can switch back to default normally.
8343
        out_ref = torch.nn.functional.linear(m1.cpu(), m2.cpu())
8344

8345
        self.assertEqual(out1, out2)
8346
        self.assertEqual(out_ref, out2.cpu())
8347

8348
    def test_permute_matmul(self):
8349
        a = torch.ones([2, 5, 24, 24])
8350
        b = torch.ones([3, 2, 5, 24, 24])
8351
        c = a.permute(0, 1, 3, 2).matmul(b)
8352
        self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720])
8353

8354
    def test_lower_precision_accumulation_with_ref_path(self):
8355
        # fix https://github.com/pytorch/pytorch/issues/95125
8356
        # and https://github.com/pytorch/pytorch/issues/83863
8357
        # for bf16 accumulation in gemm ref path
8358
        def check_correctness(fn, dtype, *args):
8359
            expected = fn(*args).to(dtype=dtype)
8360
            with torch.backends.mkldnn.flags(enabled=False):
8361
                def test():
8362
                    lower_args = (arg.to(dtype=dtype) for arg in args)
8363
                    tmp_result = fn(*lower_args)
8364
                    return tmp_result
8365
                c = test()
8366
                assert (torch.all(c == expected)), "Incorrect result with\n" \
8367
                                                   f"expected: {expected}\n" \
8368
                                                   f"got: {c}\n"
8369
        # test matmul
8370
        for dtype in [torch.bfloat16, torch.half]:
8371
            for transa in [True, False]:
8372
                for transb in [True, False]:
8373
                    a = torch.ones(300, 300)
8374
                    b = torch.ones(300, 300)
8375
                    if transa:
8376
                        a = a.transpose(0, 1).contiguous().transpose(0, 1)
8377
                    if transb:
8378
                        b = b.transpose(0, 1).contiguous().transpose(0, 1)
8379
                    check_correctness(torch.matmul, dtype, a, b)
8380
        # test bmm
8381
        a = torch.ones(1, 1, 300)
8382
        b = torch.ones(1, 300, 1)
8383
        check_correctness(torch.bmm, torch.bfloat16, a, b)
8384
        check_correctness(torch.bmm, torch.half, a, b)
8385
        # test baddbmm
8386
        a = torch.ones(1, 1, 300)
8387
        b = torch.ones(1, 300, 1)
8388
        c = torch.ones(1, 1, 1)
8389
        check_correctness(torch.baddbmm, torch.bfloat16, c, a, b)
8390
        check_correctness(torch.baddbmm, torch.half, c, a, b)
8391
        # test mv/addmv
8392
        for dtype in [torch.bfloat16, torch.half]:
8393
            for trans in [True, False]:
8394
                c = torch.ones(300) * -300
8395
                a = torch.ones(300, 300)
8396
                if trans:
8397
                    a = a.transpose(0, 1).contiguous().transpose(0, 1)
8398
                b = torch.ones(300)
8399
                check_correctness(torch.mv, dtype, a, b)
8400
                check_correctness(torch.addmv, dtype, c, a, b)
8401
        # test dot
8402
        a = torch.ones(300)
8403
        b = torch.ones(300)
8404
        check_correctness(torch.dot, torch.bfloat16, a, b)
8405
        check_correctness(torch.dot, torch.half, a, b)
8406

8407
    @dtypes(torch.float, torch.double)
8408
    @precisionOverride({torch.float32: 1e-4})
8409
    def test_1_sized_with_0_strided(self, device, dtype):
8410
        a = make_tensor((8, 1, 64), dtype=dtype, device=device)
8411
        a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
8412
        b = make_tensor((8, 64, 512), dtype=dtype, device=device)
8413
        b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
8414
        res = torch.bmm(a_strided, b_strided)
8415
        expect = torch.from_numpy(
8416
            a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype)
8417
        self.assertEqual(expect, res)
8418

8419
instantiate_device_type_tests(TestLinalg, globals())
8420

8421
if __name__ == '__main__':
8422
    TestCase._default_dtype_check_enabled = True
8423
    run_tests()
8424

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.