10
from math import inf, nan, isnan
13
from random import randrange
14
from itertools import product
15
from functools import reduce, partial
17
from torch.testing._internal.common_utils import \
18
(TestCase, run_tests, TEST_SCIPY, IS_MACOS, IS_WINDOWS, slowTest,
19
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
20
make_fullrank_matrices_with_distinct_singular_values,
21
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
22
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
23
from torch.testing._internal.common_device_type import \
24
(instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
25
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
26
skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA,
27
onlyCUDA, skipCUDAVersionIn, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm,
28
dtypesIfMPS, largeTensorTest)
29
from torch.testing import make_tensor
30
from torch.testing._internal.common_dtype import (
31
all_types, all_types_and_complex_and, floating_and_complex_types, integral_types,
32
floating_and_complex_types_and, floating_types_and, complex_types,
34
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
35
_get_torch_cuda_version, CDNA2OrLater
36
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
37
from torch.testing._internal.common_mkldnn import bf32_on_and_off
38
from torch.distributions.binomial import Binomial
39
import torch.backends.opt_einsum as opt_einsum
43
assert torch.get_default_dtype() is torch.float32
48
def blaslt_supported_device():
49
if torch.cuda.is_available():
51
for arch in ['gfx90a', 'gfx94']:
52
if arch in torch.cuda.get_device_properties(0).gcnArchName:
58
def set_tunableop_defaults():
59
if not torch.cuda.is_available():
64
ordinal = torch.cuda.current_device()
65
filename = f"tunableop_results{ordinal}.csv"
66
torch.cuda.tunable.enable(False)
67
torch.cuda.tunable.tuning_enable(True)
68
torch.cuda.tunable.set_filename(filename)
69
torch.cuda.tunable.set_max_tuning_duration(30)
70
torch.cuda.tunable.set_max_tuning_iterations(100)
73
class TestLinalg(TestCase):
75
super(self.__class__, self).setUp()
76
torch.backends.cuda.matmul.allow_tf32 = False
79
torch.backends.cuda.matmul.allow_tf32 = True
80
super(self.__class__, self).tearDown()
84
@dtypes(torch.float, torch.cfloat)
85
@precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06})
86
@tf32_on_and_off(5e-3)
87
@bf32_on_and_off(5e-3)
88
def test_inner(self, device, dtype):
89
def check(a_sizes_, b_sizes_):
90
for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)):
91
a = torch.randn(a_sizes, dtype=dtype, device=device)
92
b = torch.randn(b_sizes, dtype=dtype, device=device)
93
res = torch.inner(a, b)
94
ref = np.inner(a.cpu().numpy(), b.cpu().numpy())
95
self.assertEqual(res.cpu(), torch.from_numpy(np.array(ref)))
96
out = torch.zeros_like(res)
97
torch.inner(a, b, out=out)
98
self.assertEqual(res, out)
109
check([2], [3, 1, 2])
110
check([2], [3, 0, 2])
112
check([1, 2], [3, 2])
113
check([1, 2], [3, 4, 2])
114
check([2, 1, 3, 2], [1, 3, 2, 2])
117
with self.assertRaisesRegex(RuntimeError,
118
r"inner\(\) the last dimension must match on both "
119
r"input tensors but got shapes \[2, 3\] and \[2, 2\]"):
120
torch.randn(2, 3, device=device, dtype=dtype).inner(torch.randn(2, 2, device=device, dtype=dtype))
123
@precisionOverride({torch.bfloat16: 1e-1})
124
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
125
def test_outer(self, device, dtype):
126
def run_test_case(a, b):
127
if dtype == torch.bfloat16:
128
a_np = a.to(torch.double).cpu().numpy()
129
b_np = b.to(torch.double).cpu().numpy()
132
a_np = a.cpu().numpy()
133
b_np = b.cpu().numpy()
135
expected = np.outer(a_np, b_np)
137
self.assertEqual(torch.outer(a, b), expected, exact_dtype=False)
138
self.assertEqual(torch.Tensor.outer(a, b), expected, exact_dtype=False)
140
self.assertEqual(torch.ger(a, b), expected, exact_dtype=False)
141
self.assertEqual(torch.Tensor.ger(a, b), expected, exact_dtype=False)
144
out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
145
torch.outer(a, b, out=out)
146
self.assertEqual(out, expected, exact_dtype=False)
148
out = torch.empty(a.size(0), b.size(0), device=device, dtype=dtype)
149
torch.ger(a, b, out=out)
150
self.assertEqual(out, expected, exact_dtype=False)
152
a = torch.randn(50).to(device=device, dtype=dtype)
153
b = torch.randn(50).to(device=device, dtype=dtype)
157
zero_strided = torch.randn(1).to(device=device, dtype=dtype).expand(50)
158
run_test_case(zero_strided, b)
159
run_test_case(a, zero_strided)
161
def test_matrix_rank_removed_error(self, device):
162
a = make_tensor(5, 5, device=device, dtype=torch.float32)
163
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
166
def test_solve_removed_error(self, device):
167
a = make_tensor(5, 5, device=device, dtype=torch.float32)
168
b = make_tensor(5, 1, device=device, dtype=torch.float32)
169
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
171
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
174
def test_eig_removed_error(self, device):
175
a = make_tensor(5, 5, device=device, dtype=torch.float32)
176
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
178
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
181
def test_symeig_removed_error(self, device):
182
a = make_tensor(5, 5, device=device, dtype=torch.float32)
183
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
185
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
188
def test_lstsq_removed_error(self, device):
189
a = make_tensor(5, 5, device=device, dtype=torch.float32)
190
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
192
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
197
@skipIfTorchDynamo("flaky, needs investigation")
198
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
199
def test_linalg_lstsq(self, device, dtype):
200
from torch.testing._internal.common_utils import random_well_conditioned_matrix
201
if self.device_type == 'cpu':
202
drivers = ('gels', 'gelsy', 'gelsd', 'gelss', None)
204
drivers = ('gels', None)
206
def check_solution_correctness(a, b, sol):
207
sol2 = a.pinverse() @ b
208
self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
210
def check_correctness_ref(a, b, res, ref, driver="default"):
211
def apply_if_not_empty(t, f):
217
def select_if_not_empty(t, i):
218
selected = apply_if_not_empty(t, lambda x: x.select(0, i))
224
batch_size = int(np.prod(a.shape[:-2]))
227
a_3d = a.view(batch_size, m, n)
228
b_3d = b.view(batch_size, m, nrhs)
230
solution_3d = res.solution.view(batch_size, n, nrhs)
231
residuals_2d = apply_if_not_empty(res.residuals, lambda t: t.view(-1, nrhs))
232
rank_1d = apply_if_not_empty(res.rank, lambda t: t.view(-1))
233
singular_values_2d = res.singular_values.view(batch_size, res.singular_values.shape[-1])
236
for i in range(batch_size):
237
sol, residuals, rank, singular_values = ref(
238
a_3d.select(0, i).numpy(),
239
b_3d.select(0, i).numpy()
242
if singular_values is None:
244
self.assertEqual(sol, solution_3d.select(0, i), atol=1e-5, rtol=1e-5)
245
self.assertEqual(rank, select_if_not_empty(rank_1d, i), atol=1e-5, rtol=1e-5)
246
self.assertEqual(singular_values, singular_values_2d.select(0, i), atol=1e-5, rtol=1e-5)
255
if torch.all(rank_1d == n):
257
residuals, select_if_not_empty(residuals_2d, i), atol=1e-5, rtol=1e-5, exact_dtype=False
260
self.assertTrue(residuals_2d.numel() == 0)
263
self.assertEqual(res.solution.shape, (*a.shape[:-2], n, nrhs))
264
self.assertEqual(res.rank.shape, a.shape[:-2])
267
if m > n and driver != "gelsy":
268
self.assertEqual(res.residuals.shape, (*a.shape[:-2], 0))
270
self.assertEqual(res.residuals.shape, (0, ))
273
if driver == "default" or driver == "gelsd" or driver == "gelss":
274
self.assertEqual(res.singular_values.shape, (*a.shape[:-2], min(m, n)))
276
self.assertEqual(res.singular_values.shape, (0, ))
278
def check_correctness_scipy(a, b, res, driver, cond):
280
if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
284
return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
285
check_correctness_ref(a, b, res, scipy_ref, driver=driver)
287
def check_correctness_numpy(a, b, res, driver, rcond):
289
if driver == 'gelsd':
292
return np.linalg.lstsq(a, b, rcond=rcond)
293
check_correctness_ref(a, b, res, numpy_ref)
295
ms = [2 ** i for i in range(5)]
296
m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
298
m_l_n_sizes = [(m // 2, m) for m in ms]
299
include_m_l_n_case = (has_cusolver() or device == 'cpu')
300
matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if include_m_l_n_case else [])
301
batches = [(), (2,), (2, 2), (2, 2, 2)]
308
rconds = (None, True, -1)
310
for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
312
if rcond and rcond != -1:
313
if driver in ('gelss', 'gelsd'):
323
if driver == 'gels' and rcond is not None:
326
shape = batch + matrix_size
327
a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
328
b = torch.rand(*shape, dtype=dtype, device=device)
332
res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
336
check_correctness_scipy(a, b, res, driver, rcond)
339
check_correctness_numpy(a, b, res, driver, rcond)
343
if driver == 'gels' and rcond is None:
344
check_solution_correctness(a, b, sol)
348
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
349
def test_linalg_lstsq_batch_broadcasting(self, device, dtype):
350
from torch.testing._internal.common_utils import random_well_conditioned_matrix
352
def check_correctness(a, b):
353
sol = torch.linalg.lstsq(a, b).solution
354
sol2 = a.pinverse() @ b
355
self.assertEqual(sol, sol2, rtol=1e-5, atol=1e-5)
357
ms = [2 ** i for i in range(5)]
358
batches = [(), (0,), (2,), (2, 2), (2, 2, 2)]
360
for m, batch in itertools.product(ms, batches):
361
a = random_well_conditioned_matrix(m, m, dtype=dtype, device=device).view(*([1] * len(batch)), m, m)
362
b = torch.rand(*(batch + (m, m)), dtype=dtype, device=device)
363
check_correctness(a, b)
367
a = random_well_conditioned_matrix(1, 3, 1, 3, m, m, dtype=dtype, device=device)
368
b = torch.rand(3, 1, 3, 1, m, m // 2, dtype=dtype, device=device)
369
check_correctness(a, b)
372
b = torch.rand(3, 1, 3, 1, m, dtype=dtype, device=device)
375
check_correctness(a, b.unsqueeze(-1))
377
a = random_well_conditioned_matrix(3, 1, 3, 1, m, m, dtype=dtype, device=device)
378
b = torch.rand(1, 3, 1, 3, m, m // 2, dtype=dtype, device=device)
379
check_correctness(a, b)
382
b = torch.rand(1, 3, 1, 3, m, dtype=dtype, device=device)
383
check_correctness(a, b.unsqueeze(-1))
387
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
388
def test_linalg_lstsq_input_checks(self, device, dtype):
391
a = torch.rand(0, 0, 3, 3, dtype=dtype, device=device)
392
b = torch.rand(0, 0, 3, 2, dtype=dtype, device=device)
394
torch.linalg.lstsq(a, b)[0],
395
torch.zeros(0, 0, 3, 2, dtype=dtype, device=device)
398
a = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
399
b = torch.rand(2, 2, 0, 0, dtype=dtype, device=device)
401
torch.linalg.lstsq(a, b)[0],
402
torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
405
a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
406
b = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
408
torch.linalg.lstsq(a, b)[0],
409
torch.zeros(2, 2, 0, 0, dtype=dtype, device=device)
412
a = torch.rand(2, 2, 3, 0, dtype=dtype, device=device)
413
b = torch.rand(2, 2, 3, 2, dtype=dtype, device=device)
415
torch.linalg.lstsq(a, b)[0],
416
torch.zeros(2, 2, 0, 2, dtype=dtype, device=device)
420
if torch.device(device).type == 'cpu':
422
a = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
423
b = torch.rand(2, 2, 0, 3, dtype=dtype, device=device)
425
torch.linalg.lstsq(a, b)[0],
426
torch.zeros(2, 2, 3, 3, dtype=dtype, device=device)
429
a = torch.rand(2, 3, dtype=dtype, device=device)
430
b = torch.rand(3, dtype=dtype, device=device)
432
with self.assertRaisesRegex(RuntimeError, 'input must have at least 2 dimensions'):
433
torch.linalg.lstsq(b, b)
435
with self.assertRaisesRegex(RuntimeError, 'other must have at least 1 dimension'):
436
torch.linalg.lstsq(a, torch.tensor(1, dtype=dtype, device=device))
438
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-1\)'):
439
torch.linalg.lstsq(a, b)
441
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
442
torch.linalg.lstsq(a, b.unsqueeze(-1))
444
a = torch.randn(1, 1, 1, dtype=dtype, device=device)
445
b = torch.randn(3, 1, dtype=dtype, device=device)
447
with self.assertRaisesRegex(RuntimeError, r'input.size\(-2\) should match other.size\(-2\)'):
448
torch.linalg.lstsq(a, b)
450
def complement_device(device):
451
if device == 'cpu' and torch.cuda.is_available():
456
a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
457
b = torch.rand(2, 2, 2, dtype=dtype, device=complement_device(device))
458
if a.device != b.device:
459
with self.assertRaisesRegex(RuntimeError, 'be on the same device'):
460
torch.linalg.lstsq(a, b)
462
b = (torch.rand(2, 2, 2, dtype=dtype, device=device) * 100).long()
463
with self.assertRaisesRegex(RuntimeError, 'the same dtype'):
464
torch.linalg.lstsq(a, b)
466
a = torch.rand(2, 2, 2, 2, dtype=dtype, device=device)
467
b = torch.rand(2, 2, 2, dtype=dtype, device=device)
470
with self.assertRaisesRegex(RuntimeError, '`driver` other than `gels` is not supported on CUDA'):
471
torch.linalg.lstsq(a, b, driver='fictitious_driver')
474
with self.assertRaisesRegex(RuntimeError, r'parameter `driver` should be one of \(gels, gelsy, gelsd, gelss\)'):
475
torch.linalg.lstsq(a, b, driver='fictitious_driver')
478
version = torch.testing._internal.common_cuda._get_torch_cuda_version()
479
cusolver_not_available = (version < (10, 1))
481
if device != 'cpu' and cusolver_not_available:
482
a = torch.rand(2, 3, dtype=dtype, device=device)
483
b = torch.rand(2, 1, dtype=dtype, device=device)
484
with self.assertRaisesRegex(RuntimeError, r'only overdetermined systems'):
485
torch.linalg.lstsq(a, b)
489
@dtypes(*floating_and_complex_types())
490
def test_cholesky(self, device, dtype):
491
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
493
def run_test(shape, batch, contiguous):
494
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
495
if A.numel() > 0 and not contiguous:
497
self.assertFalse(A.is_contiguous())
498
expected_L = np.linalg.cholesky(A.cpu().numpy())
499
actual_L = torch.linalg.cholesky(A)
503
if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
505
expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
506
actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
508
self.assertEqual(actual_norm, expected_norm)
510
self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
512
self.assertEqual(actual_L, expected_L)
515
batches = ((), (3, ), (2, 2))
516
larger_input_case = [(100, (5, ), True)]
517
for shape, batch, contiguous in list(itertools.product(shapes, batches, (True, False))) + larger_input_case:
518
run_test(shape, batch, contiguous)
521
A = random_hermitian_pd_matrix(3, 3, dtype=dtype, device=device)
522
out = torch.empty_like(A)
523
ans = torch.linalg.cholesky(A, out=out)
524
self.assertEqual(ans, out)
525
expected = torch.linalg.cholesky(A)
526
self.assertEqual(expected, out)
529
expected = torch.linalg.cholesky(A).mH
530
actual = torch.linalg.cholesky(A, upper=True)
531
self.assertEqual(expected, actual)
535
@dtypes(*floating_and_complex_types())
536
def test_cholesky_errors_and_warnings(self, device, dtype):
537
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
540
A = torch.randn(2, 3, device=device, dtype=dtype)
541
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
542
torch.linalg.cholesky(A)
543
A = torch.randn(2, 2, 3, device=device, dtype=dtype)
544
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
545
torch.linalg.cholesky(A)
546
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Last 2 dimensions of the array must be square'):
547
np.linalg.cholesky(A.cpu().numpy())
550
A = torch.randn(2, device=device, dtype=dtype)
551
with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
552
torch.linalg.cholesky(A)
553
with self.assertRaisesRegex(np.linalg.LinAlgError,
554
r'1-dimensional array given\. Array must be at least two-dimensional'):
555
np.linalg.cholesky(A.cpu().numpy())
558
A = torch.eye(3, 3, dtype=dtype, device=device)
560
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
561
torch.linalg.cholesky(A)
562
with self.assertRaisesRegex(np.linalg.LinAlgError, r'Matrix is not positive definite'):
563
np.linalg.cholesky(A.cpu().numpy())
566
A = torch.eye(3, 3, dtype=dtype, device=device)
567
A = A.reshape((1, 3, 3))
568
A = A.repeat(5, 1, 1)
570
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 4\): The factorization could not be completed'):
571
torch.linalg.cholesky(A)
574
A = random_hermitian_pd_matrix(3, dtype=dtype, device=device)
575
out = torch.empty(2, 3, dtype=dtype, device=device)
576
with warnings.catch_warnings(record=True) as w:
578
torch.linalg.cholesky(A, out=out)
580
self.assertEqual(len(w), 1)
581
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
584
out = torch.empty(*A.shape, dtype=torch.int, device=device)
585
with self.assertRaisesRegex(RuntimeError, "but got int instead"):
586
torch.linalg.cholesky(A, out=out)
589
if torch.cuda.is_available():
590
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
591
out = torch.empty(0, device=wrong_device, dtype=dtype)
592
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
593
torch.linalg.cholesky(A, out=out)
599
@dtypes(torch.double)
600
def test_old_cholesky_batched_many_batches(self, device, dtype):
601
from torch.testing._internal.common_utils import random_symmetric_pd_matrix
603
def cholesky_test_helper(n, batchsize, device, upper):
604
A = random_symmetric_pd_matrix(n, batchsize, dtype=dtype, device=device)
605
chol_fact = torch.cholesky(A, upper=upper)
608
self.assertEqual(A, chol_fact.mT.matmul(chol_fact))
610
self.assertEqual(chol_fact, chol_fact.triu())
613
self.assertEqual(A, chol_fact.matmul(chol_fact.mT))
615
self.assertEqual(chol_fact, chol_fact.tril())
617
for upper, batchsize in itertools.product([True, False], [262144, 524288]):
618
cholesky_test_helper(2, batchsize, device, upper)
620
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
623
@dtypes(*floating_and_complex_types())
624
def test_old_cholesky_batched(self, device, dtype):
625
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
627
def cholesky_test_helper(n, batch_dims, upper):
628
A = random_hermitian_pd_matrix(n, *batch_dims, dtype=dtype, device=device)
629
cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
630
cholesky_exp = cholesky_exp.reshape_as(A)
631
self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
633
for upper, batchsize in itertools.product([True, False], [(3,), (3, 4), (2, 3, 4)]):
634
cholesky_test_helper(3, batchsize, upper)
636
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
639
@dtypes(*floating_and_complex_types())
640
@tf32_on_and_off(0.01)
641
@bf32_on_and_off(0.01)
642
def test_old_cholesky(self, device, dtype):
643
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
645
A = random_hermitian_pd_matrix(10, dtype=dtype, device=device)
648
C = torch.cholesky(A)
649
B = torch.mm(C, C.t().conj())
650
self.assertEqual(A, B, atol=1e-14, rtol=0)
653
U = torch.cholesky(A, True)
654
B = torch.mm(U.t().conj(), U)
655
self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (upper) did not allow rebuilding the original matrix')
658
L = torch.cholesky(A, False)
659
B = torch.mm(L, L.t().conj())
660
self.assertEqual(A, B, atol=1e-14, rtol=0, msg='cholesky (lower) did not allow rebuilding the original matrix')
664
@dtypes(*floating_and_complex_types())
665
def test_old_cholesky_empty(self, device, dtype):
667
A = torch.empty(0, 0, dtype=dtype, device=device)
668
chol = torch.cholesky(A, upper)
669
chol_A = torch.matmul(chol, chol.t().conj())
670
self.assertEqual(A, chol_A)
671
for upper in [True, False]:
680
@dtypes(*floating_and_complex_types())
681
def test_old_cholesky_batched_upper(self, device, dtype):
682
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
685
A = random_hermitian_pd_matrix(3, batchsize, dtype=dtype, device=device)
688
U = torch.cholesky(A_triu, upper=True)
690
reconstruct_A = U.mH @ U
691
self.assertEqual(A, reconstruct_A)
693
@skipCUDAIfNoMagmaAndNoCusolver
695
@dtypes(*floating_and_complex_types())
696
def test_cholesky_ex(self, device, dtype):
697
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
699
def run_test(n, batch):
700
A = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
701
expected_L = np.linalg.cholesky(A.cpu().numpy())
702
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
703
actual_L, actual_info = torch.linalg.cholesky_ex(A)
707
if A.numel() > 0 and dtype in [torch.float32, torch.complex64]:
709
expected_norm = np.linalg.norm(expected_L, ord=1, axis=(-2, -1))
710
actual_norm = torch.linalg.norm(actual_L, ord=1, axis=(-2, -1))
712
self.assertEqual(actual_norm, expected_norm)
714
self.assertEqual(actual_L, expected_L, atol=1e-2, rtol=1e-5)
716
self.assertEqual(actual_L, expected_L)
717
self.assertEqual(actual_info, expected_info)
720
batches = ((), (2, ), (2, 1))
721
for n, batch in itertools.product(ns, batches):
724
@skipCUDAIfNoMagmaAndNoCusolver
726
@dtypes(*floating_and_complex_types())
727
def test_cholesky_ex_non_pd(self, device, dtype):
729
A = torch.eye(3, 3, dtype=dtype, device=device)
731
_, info = torch.linalg.cholesky_ex(A)
732
self.assertEqual(info, 3)
733
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'minor of order 3 is not positive-definite'):
734
torch.linalg.cholesky_ex(A, check_errors=True)
738
A = torch.eye(3, 3, dtype=dtype, device=device)
739
A = A.reshape((1, 3, 3))
740
A = A.repeat(5, 1, 1)
742
_, info = torch.linalg.cholesky_ex(A)
744
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
746
self.assertEqual(info, expected_info)
747
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The factorization could not be completed'):
748
torch.linalg.cholesky_ex(A, check_errors=True)
750
def _test_addr_vs_numpy(self, device, dtype, beta=1, alpha=1):
751
def check(m, a, b, beta, alpha):
752
if dtype == torch.bfloat16:
753
a_np = a.to(torch.double).cpu().numpy()
754
b_np = b.to(torch.double).cpu().numpy()
755
m_np = m.to(torch.double).cpu().numpy()
758
a_np = a.cpu().numpy()
759
b_np = b.cpu().numpy()
760
m_np = m.cpu().numpy()
763
expected = alpha * np.outer(a_np, b_np)
765
expected = beta * m_np + alpha * np.outer(a_np, b_np)
767
res = torch.addr(m, a, b, beta=beta, alpha=alpha)
768
self.assertEqual(res, expected, exact_dtype=exact_dtype)
771
out = torch.empty_like(res)
772
torch.addr(m, a, b, beta=beta, alpha=alpha, out=out)
773
self.assertEqual(out, expected, exact_dtype=exact_dtype)
775
m = make_tensor((50, 50), device=device, dtype=dtype, low=-2, high=2)
776
a = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
777
b = make_tensor((50,), device=device, dtype=dtype, low=-2, high=2)
779
check(m, a, b, beta, alpha)
782
m_transpose = torch.transpose(m, 0, 1)
783
check(m_transpose, a, b, beta, alpha)
786
zero_strided = make_tensor((1,), device=device, dtype=dtype, low=-2, high=2).expand(50)
787
check(m, zero_strided, b, beta, alpha)
790
m_scalar = torch.tensor(1, device=device, dtype=dtype)
791
check(m_scalar, a, b, beta, alpha)
794
float_and_complex_dtypes = floating_and_complex_types_and(torch.half, torch.bfloat16)
795
if beta == 0 and dtype in float_and_complex_dtypes:
796
m[0][10] = m[10][10] = m[20][20] = float('inf')
797
m[1][10] = m[11][10] = m[21][20] = float('nan')
798
check(m, a, b, 0, alpha)
801
def test_addr_bool(self, device, dtype):
802
self._test_addr_vs_numpy(device, dtype, beta=True, alpha=False)
803
self._test_addr_vs_numpy(device, dtype, beta=False, alpha=True)
804
self._test_addr_vs_numpy(device, dtype, beta=False, alpha=False)
805
self._test_addr_vs_numpy(device, dtype, beta=True, alpha=True)
807
@dtypes(*integral_types())
808
def test_addr_integral(self, device, dtype):
809
with self.assertRaisesRegex(RuntimeError,
810
'argument beta must not be a floating point number.'):
811
self._test_addr_vs_numpy(device, dtype, beta=2., alpha=1)
812
with self.assertRaisesRegex(RuntimeError,
813
'argument alpha must not be a floating point number.'):
814
self._test_addr_vs_numpy(device, dtype, beta=2, alpha=1.)
815
with self.assertRaisesRegex(RuntimeError,
816
'Boolean beta only supported for Boolean results.'):
817
self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
818
with self.assertRaisesRegex(RuntimeError,
819
'Boolean alpha only supported for Boolean results.'):
820
self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
823
self._test_addr_vs_numpy(device, dtype, beta=0, alpha=2)
825
self._test_addr_vs_numpy(device, dtype, beta=2, alpha=2)
827
@precisionOverride({torch.bfloat16: 1e-1})
828
@dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
829
def test_addr_float_and_complex(self, device, dtype):
830
with self.assertRaisesRegex(RuntimeError,
831
'Boolean beta only supported for Boolean results.'):
832
self._test_addr_vs_numpy(device, dtype, beta=True, alpha=1)
833
with self.assertRaisesRegex(RuntimeError,
834
'Boolean alpha only supported for Boolean results.'):
835
self._test_addr_vs_numpy(device, dtype, beta=2, alpha=True)
838
self._test_addr_vs_numpy(device, dtype, beta=0., alpha=2)
840
self._test_addr_vs_numpy(device, dtype, beta=0.5, alpha=2)
841
if dtype in complex_types():
842
self._test_addr_vs_numpy(device, dtype, beta=(0 + 0.1j), alpha=(0.2 - 0.2j))
844
@dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
845
all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
846
def test_outer_type_promotion(self, device, dtypes):
847
a = torch.randn(5).to(device=device, dtype=dtypes[0])
848
b = torch.randn(5).to(device=device, dtype=dtypes[1])
849
for op in (torch.outer, torch.Tensor.outer, torch.ger, torch.Tensor.ger):
851
self.assertEqual(result.dtype, torch.result_type(a, b))
854
def test_addr_type_promotion(self, device):
855
for dtypes0, dtypes1, dtypes2 in product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), repeat=3):
856
a = make_tensor((5,), device=device, dtype=dtypes0, low=-2, high=2)
857
b = make_tensor((5,), device=device, dtype=dtypes1, low=-2, high=2)
858
m = make_tensor((5, 5), device=device, dtype=dtypes2, low=-2, high=2)
860
desired_dtype = torch.promote_types(torch.promote_types(dtypes0, dtypes1),
862
for op in (torch.addr, torch.Tensor.addr):
864
self.assertEqual(result.dtype, desired_dtype)
869
def test_outer_ger_addr_legacy_tests(self, device):
870
for size in ((0, 0), (0, 5), (5, 0)):
871
a = torch.rand(size[0], device=device)
872
b = torch.rand(size[1], device=device)
874
self.assertEqual(torch.outer(a, b).shape, size)
875
self.assertEqual(torch.ger(a, b).shape, size)
877
m = torch.empty(size, device=device)
878
self.assertEqual(torch.addr(m, a, b).shape, size)
880
m = torch.randn(5, 6, device=device)
881
a = torch.randn(5, device=device)
882
b = torch.tensor(6, device=device)
883
self.assertRaises(RuntimeError, lambda: torch.outer(a, b))
884
self.assertRaises(RuntimeError, lambda: torch.outer(b, a))
885
self.assertRaises(RuntimeError, lambda: torch.ger(a, b))
886
self.assertRaises(RuntimeError, lambda: torch.ger(b, a))
887
self.assertRaises(RuntimeError, lambda: torch.addr(m, a, b))
888
self.assertRaises(RuntimeError, lambda: torch.addr(m, b, a))
893
@dtypes(torch.double, torch.cdouble)
894
def test_det(self, device, dtype):
896
torch.randn((2, 2), device=device, dtype=dtype),
897
torch.randn((129, 129), device=device, dtype=dtype),
898
torch.randn((3, 52, 52), device=device, dtype=dtype),
899
torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
902
ops = (torch.det, torch.Tensor.det,
905
expected = np.linalg.det(t.cpu().numpy())
908
self.assertEqual(actual, expected)
909
self.compare_with_numpy(op, np.linalg.det, t)
912
t = torch.randn(1, device=device, dtype=dtype)
913
with self.assertRaises(RuntimeError):
918
@dtypes(*floating_and_complex_types())
919
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
920
def test_eigh(self, device, dtype):
921
from torch.testing._internal.common_utils import random_hermitian_matrix
923
def run_test(shape, batch, uplo):
924
matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
925
expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
926
actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
927
self.assertEqual(actual_w, expected_w)
929
self.assertEqual(abs(actual_v), abs(expected_v))
933
if matrix.numel() > 0:
934
phase = torch.from_numpy(expected_v[..., 0, :]).to(device=device).div(actual_v[..., 0, :])
935
actual_v_rotated = actual_v * phase.unsqueeze(-2).expand_as(actual_v)
936
self.assertEqual(actual_v_rotated, expected_v)
939
out_w = torch.empty_like(actual_w)
940
out_v = torch.empty_like(actual_v)
941
ans_w, ans_v = torch.linalg.eigh(matrix, UPLO=uplo, out=(out_w, out_v))
942
self.assertEqual(ans_w, out_w)
943
self.assertEqual(ans_v, out_v)
944
self.assertEqual(ans_w, actual_w)
945
self.assertEqual(abs(ans_v), abs(actual_v))
948
batches = ((), (3, ), (2, 2))
950
for shape, batch, uplo in itertools.product(shapes, batches, uplos):
951
run_test(shape, batch, uplo)
955
@dtypes(*floating_and_complex_types())
956
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
957
def test_eigh_lower_uplo(self, device, dtype):
958
def run_test(shape, batch, uplo):
961
matrix = torch.randn(shape, shape, *batch, dtype=dtype, device=device)
962
expected_w, expected_v = np.linalg.eigh(matrix.cpu().numpy(), UPLO=uplo)
963
actual_w, actual_v = torch.linalg.eigh(matrix, UPLO=uplo)
964
self.assertEqual(actual_w, expected_w)
965
self.assertEqual(abs(actual_v), abs(expected_v))
969
run_test(3, (2, 2), uplo)
973
@dtypes(*floating_and_complex_types())
974
def test_eigh_errors_and_warnings(self, device, dtype):
975
from torch.testing._internal.common_utils import random_hermitian_matrix
978
t = torch.randn(2, 3, device=device, dtype=dtype)
979
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
983
t = torch.randn(3, 3, device=device, dtype=dtype)
984
for uplo in ["a", "wrong"]:
985
with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
986
torch.linalg.eigh(t, UPLO=uplo)
987
with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
988
np.linalg.eigh(t.cpu().numpy(), UPLO=uplo)
991
a = random_hermitian_matrix(3, dtype=dtype, device=device)
992
real_dtype = a.real.dtype if dtype.is_complex else dtype
993
out_w = torch.empty(7, 7, dtype=real_dtype, device=device)
994
out_v = torch.empty(7, 7, dtype=dtype, device=device)
995
with warnings.catch_warnings(record=True) as w:
997
torch.linalg.eigh(a, out=(out_w, out_v))
999
self.assertEqual(len(w), 2)
1000
self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
1001
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1004
out_w = torch.empty(0, dtype=real_dtype, device=device)
1005
out_v = torch.empty(0, dtype=torch.int, device=device)
1006
with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1007
torch.linalg.eigh(a, out=(out_w, out_v))
1009
out_w = torch.empty(0, dtype=torch.int, device=device)
1010
out_v = torch.empty(0, dtype=dtype, device=device)
1011
with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1012
torch.linalg.eigh(a, out=(out_w, out_v))
1015
if torch.cuda.is_available():
1016
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1017
out_w = torch.empty(0, device=wrong_device, dtype=dtype)
1018
out_v = torch.empty(0, device=device, dtype=dtype)
1019
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1020
torch.linalg.eigh(a, out=(out_w, out_v))
1021
out_w = torch.empty(0, device=device, dtype=dtype)
1022
out_v = torch.empty(0, device=wrong_device, dtype=dtype)
1023
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1024
torch.linalg.eigh(a, out=(out_w, out_v))
1027
@dtypes(torch.float, torch.double)
1028
@unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.")
1029
def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype):
1033
a = torch.ones(512, 512, dtype=dtype, device=device)
1037
eigh_out = torch.linalg.eigh(a)
1038
svd_out = torch.linalg.svd(a)
1044
self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1045
self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
1049
@dtypes(*floating_and_complex_types())
1050
@precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4})
1051
def test_eigvalsh(self, device, dtype):
1052
from torch.testing._internal.common_utils import random_hermitian_matrix
1054
def run_test(shape, batch, uplo):
1055
matrix = random_hermitian_matrix(shape, *batch, dtype=dtype, device=device)
1056
expected_w = np.linalg.eigvalsh(matrix.cpu().numpy(), UPLO=uplo)
1057
actual_w = torch.linalg.eigvalsh(matrix, UPLO=uplo)
1058
self.assertEqual(actual_w, expected_w)
1061
out = torch.empty_like(actual_w)
1062
ans = torch.linalg.eigvalsh(matrix, UPLO=uplo, out=out)
1063
self.assertEqual(ans, out)
1064
self.assertEqual(ans, actual_w)
1067
batches = ((), (3, ), (2, 2))
1069
for shape, batch, uplo in itertools.product(shapes, batches, uplos):
1070
run_test(shape, batch, uplo)
1074
@dtypes(*floating_and_complex_types())
1075
def test_eigvalsh_errors_and_warnings(self, device, dtype):
1077
t = torch.randn(2, 3, device=device, dtype=dtype)
1078
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1079
torch.linalg.eigvalsh(t)
1082
t = torch.randn(3, 3, device=device, dtype=dtype)
1083
for uplo in ["a", "wrong"]:
1084
with self.assertRaisesRegex(RuntimeError, "be 'L' or 'U'"):
1085
torch.linalg.eigvalsh(t, UPLO=uplo)
1086
with self.assertRaisesRegex(ValueError, "be 'L' or 'U'"):
1087
np.linalg.eigvalsh(t.cpu().numpy(), UPLO=uplo)
1090
real_dtype = t.real.dtype if dtype.is_complex else dtype
1091
out = torch.empty_like(t).to(real_dtype)
1092
with warnings.catch_warnings(record=True) as w:
1094
torch.linalg.eigvalsh(t, out=out)
1096
self.assertEqual(len(w), 1)
1097
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1100
out = torch.empty(0, dtype=torch.int, device=device)
1101
with self.assertRaisesRegex(RuntimeError, "but got int instead"):
1102
torch.linalg.eigvalsh(t, out=out)
1105
if torch.cuda.is_available():
1106
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1107
out = torch.empty(0, device=wrong_device, dtype=dtype)
1108
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1109
torch.linalg.eigvalsh(t, out=out)
1111
@dtypes(*floating_and_complex_types())
1112
def test_kron(self, device, dtype):
1114
def run_test_case(a_shape, b_shape):
1115
a = torch.rand(a_shape, dtype=dtype, device=device)
1116
b = torch.rand(b_shape, dtype=dtype, device=device)
1118
expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1119
result = torch.kron(a, b)
1120
self.assertEqual(result, expected)
1123
out = torch.empty_like(result)
1124
ans = torch.kron(a, b, out=out)
1125
self.assertEqual(ans, out)
1126
self.assertEqual(ans, result)
1128
shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)]
1129
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
1130
run_test_case(a_shape, b_shape)
1132
@dtypes(*floating_and_complex_types())
1133
def test_kron_empty(self, device, dtype):
1135
def run_test_case(empty_shape):
1136
a = torch.eye(3, dtype=dtype, device=device)
1137
b = torch.empty(empty_shape, dtype=dtype, device=device)
1138
result = torch.kron(a, b)
1139
expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
1140
self.assertEqual(result, expected)
1143
result = torch.kron(b, a)
1144
self.assertEqual(result.shape, expected.shape)
1146
empty_shapes = [(0,), (2, 0), (1, 0, 3)]
1147
for empty_shape in empty_shapes:
1148
run_test_case(empty_shape)
1150
@dtypes(*floating_and_complex_types())
1151
def test_kron_errors_and_warnings(self, device, dtype):
1153
a = torch.eye(3, dtype=dtype, device=device)
1154
b = torch.ones((2, 2), dtype=dtype, device=device)
1155
out = torch.empty_like(a)
1156
with warnings.catch_warnings(record=True) as w:
1158
torch.kron(a, b, out=out)
1160
self.assertEqual(len(w), 1)
1161
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1164
out = torch.empty_like(a).to(torch.int)
1165
with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"):
1166
torch.kron(a, b, out=out)
1170
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1171
def test_norm_dtype(self, device, dtype):
1172
make_arg = partial(make_tensor, dtype=dtype, device=device)
1174
def run_test_case(input_size, ord, keepdim, to_dtype):
1176
f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1177
f'dtype={dtype}, to_dtype={to_dtype}')
1178
input = make_arg(input_size)
1179
result = torch.linalg.norm(input, ord, keepdim=keepdim)
1180
self.assertEqual(result.dtype, input.real.dtype, msg=msg)
1182
result_out = torch.empty((0), dtype=result.dtype, device=device)
1183
torch.linalg.norm(input, ord, keepdim=keepdim, out=result_out)
1184
self.assertEqual(result, result_out, msg=msg)
1186
result = torch.linalg.norm(input.to(to_dtype), ord, keepdim=keepdim)
1187
result_with_dtype = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype)
1188
self.assertEqual(result, result_with_dtype, msg=msg)
1190
result_out_with_dtype = torch.empty_like(result_with_dtype)
1191
torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_with_dtype)
1192
self.assertEqual(result_with_dtype, result_out_with_dtype, msg=msg)
1194
ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1198
if dtype != torch.float16 and dtype != torch.bfloat16:
1199
ord_vector.extend([0.1, -0.1])
1200
ord_matrix = ['fro', 'nuc', 1, -1, 2, -2, inf, -inf, None]
1203
if dtype == torch.cfloat:
1204
norm_dtypes = (torch.cfloat, torch.cdouble)
1205
elif dtype == torch.cdouble:
1206
norm_dtypes = (torch.cdouble,)
1207
elif dtype in (torch.float16, torch.bfloat16, torch.float):
1208
norm_dtypes = (torch.float, torch.double)
1209
elif dtype == torch.double:
1210
norm_dtypes = (torch.double,)
1212
raise RuntimeError("Unsupported dtype")
1214
for ord, keepdim, norm_dtype in product(ord_vector, (True, False), norm_dtypes):
1215
run_test_case((S,) , ord, keepdim, norm_dtype)
1217
for ord, keepdim, norm_dtype in product(ord_matrix, (True, False), norm_dtypes):
1218
if ord in [2, -2, 'nuc']:
1220
if dtype == torch.float16 or dtype == torch.bfloat16:
1224
if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1225
(torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1227
run_test_case((S, S) , ord, keepdim, norm_dtype)
1230
@dtypes(torch.bfloat16, torch.float16)
1231
def test_norm_bfloat16_and_half(self, device, dtype):
1232
make_arg = partial(make_tensor, dtype=dtype, device=device)
1234
def run_test_case(input_size, ord, keepdim):
1236
f'input_size={input_size}, ord={ord}, keepdim={keepdim}, '
1238
input = make_arg(input_size).fill_(1)
1239
result_ref = torch.linalg.norm(input.float(), ord, keepdim=keepdim).to(dtype=dtype)
1240
result = torch.linalg.norm(input, ord, keepdim=keepdim)
1241
self.assertEqual(result_ref, result, msg=msg)
1243
ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None]
1244
for S, ord, keepdim in product((10, 2049), ord_vector, (True, False)):
1245
run_test_case((S,) , ord, keepdim, )
1247
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble, torch.bfloat16, torch.float16)
1248
def test_vector_norm(self, device, dtype):
1249
if IS_ARM64 and device == 'cpu' and dtype in [torch.float16, torch.bfloat16, torch.float32]:
1250
raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
1254
ord_vector = [0, 0.9, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1266
def vector_norm_reference(input, ord, dim=None, keepdim=False, dtype=None):
1268
input_maybe_flat = input.flatten(0, -1)
1270
input_maybe_flat = input
1272
result = torch.linalg.norm(input_maybe_flat, ord, dim=dim, keepdim=keepdim, dtype=dtype)
1273
if keepdim and dim is None:
1274
result = result.reshape([1] * input.dim())
1277
def run_test_case(input, ord, dim, keepdim, norm_dtype):
1278
if (input.numel() == 0 and
1279
(ord < 0. or ord == inf) and
1280
(dim is None or input.shape[dim] == 0)):
1282
error_msg = "linalg.vector_norm cannot compute"
1283
with self.assertRaisesRegex(RuntimeError, error_msg):
1284
torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim)
1286
msg = (f'input.size()={input.size()}, ord={ord}, dim={dim}, '
1287
f'keepdim={keepdim}, dtype={dtype}, norm_dtype={norm_dtype}')
1288
result_dtype_reference = vector_norm_reference(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1289
result_dtype = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1290
if dtype.is_complex:
1291
result_dtype_reference = result_dtype_reference.real
1292
self.assertEqual(result_dtype, result_dtype_reference, msg=msg)
1294
if norm_dtype is not None:
1295
ref = torch.linalg.vector_norm(input.to(norm_dtype), ord, dim=dim, keepdim=keepdim)
1296
actual = torch.linalg.vector_norm(input, ord, dim=dim, keepdim=keepdim, dtype=norm_dtype)
1297
self.assertEqual(ref, actual, msg=msg)
1299
if dtype == torch.cfloat:
1300
norm_dtypes = (None, torch.cfloat, torch.cdouble)
1301
elif dtype == torch.cdouble:
1302
norm_dtypes = (None, torch.cdouble)
1303
elif dtype in (torch.float16, torch.bfloat16, torch.float):
1304
norm_dtypes = (None, torch.float, torch.double)
1305
elif dtype == torch.double:
1306
norm_dtypes = (None, torch.double)
1308
raise RuntimeError("Unsupported dtype")
1310
for amp in [False, True]:
1311
with torch.autocast(device_type=device, enabled=amp):
1312
for input_size, ord, keepdim, norm_dtype in product(input_sizes, ord_vector, [True, False], norm_dtypes):
1313
input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1314
for dim in [None, random.randint(0, len(input_size) - 1)]:
1322
def test_vector_norm_dim_tuple_arg(self, device):
1325
((4, ), (0, ), None, None),
1326
((4, ), (1, ), IndexError, r'Dimension out of range'),
1327
((4, ), (-2, ), IndexError, r'Dimension out of range'),
1328
((4, 3), (0, -1), None, None),
1329
((4, 3), (0, 0), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1330
((4, 3), (0, -2), RuntimeError, r'dim 0 appears multiple times in the list of dims'),
1331
((4, 3), (0, 1.0), TypeError, r"argument 'dim' must be tuple of ints"),
1332
((4, 3), (None, ), TypeError, r"argument 'dim' must be tuple of ints"),
1334
for input_size, dim_tuple, error, error_msg in test_cases:
1335
input = torch.randn(input_size, device=device)
1337
for dim in [dim_tuple, list(dim_tuple)]:
1339
torch.linalg.vector_norm(input, dim=dim)
1341
with self.assertRaises(error):
1342
torch.linalg.vector_norm(input, dim=dim)
1346
@dtypes(torch.float, torch.double)
1347
def test_norm_vector(self, device, dtype):
1348
def run_test_case(input, p, dim, keepdim):
1349
result = torch.linalg.norm(input, ord, dim, keepdim)
1350
input_numpy = input.cpu().numpy()
1351
result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1353
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1354
self.assertEqual(result, result_numpy, msg=msg)
1356
result_out = torch.empty_like(result)
1357
torch.linalg.norm(input, ord, dim, keepdim, out=result_out)
1358
self.assertEqual(result, result_out, msg=msg)
1360
ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf]
1364
((S, ), ord_vector, None),
1365
((S, ), ord_vector, 0),
1366
((S, S, S), ord_vector, 0),
1367
((S, S, S), ord_vector, 1),
1368
((S, S, S), ord_vector, 2),
1369
((S, S, S), ord_vector, -1),
1370
((S, S, S), ord_vector, -2),
1373
if dtype == torch.double:
1374
test_cases.append(((L, ), ord_vector, None))
1375
for keepdim in [True, False]:
1376
for input_size, ord_settings, dim in test_cases:
1377
input = torch.randn(*input_size, dtype=dtype, device=device)
1378
for ord in ord_settings:
1379
run_test_case(input, ord, dim, keepdim)
1385
@dtypes(torch.float, torch.double)
1386
@precisionOverride({torch.float32: 2e-4})
1387
def test_norm_matrix(self, device, dtype):
1388
make_arg = partial(make_tensor, dtype=dtype, device=device)
1390
def run_test_case(input, ord, dim, keepdim):
1391
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1392
result = torch.linalg.norm(input, ord, dim, keepdim)
1393
input_numpy = input.cpu().numpy()
1394
result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1396
result = torch.linalg.norm(input, ord, dim, keepdim)
1397
self.assertEqual(result, result_numpy, msg=msg)
1398
if ord is not None and dim is not None:
1399
result = torch.linalg.matrix_norm(input, ord, dim, keepdim)
1400
self.assertEqual(result, result_numpy, msg=msg)
1402
ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro']
1409
((S, S, S, S), (2, 0)),
1410
((S, S, S, S), (-1, -2)),
1411
((S, S, S, S), (-1, -3)),
1412
((S, S, S, S), (-3, 2)),
1415
for (shape, dim), keepdim, ord in product(test_cases, [True, False], ord_matrix):
1416
if ord in [2, -2, 'nuc']:
1418
if dtype == torch.float16 or dtype == torch.bfloat16:
1421
if ((torch.device(device).type == 'cuda' and not torch.cuda.has_magma and not has_cusolver()) or
1422
(torch.device(device).type == 'cpu' and not torch._C.has_lapack)):
1424
run_test_case(make_arg(shape), ord, dim, keepdim)
1428
@dtypes(torch.bfloat16, torch.float16)
1429
def test_norm_fused_type_promotion(self, device, dtype):
1430
x = torch.randn(10, device=device, dtype=dtype)
1432
def profile_and_check(fn, x, kwargs):
1433
with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p:
1434
fn(x, **kwargs, dtype=torch.float)
1436
self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events()))
1438
self.assertFalse("aten::to" in (e.name for e in p.events()))
1440
for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})):
1441
profile_and_check(f, x, kwargs)
1446
@dtypes(*floating_and_complex_types())
1447
@precisionOverride({torch.float32: 1e-3})
1448
def test_cond(self, device, dtype):
1449
def run_test_case(input, p):
1450
result = torch.linalg.cond(input, p)
1451
result_numpy = np.linalg.cond(input.cpu().numpy(), p)
1452
self.assertEqual(result, result_numpy, rtol=1e-2, atol=self.precision, exact_dtype=False)
1453
self.assertEqual(result.shape, result_numpy.shape)
1456
out = torch.empty_like(result)
1457
ans = torch.linalg.cond(input, p, out=out)
1458
self.assertEqual(ans, out)
1459
self.assertEqual(ans, result)
1461
norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1462
input_sizes = [(32, 32), (2, 3, 3, 3)]
1463
for input_size in input_sizes:
1464
input = torch.randn(*input_size, dtype=dtype, device=device)
1465
for p in norm_types:
1466
run_test_case(input, p)
1469
input_sizes = [(0, 3, 3), (0, 2, 5, 5)]
1470
for input_size in input_sizes:
1471
input = torch.randn(*input_size, dtype=dtype, device=device)
1472
for p in norm_types:
1473
run_test_case(input, p)
1476
input_sizes = [(16, 32), (32, 16), (2, 3, 5, 3), (2, 3, 3, 5)]
1477
for input_size in input_sizes:
1478
input = torch.randn(*input_size, dtype=dtype, device=device)
1479
for p in [2, -2, None]:
1480
run_test_case(input, p)
1483
a = torch.eye(3, dtype=dtype, device=device)
1485
for p in norm_types:
1488
except np.linalg.LinAlgError:
1494
input_sizes = [(0, 0), (2, 5, 0, 0)]
1495
for input_size in input_sizes:
1496
input = torch.randn(*input_size, dtype=dtype, device=device)
1497
for p in ['fro', 2]:
1498
expected_dtype = a.real.dtype if dtype.is_complex else dtype
1499
expected = torch.zeros(input_size[:-2], dtype=expected_dtype, device=device)
1500
actual = torch.linalg.cond(input, p)
1501
self.assertEqual(actual, expected)
1506
@dtypes(*floating_and_complex_types())
1507
@precisionOverride({torch.float32: 1e-3})
1508
def test_cond_errors_and_warnings(self, device, dtype):
1509
norm_types = [1, -1, 2, -2, inf, -inf, 'fro', 'nuc', None]
1512
a = torch.ones(3, dtype=dtype, device=device)
1513
for p in norm_types:
1514
with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'):
1515
torch.linalg.cond(a, p)
1518
a = torch.ones(3, 2, dtype=dtype, device=device)
1519
norm_types = [1, -1, inf, -inf, 'fro', 'nuc']
1520
for p in norm_types:
1521
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
1522
torch.linalg.cond(a, p)
1525
a = torch.ones((2, 2), dtype=dtype, device=device)
1526
for p in ['fro', 2]:
1527
real_dtype = a.real.dtype if dtype.is_complex else dtype
1528
out = torch.empty(a.shape, dtype=real_dtype, device=device)
1529
with warnings.catch_warnings(record=True) as w:
1531
torch.linalg.cond(a, p, out=out)
1533
self.assertEqual(len(w), 1)
1534
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
1537
out = torch.empty(0, dtype=torch.int, device=device)
1538
for p in ['fro', 2]:
1539
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
1540
torch.linalg.cond(a, p, out=out)
1543
if torch.cuda.is_available():
1544
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
1545
out = torch.empty(0, dtype=dtype, device=wrong_device)
1546
for p in ['fro', 2]:
1547
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
1548
torch.linalg.cond(a, p, out=out)
1556
a = torch.eye(3, 3, dtype=dtype, device=device)
1557
a = a.reshape((1, 3, 3))
1558
a = a.repeat(batch_dim, 1, 1)
1560
for p in [1, -1, inf, -inf, 'fro', 'nuc']:
1561
result = torch.linalg.cond(a, p)
1562
self.assertEqual(result[1], float('inf'))
1565
a = torch.ones(3, 3, dtype=dtype, device=device)
1566
for p in ['wrong_norm', 5]:
1567
with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"):
1568
torch.linalg.cond(a, p)
1572
@dtypes(torch.float, torch.double)
1573
def test_norm_errors(self, device, dtype):
1574
def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex):
1576
f'test case input.size()={input.size()}, ord={ord}, dim={dim}, '
1577
f'keepdim={keepdim}, dtype={dtype}')
1579
with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info):
1580
torch.linalg.norm(input, ord, dim, keepdim)
1582
input_numpy = input.cpu().numpy()
1584
msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"'
1585
with self.assertRaises(Exception, msg=test_case_info):
1586
np.linalg.norm(input_numpy, ord, dim, keepdim)
1589
error_test_cases = [
1591
((S, ), ['fro', 'nuc'], None, RuntimeError, r'A must have at least 2 dimensions'),
1592
((S, S), [3.5], None, RuntimeError, r'matrix_norm: Order 3.5 not supported'),
1593
((S, S), [0], None, RuntimeError, r'matrix_norm: Order 0 not supported'),
1594
((S, S), ['fail'], None, RuntimeError, r'matrix_norm: Order fail not supported'),
1595
((S, S), ['fro', 'nuc'], 0, RuntimeError, r'matrix_norm: dim must be a 2-tuple'),
1596
((S, S), ['fro', 'nuc', 2], (0, 0), RuntimeError, r'dims must be different'),
1597
((S, S), ['fro', 'nuc', 2], (-1, 1), RuntimeError, r'dims must be different'),
1598
((S, S), ['fro', 'nuc', 2], (0, 4), IndexError, r'Dimension out of range'),
1599
((S, ), [0], (4, ), IndexError, r'Dimension out of range'),
1600
((S, ), [None], (0, 0), RuntimeError, r'dim 0 appears multiple times'),
1601
((S, S, S), [1], (0, 1, 2), RuntimeError, r"If dim is specified, it must be of length 1 or 2."),
1602
((S, S, S), [1], None, RuntimeError, r"If dim is not specified but ord is, the input must be 1D or 2D"),
1604
for keepdim in [True, False]:
1605
for input_size, ord_settings, dim, error_type, error_regex in error_test_cases:
1606
input = torch.randn(*input_size, dtype=dtype, device=device)
1607
for ord in ord_settings:
1608
run_error_test_case(input, ord, dim, keepdim, error_type, error_regex)
1613
@dtypes(torch.cfloat, torch.cdouble)
1614
@precisionOverride({torch.cfloat: 5e-4})
1615
def test_norm_complex(self, device, dtype):
1616
def gen_error_message(input_size, ord, keepdim, dim=None):
1617
return f"complex norm failed for input size {input_size}, ord={ord}, keepdim={keepdim}, dim={dim}"
1619
vector_ords = [None, 0, 1, 2, 3, inf, -1, -2, -3, -inf]
1620
matrix_ords = [None, 'fro', 'nuc', 1, 2, inf, -1, -2, -inf]
1623
for keepdim in [False, True]:
1625
x = torch.randn(25, device=device, dtype=dtype)
1626
xn = x.cpu().numpy()
1627
for ord in vector_ords:
1628
res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1629
expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1630
msg = gen_error_message(x.size(), ord, keepdim)
1631
self.assertEqual(res.shape, expected.shape, msg=msg)
1632
self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1634
res_out = torch.tensor([], device=device, dtype=res.dtype)
1635
torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1636
self.assertEqual(res_out.shape, expected.shape, msg=msg)
1637
self.assertEqual(res_out, expected, msg=msg)
1640
x = torch.randn(25, 25, device=device, dtype=dtype)
1641
xn = x.cpu().numpy()
1642
for ord in matrix_ords:
1643
res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu()
1644
expected = np.linalg.norm(xn, ord, keepdims=keepdim)
1645
msg = gen_error_message(x.size(), ord, keepdim)
1646
self.assertEqual(res.shape, expected.shape, msg=msg)
1647
self.assertEqual(res, expected, msg=msg, exact_dtype=False)
1649
res_out = torch.tensor([], device=device, dtype=res.dtype)
1650
torch.linalg.norm(x, ord, keepdim=keepdim, out=res_out)
1651
self.assertEqual(res_out.shape, expected.shape, msg=msg)
1652
self.assertEqual(res_out, expected, msg=msg)
1656
def test_vector_norm_extreme_values(self, device):
1657
vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1659
for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1660
vectors.append(list(pair))
1661
for vector in vectors:
1662
x = torch.tensor(vector, device=device)
1663
x_n = x.cpu().numpy()
1664
for ord in vector_ords:
1665
msg = f'ord={ord}, vector={vector}'
1666
result = torch.linalg.vector_norm(x, ord=ord)
1667
result_n = np.linalg.norm(x_n, ord=ord)
1668
self.assertEqual(result, result_n, msg=msg)
1670
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1671
def test_vector_norm_reduce_over_1D_vector(self, device, dtype):
1672
input_sizes_and_dims = [
1674
((3, 1, 2, 1), (1, 3)),
1677
orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2]
1678
keepdims = [True, False]
1680
for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims):
1681
input_size = input_size_and_dim[0]
1682
dim = input_size_and_dim[1]
1683
if type(dim) is tuple and ord == 0:
1686
input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
1687
result = torch.linalg.vector_norm(input, ord, dim, keepdim)
1688
result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim)
1690
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1691
self.assertEqual(result, result_numpy, msg=msg)
1693
@skipCUDAIfNoMagmaAndNoCusolver
1695
@dtypes(torch.float, torch.double)
1696
@precisionOverride({torch.float32: 2e-5})
1697
def test_matrix_norm(self, device, dtype):
1699
A = make_tensor((2, 2, 2), dtype=dtype, device=device)
1701
with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must have at least 2 dimensions.*'):
1702
torch.linalg.matrix_norm(make_tensor((2,), dtype=dtype, device=device))
1703
with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm:.*must be a 2-tuple.*'):
1704
torch.linalg.matrix_norm(A, dim=(0,))
1705
with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1706
torch.linalg.matrix_norm(A, ord=0)
1707
with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'):
1708
torch.linalg.matrix_norm(A, ord=3.0)
1711
ref = torch.linalg.norm(A, dim=(-2, -1))
1712
res = torch.linalg.matrix_norm(A)
1713
self.assertEqual(ref, res)
1717
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
1718
@unittest.skipIf(IS_MACOS, "Skipped on MacOS!")
1721
def test_norm_extreme_values(self, device):
1722
vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf]
1725
matrix_ords = ['fro', 1, inf, -1, -inf]
1728
for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2):
1729
vectors.append(list(pair))
1730
matrices.append([[pair[0], pair[1]]])
1731
matrices.append([[pair[0]], [pair[1]]])
1732
for vector in vectors:
1733
x = torch.tensor(vector).to(device)
1734
x_n = x.cpu().numpy()
1735
for ord in vector_ords:
1736
msg = f'ord={ord}, vector={vector}'
1737
result = torch.linalg.norm(x, ord=ord)
1738
result_n = np.linalg.norm(x_n, ord=ord)
1739
self.assertEqual(result, result_n, msg=msg)
1742
def is_broken_matrix_norm_case(ord, x):
1743
if self.device_type == 'cuda':
1744
if x.size() == torch.Size([1, 2]):
1745
if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1:
1749
if ord in ['nuc', 2, -2]:
1755
for matrix in matrices:
1756
x = torch.tensor(matrix).to(device)
1757
x_n = x.cpu().numpy()
1758
for ord in matrix_ords:
1759
msg = f'ord={ord}, matrix={matrix}'
1760
if is_broken_matrix_norm_case(ord, x):
1763
result_n = np.linalg.norm(x_n, ord=ord)
1764
result = torch.linalg.norm(x, ord=ord)
1765
self.assertEqual(result, result_n, msg=msg)
1770
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1771
def test_norm_vector_degenerate_shapes(self, device, dtype):
1772
def run_test_case(input, ord, dim, keepdim):
1773
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1774
if (input.numel() == 0 and
1775
(ord < 0. or ord == inf) and
1776
(dim is None or input.shape[dim] == 0)):
1777
with self.assertRaises(RuntimeError):
1778
torch.linalg.norm(input, ord, dim, keepdim)
1780
input_numpy = input.cpu().numpy()
1781
result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1782
result = torch.linalg.norm(input, ord, dim, keepdim)
1783
self.assertEqual(result, result_numpy, msg=msg)
1785
ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf]
1795
for keepdim in [True, False]:
1796
for input_size, dim in test_cases:
1797
input = torch.randn(*input_size, dtype=dtype, device=device)
1798
for ord in ord_vector:
1799
run_test_case(input, ord, dim, keepdim)
1804
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
1805
def test_norm_matrix_degenerate_shapes(self, device, dtype):
1806
def run_test_case(input, ord, dim, keepdim, should_error):
1807
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
1808
input_numpy = input.cpu().numpy()
1809
ops = [torch.linalg.norm]
1811
if ord is not None and dim is not None:
1812
ops.append(torch.linalg.matrix_norm)
1815
with self.assertRaises(ValueError):
1816
np.linalg.norm(input_numpy, ord, dim, keepdim)
1818
with self.assertRaises(IndexError):
1819
op(input, ord, dim, keepdim)
1821
result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim)
1823
result = op(input, ord, dim, keepdim)
1824
self.assertEqual(result, result_numpy, msg=msg)
1826
ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None]
1830
((0, 0), [1, 2, inf, -1, -2, -inf], None),
1831
((0, S), [2, inf, -2, -inf], None),
1832
((S, 0), [1, 2, -1, -2], None),
1833
((S, S, 0), [], (0, 1)),
1834
((1, S, 0), [], (0, 1)),
1835
((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)),
1836
((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)),
1839
for keepdim in [True, False]:
1840
for input_size, error_ords, dim in test_cases:
1841
input = torch.randn(*input_size, dtype=dtype, device=device)
1842
for ord in ord_matrix:
1843
run_test_case(input, ord, dim, keepdim, ord in error_ords)
1845
def test_norm_fastpaths(self, device):
1846
x = torch.randn(3, 5, device=device)
1849
result = torch.linalg.norm(x, 4.5, 1)
1850
expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5)
1851
self.assertEqual(result, expected)
1854
result = torch.linalg.norm(x, 0, 1)
1855
expected = (x != 0).type_as(x).sum(1)
1856
self.assertEqual(result, expected)
1859
result = torch.linalg.norm(x, 1, 1)
1860
expected = x.abs().sum(1)
1861
self.assertEqual(result, expected)
1864
result = torch.linalg.norm(x, 2, 1)
1865
expected = torch.sqrt(x.pow(2).sum(1))
1866
self.assertEqual(result, expected)
1869
result = torch.linalg.norm(x, 3, 1)
1870
expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0)
1871
self.assertEqual(result, expected)
1877
@dtypes(torch.float64, torch.complex128)
1878
def test_eig_numpy(self, device, dtype):
1879
def run_test(shape, *, symmetric=False):
1880
from torch.testing._internal.common_utils import random_symmetric_matrix
1882
if not dtype.is_complex and symmetric:
1885
a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1887
a = make_tensor(shape, dtype=dtype, device=device)
1889
actual = torch.linalg.eig(a)
1894
expected = np.linalg.eig(a.cpu().numpy())
1897
ind = np.argsort(expected[0], axis=-1)[::-1]
1898
expected = (np.take_along_axis(expected[0], ind, axis=-1), np.take_along_axis(expected[1], ind[:, None], axis=-1))
1904
ind = np.argsort(actual[0].cpu().numpy(), axis=-1)[::-1]
1905
actual_np = [x.cpu().numpy() for x in actual]
1907
np.take_along_axis(actual_np[0], ind, axis=-1),
1908
np.take_along_axis(actual_np[1], ind[:, None], axis=-1))
1910
self.assertEqual(expected[0], sorted_actual[0], exact_dtype=False)
1911
self.assertEqual(abs(expected[1]), abs(sorted_actual[1]), exact_dtype=False)
1915
(0, 0, 0), (0, 5, 5),
1918
for shape in shapes:
1920
run_test(shape, symmetric=True)
1924
@dtypes(*floating_and_complex_types())
1925
def test_eig_compare_backends(self, device, dtype):
1926
def run_test(shape, *, symmetric=False):
1927
from torch.testing._internal.common_utils import random_symmetric_matrix
1929
if not dtype.is_complex and symmetric:
1931
a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
1933
a = make_tensor(shape, dtype=dtype, device=device)
1935
actual = torch.linalg.eig(a)
1937
complementary_device = 'cpu'
1940
expected = torch.linalg.eig(a.to(complementary_device))
1941
self.assertEqual(expected[0], actual[0])
1942
self.assertEqual(expected[1], actual[1])
1946
(0, 0, 0), (0, 5, 5),
1949
for shape in shapes:
1951
run_test(shape, symmetric=True)
1956
@dtypes(torch.float32)
1957
def test_eig_check_magma(self, device, dtype):
1959
shape = (2049, 2049)
1960
a = make_tensor(shape, dtype=dtype, device=device)
1961
w, v = torch.linalg.eig(a)
1963
self.assertEqual(a.to(v.dtype) @ v, w * v, atol=1e-3, rtol=1e-3)
1967
@dtypes(*floating_and_complex_types())
1968
def test_eig_errors_and_warnings(self, device, dtype):
1970
a = make_tensor(2, dtype=dtype, device=device)
1971
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
1975
a = make_tensor((2, 3), dtype=dtype, device=device)
1976
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
1980
if not dtype.is_complex:
1982
a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
1983
out0 = torch.empty(0, device=device, dtype=dtype)
1984
out1 = torch.empty(0, device=device, dtype=dtype)
1985
with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
1986
torch.linalg.eig(a, out=(out0, out1))
1988
out0 = torch.empty(0, device=device, dtype=torch.complex128)
1989
with self.assertRaisesRegex(RuntimeError, "Expected eigenvectors to be safely castable"):
1990
torch.linalg.eig(a, out=(out0, out1))
1993
a = make_tensor((3, 3), dtype=dtype, device=device)
1994
out0 = torch.empty(0, dtype=torch.int, device=device)
1995
out1 = torch.empty(0, dtype=torch.int, device=device)
1996
with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
1997
torch.linalg.eig(a, out=(out0, out1))
1999
out0 = torch.empty(0, dtype=torch.complex128, device=device)
2000
with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"):
2001
torch.linalg.eig(a, out=(out0, out1))
2004
a = make_tensor((3, 3), dtype=dtype, device=device)
2005
out0 = torch.empty(1, device=device, dtype=torch.complex128)
2006
out1 = torch.empty(1, device=device, dtype=torch.complex128)
2007
with warnings.catch_warnings(record=True) as w:
2009
torch.linalg.eig(a, out=(out0, out1))
2011
self.assertEqual(len(w), 2)
2012
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2013
self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
2016
if torch.cuda.is_available():
2017
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2018
out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2019
out_v = torch.empty(0, device=device, dtype=torch.complex128)
2020
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2021
torch.linalg.eig(a, out=(out_w, out_v))
2022
out_w = torch.empty(0, device=device, dtype=torch.complex128)
2023
out_v = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2024
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2025
torch.linalg.eig(a, out=(out_w, out_v))
2029
@dtypes(*floating_and_complex_types())
2030
def test_eig_with_nan(self, device, dtype):
2031
for val in [np.inf, np.nan]:
2032
for batch_dim in [(), (10,)]:
2033
a = make_tensor((*batch_dim, 5, 5), device=device, dtype=dtype)
2034
a[..., -1, -1] = val
2036
with self.assertRaisesRegex(RuntimeError, "torch.linalg.eig: input tensor should not"):
2043
@dtypes(torch.float64, torch.complex128)
2044
def test_eigvals_numpy(self, device, dtype):
2045
def run_test(shape, *, symmetric=False):
2046
from torch.testing._internal.common_utils import random_symmetric_matrix
2048
if not dtype.is_complex and symmetric:
2051
a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2053
a = make_tensor(shape, dtype=dtype, device=device)
2055
actual = torch.linalg.eigvals(a)
2060
expected = np.linalg.eigvals(a.cpu().numpy())
2063
ind = np.argsort(expected, axis=-1)[::-1]
2064
expected = np.take_along_axis(expected, ind, axis=-1)
2070
ind = np.argsort(actual.cpu().numpy(), axis=-1)[::-1]
2071
actual_np = actual.cpu().numpy()
2072
sorted_actual = np.take_along_axis(actual_np, ind, axis=-1)
2074
self.assertEqual(expected, sorted_actual, exact_dtype=False)
2078
(0, 0, 0), (0, 5, 5),
2081
for shape in shapes:
2083
run_test(shape, symmetric=True)
2087
@dtypes(*floating_and_complex_types())
2088
def test_eigvals_compare_backends(self, device, dtype):
2089
def run_test(shape, *, symmetric=False):
2090
from torch.testing._internal.common_utils import random_symmetric_matrix
2092
if not dtype.is_complex and symmetric:
2094
a = random_symmetric_matrix(shape[-1], *shape[:-2], dtype=dtype, device=device)
2096
a = make_tensor(shape, dtype=dtype, device=device)
2098
actual = torch.linalg.eigvals(a)
2100
complementary_device = 'cpu'
2103
expected = torch.linalg.eigvals(a.to(complementary_device))
2104
self.assertEqual(expected, actual)
2107
complex_dtype = dtype
2108
if not dtype.is_complex:
2109
complex_dtype = torch.complex128 if dtype == torch.float64 else torch.complex64
2110
out = torch.empty(0, dtype=complex_dtype, device=device)
2111
ans = torch.linalg.eigvals(a, out=out)
2112
self.assertEqual(ans, out)
2113
self.assertEqual(expected.to(complex_dtype), out)
2117
out = torch.empty(2 * shape[0], *shape[1:-1], dtype=complex_dtype, device=device)[::2]
2118
self.assertFalse(out.is_contiguous())
2119
ans = torch.linalg.eigvals(a, out=out)
2120
self.assertEqual(ans, out)
2121
self.assertEqual(expected.to(complex_dtype), out)
2125
(0, 0, 0), (0, 5, 5),
2128
for shape in shapes:
2130
run_test(shape, symmetric=True)
2134
@dtypes(*floating_and_complex_types())
2135
def test_eigvals_errors_and_warnings(self, device, dtype):
2137
a = make_tensor(2, dtype=dtype, device=device)
2138
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
2139
torch.linalg.eigvals(a)
2142
a = make_tensor((2, 3), dtype=dtype, device=device)
2143
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2144
torch.linalg.eigvals(a)
2147
if not dtype.is_complex:
2149
a = torch.tensor([[3., -2.], [4., -1.]], dtype=dtype, device=device)
2150
out = torch.empty(0, device=device, dtype=dtype)
2151
with self.assertRaisesRegex(RuntimeError, "Expected eigenvalues to be safely castable"):
2152
torch.linalg.eigvals(a, out=out)
2155
a = make_tensor((3, 3), dtype=dtype, device=device)
2156
out = torch.empty(0, dtype=torch.int, device=device)
2157
with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
2158
torch.linalg.eigvals(a, out=out)
2161
out = torch.empty(1, device=device, dtype=torch.complex128)
2162
with warnings.catch_warnings(record=True) as w:
2164
torch.linalg.eigvals(a, out=out)
2166
self.assertEqual(len(w), 1)
2167
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2170
if torch.cuda.is_available():
2171
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2172
out_w = torch.empty(0, device=wrong_device, dtype=torch.complex128)
2173
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2174
torch.linalg.eigvals(a, out=out_w)
2178
def test_norm_old(self, device):
2179
def gen_error_message(input_size, p, keepdim, dim=None):
2180
return f"norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2185
class PrecisionContext:
2186
def __init__(self, test, norm):
2188
self.saved_overrides = getattr(test, 'precision_overrides', None)
2189
self.target_test = test
2191
def __enter__(self):
2192
if 'nuc' != self.norm:
2194
self.target_test.precision_overrides = {torch.float: 1e-4, torch.cfloat: 2e-4}
2195
return self.target_test.precision_overrides
2197
def __exit__(self, type, value, tb) -> bool:
2198
if 'nuc' != self.norm:
2200
if self.saved_overrides is None:
2201
delattr(self.target_test, 'precision_overrides')
2203
self.target_test.precision_overrides = self.saved_overrides
2206
for keepdim in [False, True]:
2208
x = torch.randn(25, device=device)
2209
xn = x.cpu().numpy()
2210
for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]:
2211
res = x.norm(p, keepdim=keepdim).cpu()
2212
expected = np.linalg.norm(xn, p, keepdims=keepdim)
2213
self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim))
2216
x = torch.randn(25, 25, device=device)
2217
xn = x.cpu().numpy()
2218
for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]:
2220
res = x.norm(p, dim, keepdim=keepdim).cpu()
2221
expected = np.linalg.norm(xn, p, dim, keepdims=keepdim)
2222
msg = gen_error_message(x.size(), p, keepdim, dim)
2223
self.assertEqual(res.shape, expected.shape, msg=msg)
2224
self.assertEqual(res, expected, msg=msg)
2227
for p in ['fro', 'nuc']:
2228
res = x.norm(p, keepdim=keepdim).cpu()
2229
expected = np.linalg.norm(xn, p, keepdims=keepdim)
2230
msg = gen_error_message(x.size(), p, keepdim)
2231
with PrecisionContext(self, p):
2232
self.assertEqual(res.shape, expected.shape, msg=msg)
2233
self.assertEqual(res, expected, msg=msg)
2236
x = torch.randn((), device=device)
2237
xn = x.cpu().numpy()
2238
res = x.norm(keepdim=keepdim).cpu()
2239
expected = np.linalg.norm(xn, keepdims=keepdim)
2240
msg = gen_error_message(x.size(), None, keepdim)
2241
self.assertEqual(res.shape, expected.shape, msg=msg)
2242
self.assertEqual(res, expected, msg=msg)
2246
2 * torch.norm(torch.ones(10000), keepdim=keepdim),
2247
torch.norm(torch.ones(40000), keepdim=keepdim))
2250
x = torch.randn(5, 6, 7, 8, device=device)
2251
xn = x.cpu().numpy()
2252
for p in ['fro', 'nuc']:
2253
for dim in itertools.product(*[list(range(4))] * 2):
2254
if dim[0] == dim[1]:
2256
res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu()
2257
expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim)
2258
msg = gen_error_message(x.size(), p, keepdim, dim)
2259
with PrecisionContext(self, p):
2260
self.assertEqual(res.shape, expected.shape, msg=msg)
2261
self.assertEqual(res, expected, msg=msg)
2264
def test_norm_old_nan_propagation(self, device):
2266
for pair in itertools.product([0.0, nan, 1.0], repeat=2):
2267
x = torch.tensor(list(pair), device=device)
2269
result = torch.norm(x, p=ord)
2270
result_check = torch.linalg.norm(x, ord=ord)
2271
self.assertEqual(result, result_check)
2275
def test_norm_complex_old(self, device):
2276
def gen_error_message(input_size, p, keepdim, dim=None):
2277
return f"complex norm failed for input size {input_size}, p={p}, keepdim={keepdim}, dim={dim}"
2279
for keepdim in [False, True]:
2281
x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device)
2282
xn = x.cpu().numpy()
2283
for p in [0, 1, 2, 3, inf, -1, -2, -3, -inf]:
2284
res = x.norm(p, keepdim=keepdim).cpu()
2285
expected = np.linalg.norm(xn, p, keepdims=keepdim)
2286
msg = gen_error_message(x.size(), p, keepdim)
2287
self.assertEqual(res.shape, expected.shape, msg=msg)
2288
self.assertEqual(res, expected, msg=msg)
2291
x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device)
2292
xn = x.cpu().numpy()
2293
for p in ['nuc', 'fro']:
2294
res = x.norm(p, keepdim=keepdim).cpu()
2295
expected = np.linalg.norm(xn, p, keepdims=keepdim)
2296
msg = gen_error_message(x.size(), p, keepdim)
2297
self.assertEqual(res.shape, expected.shape, msg=msg)
2298
self.assertEqual(res, expected, msg=msg, rtol=4e-6, atol=6e-4)
2301
@dtypes(torch.float)
2302
def test_norm_fro_2_equivalence_old(self, device, dtype):
2319
for input_size in input_sizes:
2320
a = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
2323
dim_settings = [None]
2326
dim_settings += list(range(-a.dim(), a.dim()))
2328
def wrap_dim(dim, ndims):
2329
assert (dim < ndims) and (dim >= -ndims)
2337
(d0, d1) for d0, d1 in itertools.combinations(range(-a.dim(), a.dim()), 2)
2338
if wrap_dim(d0, a.dim()) != wrap_dim(d1, a.dim())]
2340
for dim in dim_settings:
2341
for keepdim in [True, False]:
2342
a_norm_2 = torch.norm(a, p=2, dim=dim, keepdim=keepdim)
2343
a_norm_fro = torch.norm(a, p='fro', dim=dim, keepdim=keepdim)
2344
self.assertEqual(a_norm_fro, a_norm_2)
2346
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
2349
def test_nuclear_norm_axes_small_brute_force_old(self, device):
2350
def check_single_nuclear_norm(x, axes):
2351
if self.device_type != 'cpu' and randrange(100) < 95:
2354
a = np.array(x.cpu(), copy=False)
2355
expected = np.linalg.norm(a, "nuc", axis=axes)
2357
ans = torch.norm(x, "nuc", dim=axes)
2358
self.assertTrue(ans.is_contiguous())
2359
self.assertEqual(ans.shape, expected.shape)
2360
self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2362
out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
2363
ans = torch.norm(x, "nuc", dim=axes, out=out)
2364
self.assertIs(ans, out)
2365
self.assertTrue(ans.is_contiguous())
2366
self.assertEqual(ans.shape, expected.shape)
2367
self.assertEqual(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)
2369
for n in range(1, 3):
2370
for m in range(1, 3):
2371
for axes in itertools.permutations([0, 1], 2):
2373
x = torch.randn(n, m, device=device)
2374
check_single_nuclear_norm(x, axes)
2377
x = torch.randn(m, n, device=device).mT
2378
check_single_nuclear_norm(x, axes)
2381
x = torch.randn(n, 2 * m, device=device)[:, ::2]
2382
check_single_nuclear_norm(x, axes)
2385
x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2]
2386
check_single_nuclear_norm(x, axes)
2388
for o in range(1, 3):
2389
for axes in itertools.permutations([0, 1, 2], 2):
2391
x = torch.randn(o, n, m, device=device)
2392
check_single_nuclear_norm(x, axes)
2395
x = torch.randn(o, m, n, device=device).mT
2396
check_single_nuclear_norm(x, axes)
2399
x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2]
2400
check_single_nuclear_norm(x, axes)
2403
x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2]
2404
check_single_nuclear_norm(x, axes)
2406
for r in range(1, 3):
2407
for axes in itertools.permutations([0, 1, 2, 3], 2):
2409
x = torch.randn(r, o, n, m, device=device)
2410
check_single_nuclear_norm(x, axes)
2413
x = torch.randn(r, o, n, m, device=device).mT
2414
check_single_nuclear_norm(x, axes)
2417
x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2]
2418
check_single_nuclear_norm(x, axes)
2421
x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2]
2422
check_single_nuclear_norm(x, axes)
2425
def test_nuclear_norm_exceptions_old(self, device):
2426
for lst in [], [1], [1, 2]:
2427
x = torch.tensor(lst, dtype=torch.double, device=device)
2428
for axes in (), (0,):
2429
self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes)
2430
self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1))
2432
x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device)
2433
self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0))
2434
self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2))
2436
@skipCUDAIfNoCusolver
2438
@dtypes(torch.double, torch.cdouble)
2439
def test_svd_lowrank(self, device, dtype):
2440
from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
2442
def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **options):
2443
density = options.pop('density', 1)
2444
if isinstance(matrix_size, int):
2445
rows = columns = matrix_size
2447
rows, columns = matrix_size
2449
a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
2452
assert batches == ()
2453
a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
2454
a = a_input.to_dense()
2457
u, s, v = svd_lowrank(a_input, q=q, **options)
2460
u, s, v = u[..., :q], s[..., :q], v[..., :q]
2461
A = (u * s.unsqueeze(-2)).matmul(v.mH)
2462
self.assertEqual(A, a, rtol=1e-7, atol=2e-7)
2465
U, S, Vh = torch.linalg.svd(a, full_matrices=False)
2467
self.assertEqual(s, S)
2474
u, v = u[..., :actual_rank], v[..., :actual_rank]
2475
U, V = U[..., :actual_rank], V[..., :actual_rank]
2476
expected_ones = u.mH.matmul(U).det().abs()
2477
self.assertEqual(expected_ones, torch.ones_like(expected_ones))
2478
self.assertEqual(v.mH.matmul(V).det().abs(), torch.ones_like(expected_ones))
2480
all_batches = [(), (1,), (3,), (2, 3)]
2481
for actual_rank, size, all_batches in [
2482
(2, (17, 4), all_batches),
2483
(4, (17, 4), all_batches),
2484
(4, (17, 17), all_batches),
2485
(10, (100, 40), all_batches),
2486
(7, (1000, 1000), [()]),
2489
for batches in all_batches:
2490
run_subtest(actual_rank, size, batches, device, torch.svd_lowrank)
2491
if size != size[::-1]:
2492
run_subtest(actual_rank, size[::-1], batches, device, torch.svd_lowrank)
2495
for size in [(17, 4), (4, 17), (17, 17), (100, 40), (40, 100), (1000, 1000)]:
2496
for density in [0.005, 0.1]:
2497
run_subtest(None, size, (), device, torch.svd_lowrank, density=density)
2500
jitted = torch.jit.script(torch.svd_lowrank)
2501
actual_rank, size, batches = 2, (17, 4), ()
2502
run_subtest(actual_rank, size, batches, device, jitted)
2504
@skipCUDAIfNoMagmaAndNoCusolver
2506
@precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4})
2507
@setLinalgBackendsToDefaultFinally
2508
@dtypes(*floating_and_complex_types())
2510
def test_svd(self, device, dtype):
2512
make_arg = partial(make_tensor, dtype=dtype, device=device)
2514
backends = ["default"]
2516
if torch.device(device).type == 'cuda':
2517
if torch.cuda.has_magma:
2518
backends.append("magma")
2519
if has_cusolver() or has_hipsolver():
2520
backends.append("cusolver")
2523
batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
2524
drivers = (None, 'gesvd', 'gesvdj', 'gesvda')
2526
for backend in backends:
2527
torch.backends.cuda.preferred_linalg_library(backend)
2529
for batch, m, n, driver in product(batches, ns, ns, drivers):
2530
if not (backend == 'cusolver' or driver is None):
2536
shape = batch + (m, n)
2539
U, S, Vh = torch.linalg.svd(A, full_matrices=False, driver=driver)
2540
self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ Vh, A)
2542
U_f, S_f, Vh_f = torch.linalg.svd(A, full_matrices=True, driver=driver)
2543
self.assertEqual(S_f, S)
2544
self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ Vh_f[..., :k, :], A)
2546
S_s = torch.linalg.svdvals(A, driver=driver)
2547
self.assertEqual(S_s, S)
2549
U, S, V = torch.svd(A, some=True)
2550
self.assertEqual((U @ S.to(A.dtype).diag_embed()) @ V.mH, A)
2552
U_f, S_f, V_f = torch.svd(A, some=False)
2553
self.assertEqual(S_f, S)
2554
self.assertEqual((U_f[..., :k] @ S_f.to(A.dtype).diag_embed()) @ V_f[..., :k].mH, A)
2556
S_s = torch.svd(A, compute_uv=False).S
2557
self.assertEqual(S_s, S)
2559
@skipCUDAIfNoMagmaAndNoCusolver
2561
@dtypes(torch.complex128)
2562
def test_invariance_error_spectral_decompositions(self, device, dtype):
2563
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
2564
A = make_arg((3, 3))
2565
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2566
U, _, Vh = torch.linalg.svd(A, full_matrices=False)
2567
(U + Vh).sum().abs().backward()
2569
A = make_arg((3, 3))
2570
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2571
V = torch.linalg.eig(A).eigenvectors
2572
V.sum().abs().backward()
2574
A = make_arg((3, 3))
2576
with self.assertRaisesRegex(RuntimeError, "ill-defined"):
2577
Q = torch.linalg.eigh(A).eigenvectors
2578
Q.sum().abs().backward()
2580
@skipCUDAIfNoCusolver
2581
@precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
2583
@dtypes(*floating_and_complex_types())
2584
def test_svd_memory_allocation(self, device, dtype):
2589
a = make_tensor((m, n), dtype=dtype, device=device)
2591
S = torch.linalg.svdvals(a)
2592
result = torch.linalg.svd(a, full_matrices=False)
2593
self.assertEqual(result.S, S)
2595
def cholesky_solve_test_helper(self, A_dims, b_dims, upper, device, dtype):
2596
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2598
b = torch.randn(*b_dims, dtype=dtype, device=device)
2599
A = random_hermitian_pd_matrix(*A_dims, dtype=dtype, device=device)
2600
L = torch.cholesky(A, upper=upper)
2605
@dtypes(*floating_and_complex_types())
2606
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2607
torch.float64: 1e-8, torch.complex128: 1e-8})
2608
def test_cholesky_solve(self, device, dtype):
2609
for (k, n), upper in itertools.product(zip([2, 3, 5], [3, 5, 7]), [True, False]):
2610
b, A, L = self.cholesky_solve_test_helper((n,), (n, k), upper, device, dtype)
2611
x = torch.cholesky_solve(b, L, upper=upper)
2612
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
2616
@dtypes(*floating_and_complex_types())
2617
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2618
torch.float64: 1e-8, torch.complex128: 1e-8})
2619
def test_cholesky_solve_batched(self, device, dtype):
2620
def cholesky_solve_batch_helper(A_dims, b_dims, upper):
2621
b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2623
for i in range(b_dims[0]):
2624
x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper))
2625
x_exp = torch.stack(x_exp_list)
2626
x_act = torch.cholesky_solve(b, L, upper=upper)
2627
self.assertEqual(x_act, x_exp)
2628
Ax = np.matmul(A.cpu(), x_act.cpu())
2629
self.assertEqual(b, Ax)
2631
for upper, batchsize in itertools.product([True, False], [1, 3, 4]):
2632
cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), upper)
2637
@dtypes(*floating_and_complex_types())
2638
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2639
torch.float64: 1e-8, torch.complex128: 1e-8})
2640
def test_cholesky_solve_batched_many_batches(self, device, dtype):
2641
for A_dims, b_dims in zip([(5, 256, 256), (5,)], [(5, 10), (512, 512, 5, 10)]):
2642
for upper in [True, False]:
2643
b, A, L = self.cholesky_solve_test_helper(A_dims, b_dims, upper, device, dtype)
2644
x = torch.cholesky_solve(b, L, upper)
2645
Ax = torch.matmul(A, x)
2646
self.assertEqual(Ax, b.expand_as(Ax))
2650
@dtypes(*floating_and_complex_types())
2651
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2652
torch.float64: 1e-8, torch.complex128: 1e-8})
2653
def test_cholesky_solve_batched_broadcasting(self, device, dtype):
2654
from numpy.linalg import solve
2655
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2657
def run_test(A_dims, b_dims, upper):
2658
A_matrix_size = A_dims[-1]
2659
A_batch_dims = A_dims[:-2]
2660
A = random_hermitian_pd_matrix(A_matrix_size, *A_batch_dims,
2661
dtype=dtype, device='cpu')
2662
b = torch.randn(*b_dims, dtype=dtype, device='cpu')
2663
x_exp = torch.tensor(solve(A.numpy(), b.numpy()), dtype=dtype, device=device)
2664
A, b = A.to(dtype=dtype, device=device), b.to(dtype=dtype, device=device)
2665
L = torch.linalg.cholesky(A, upper=upper)
2666
x = torch.cholesky_solve(b, L, upper=upper)
2667
self.assertEqual(x, x_exp)
2669
x = torch.cholesky_solve(b, L, upper=upper, out=x)
2670
self.assertEqual(x, x_exp)
2673
for upper in [True, False]:
2674
run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), upper)
2675
run_test((2, 1, 3, 4, 4), (4, 6), upper)
2676
run_test((4, 4), (2, 1, 3, 4, 2), upper)
2677
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), upper)
2681
@dtypes(*floating_and_complex_types())
2682
def test_cholesky_solve_out_errors_and_warnings(self, device, dtype):
2684
a = torch.eye(2, dtype=dtype, device=device)
2685
b = torch.randn(2, 1, dtype=dtype, device=device)
2686
out = torch.empty(0, dtype=torch.int, device=device)
2687
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2688
torch.cholesky_solve(b, a, out=out)
2691
if torch.cuda.is_available():
2692
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2693
out = torch.empty(0, dtype=dtype, device=wrong_device)
2694
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
2695
torch.cholesky_solve(b, a, out=out)
2698
with warnings.catch_warnings(record=True) as w:
2699
out = torch.empty(1, dtype=dtype, device=device)
2701
torch.cholesky_solve(b, a, out=out)
2703
self.assertEqual(len(w), 1)
2704
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2708
@dtypes(torch.double)
2709
def test_cholesky_solve_backward(self, device, dtype):
2713
for test_L_grad in (False, True):
2714
b = torch.randn(*b_dims, dtype=dtype, device=device, requires_grad=True)
2715
L = torch.randn(*L_dims, dtype=dtype, device=device, requires_grad=test_L_grad)
2717
torch.autograd.gradcheck(lambda b, L: torch.cholesky_solve(b, torch.tril(L), upper=False), (b, L))
2719
torch.autograd.gradcheck(lambda b: torch.cholesky_solve(b, L, upper=False), (b,))
2721
@skipCUDAIfNoMagmaAndNoCusolver
2723
@dtypes(*floating_and_complex_types())
2724
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2725
torch.float64: 1e-8, torch.complex128: 1e-8})
2726
def test_inverse(self, device, dtype):
2727
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2728
make_arg = partial(make_fullrank, device=device, dtype=dtype)
2730
def run_test(torch_inverse, matrix, batches, n):
2731
matrix_inverse = torch_inverse(matrix)
2736
expected = np.linalg.inv(matrix.cpu().numpy())
2737
self.assertEqual(matrix_inverse, expected, atol=self.precision, rtol=self.precision)
2740
identity = torch.eye(n, dtype=dtype, device=device)
2741
self.assertEqual(identity.expand_as(matrix), np.matmul(matrix.cpu(), matrix_inverse.cpu()))
2742
self.assertEqual(identity.expand_as(matrix), np.matmul(matrix_inverse.cpu(), matrix.cpu()))
2746
matrix_inverse_out = torch.empty(*batches, n, n, dtype=dtype, device=device)
2747
matrix_inverse_out_t = matrix_inverse_out.mT.clone(memory_format=torch.contiguous_format)
2748
matrix_inverse_out = matrix_inverse_out_t.mT
2749
ans = torch_inverse(matrix, out=matrix_inverse_out)
2750
self.assertEqual(matrix_inverse_out, ans, atol=0, rtol=0)
2751
self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0)
2754
if matrix.ndim > 2 and batches[0] != 0:
2755
expected_inv_list = []
2756
p = int(np.prod(batches))
2757
for mat in matrix.contiguous().view(p, n, n):
2758
expected_inv_list.append(torch_inverse(mat))
2759
expected_inv = torch.stack(expected_inv_list).view(*batches, n, n)
2760
if self.device_type == 'cuda' and dtype in [torch.float32, torch.complex64]:
2764
self.assertEqual(matrix_inverse, expected_inv, atol=1e-1, rtol=1e-2)
2766
self.assertEqual(matrix_inverse, expected_inv)
2769
def test_inv_ex(input, out=None):
2771
info = torch.empty(0, dtype=torch.int32, device=device)
2772
return torch.linalg.inv_ex(input, out=(out, info)).inverse
2773
return torch.linalg.inv_ex(input).inverse
2775
for torch_inverse in [torch.inverse, torch.linalg.inv, test_inv_ex]:
2776
for batches, n in itertools.product(
2777
[[], [0], [2], [2, 1]],
2780
matrices = make_arg(*batches, n, n)
2781
run_test(torch_inverse, matrices, batches, n)
2784
run_test(torch_inverse, matrices.mT, batches, n)
2788
make_arg(*batches, 2 * n, 2 * n)
2789
.view(-1, n * 2, n * 2)[:, ::2, ::2].view(*batches, n, n),
2793
@skipCUDAIfNoMagmaAndNoCusolver
2795
@dtypes(*floating_and_complex_types())
2796
def test_inv_ex_info_device(self, device, dtype):
2797
A = torch.eye(3, 3, dtype=dtype, device=device)
2798
info = torch.linalg.inv_ex(A).info
2799
self.assertTrue(info.device == A.device)
2801
@skipCUDAIfNoMagmaAndNoCusolver
2803
@dtypes(*floating_and_complex_types())
2804
def test_inv_ex_singular(self, device, dtype):
2806
A = torch.eye(3, 3, dtype=dtype, device=device)
2808
info = torch.linalg.inv_ex(A).info
2809
self.assertEqual(info, 3)
2810
with self.assertRaisesRegex(torch.linalg.LinAlgError,
2811
r'diagonal element 3 is zero, the inversion could not be completed'):
2812
torch.linalg.inv_ex(A, check_errors=True)
2816
A = torch.eye(3, 3, dtype=dtype, device=device)
2817
A = A.reshape((1, 3, 3))
2818
A = A.repeat(5, 1, 1)
2820
info = torch.linalg.inv_ex(A).info
2822
expected_info = torch.zeros(A.shape[:-2], dtype=torch.int32, device=device)
2823
expected_info[3] = 2
2824
self.assertEqual(info, expected_info)
2825
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 3\): The diagonal element 2 is zero'):
2826
torch.linalg.inv_ex(A, check_errors=True)
2829
@skipCUDAIfNoMagmaAndNoCusolver
2831
@dtypes(*floating_and_complex_types())
2832
@precisionOverride({torch.float32: 2e-3, torch.complex64: 2e-3,
2833
torch.float64: 1e-5, torch.complex128: 1e-5})
2834
def test_inverse_many_batches(self, device, dtype):
2835
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
2836
make_arg = partial(make_fullrank, device=device, dtype=dtype)
2838
def test_inverse_many_batches_helper(torch_inverse, b, n):
2839
matrices = make_arg(b, n, n)
2840
matrices_inverse = torch_inverse(matrices)
2843
expected = np.linalg.inv(matrices.cpu().numpy())
2844
self.assertEqual(matrices_inverse, expected, atol=self.precision, rtol=1e-3)
2846
for torch_inverse in [torch.inverse, torch.linalg.inv]:
2847
test_inverse_many_batches_helper(torch_inverse, 5, 256)
2848
test_inverse_many_batches_helper(torch_inverse, 3, 512)
2850
@skipCUDAIfNoMagmaAndNoCusolver
2852
@onlyNativeDeviceTypes
2853
@dtypes(*floating_and_complex_types())
2854
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
2855
def test_inverse_errors(self, device, dtype):
2857
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
2858
torch.inverse(torch.randn(2, 3, 4, 3))
2861
def run_test_singular_input(batch_dim, n):
2862
x = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
2864
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf'\(Batch element {n}\): The diagonal element 3 is zero'):
2867
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
2868
run_test_singular_input(*params)
2870
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
2871
@skipCUDAIfNoMagmaAndNoCusolver
2873
@onlyNativeDeviceTypes
2874
@dtypes(*floating_and_complex_types())
2875
def test_inverse_errors_large(self, device, dtype):
2877
x = torch.empty((8, 10, 616, 616), dtype=dtype, device=device)
2878
x[:] = torch.eye(616, dtype=dtype, device=device)
2880
with self.assertRaisesRegex(torch.linalg.LinAlgError, r'\(Batch element 0\): The diagonal element 11 is zero'):
2883
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
2886
@dtypes(*floating_and_complex_types())
2887
def test_pinv(self, device, dtype):
2888
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
2890
def run_test_main(A, hermitian):
2892
A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
2893
np_A = A.cpu().numpy()
2894
np_A_pinv = A_pinv.cpu().numpy()
2896
self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=self.precision, rtol=self.precision)
2897
self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=self.precision, rtol=self.precision)
2898
self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1))
2899
self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1))
2901
self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
2904
out = torch.empty_like(A_pinv)
2905
ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
2906
self.assertEqual(ans, out)
2907
self.assertEqual(ans, A_pinv)
2909
def run_test_numpy(A, hermitian):
2912
rconds = [float(torch.rand(1)), ]
2914
for rcond_type in all_types():
2915
rconds.append(torch.rand(A.shape[:-2], dtype=torch.double, device=device).to(rcond_type))
2918
rconds.append(torch.rand(A.shape[-3], device=device))
2919
for rcond in rconds:
2920
actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
2921
torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
2922
self.assertEqual(actual, torch_rtol)
2923
numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
2924
expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
2925
self.assertEqual(actual, expected, atol=self.precision, rtol=1e-5)
2927
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),
2928
(3, 2), (5, 3, 2), (2, 5, 3, 2),
2929
(2, 3), (5, 2, 3), (2, 5, 2, 3),
2930
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:
2931
A = torch.randn(*sizes, dtype=dtype, device=device)
2933
run_test_main(A, hermitian)
2934
run_test_numpy(A, hermitian)
2937
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5),
2938
(0, 0), (3, 0, 0), ]:
2939
A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
2941
run_test_main(A, hermitian)
2942
run_test_numpy(A, hermitian)
2946
@dtypes(*floating_and_complex_types())
2947
def test_pinv_errors_and_warnings(self, device, dtype):
2949
a = torch.randn(1, device=device, dtype=dtype)
2950
with self.assertRaisesRegex(RuntimeError, "expected a tensor with 2 or more dimensions"):
2951
torch.linalg.pinv(a)
2954
a = torch.randn(3, 3, dtype=dtype, device=device)
2955
out = torch.empty(7, 7, dtype=dtype, device=device)
2956
with warnings.catch_warnings(record=True) as w:
2958
torch.linalg.pinv(a, out=out)
2960
self.assertEqual(len(w), 1)
2961
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
2964
out = torch.empty_like(a).to(torch.int)
2965
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
2966
torch.linalg.pinv(a, out=out)
2968
if torch.cuda.is_available():
2970
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2971
out = torch.empty_like(a).to(wrong_device)
2972
with self.assertRaisesRegex(RuntimeError, "Expected result and input tensors to be on the same device"):
2973
torch.linalg.pinv(a, out=out)
2976
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
2977
rcond = torch.full((), 1e-2, device=wrong_device)
2978
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
2979
torch.linalg.pinv(a, rcond=rcond)
2982
rcond = torch.full((), 1j, device=device)
2983
with self.assertRaisesRegex(RuntimeError, "rcond tensor of complex type is not supported"):
2984
torch.linalg.pinv(a, rcond=rcond)
2987
atol = torch.full((), 1j, device=device)
2988
with self.assertRaisesRegex(RuntimeError, "atol tensor of complex type is not supported"):
2989
torch.linalg.pinv(a, atol=atol)
2992
rtol = torch.full((), 1j, device=device)
2993
with self.assertRaisesRegex(RuntimeError, "rtol tensor of complex type is not supported"):
2994
torch.linalg.pinv(a, rtol=rtol)
2996
@skipCUDAIfNoMagmaAndNoCusolver
2998
@dtypes(*floating_and_complex_types())
2999
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/129882")
3000
def test_inv_errors_and_warnings(self, device, dtype):
3002
a = torch.randn(2, 3, 4, 3, dtype=dtype, device=device)
3003
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
3007
a = torch.randn(2, device=device, dtype=dtype)
3008
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
3012
def run_test_singular_input(batch_dim, n):
3013
a = torch.eye(3, 3, dtype=dtype, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
3015
with self.assertRaisesRegex(torch.linalg.LinAlgError, rf"\(Batch element {n}\): The diagonal element 3 is zero"):
3018
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
3019
run_test_singular_input(*params)
3022
a = torch.eye(2, dtype=dtype, device=device)
3023
out = torch.empty(0, dtype=torch.int, device=device)
3024
with self.assertRaisesRegex(RuntimeError, "but got int instead"):
3025
torch.linalg.inv(a, out=out)
3028
if torch.cuda.is_available():
3029
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3030
out = torch.empty(0, device=wrong_device, dtype=dtype)
3031
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3032
torch.linalg.inv(a, out=out)
3035
with warnings.catch_warnings(record=True) as w:
3036
a = torch.eye(2, dtype=dtype, device=device)
3037
out = torch.empty(1, dtype=dtype, device=device)
3039
torch.linalg.inv(a, out=out)
3041
self.assertEqual(len(w), 1)
3042
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3045
with warnings.catch_warnings(record=True) as w:
3046
a = torch.eye(2, dtype=dtype, device=device)
3047
out = torch.empty(3, 3, dtype=dtype, device=device)
3048
out = out.mT.clone(memory_format=torch.contiguous_format)
3050
self.assertTrue(out.mT.is_contiguous())
3052
torch.linalg.inv(a, out=out)
3054
self.assertEqual(len(w), 1)
3055
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3057
def solve_test_helper(self, A_dims, b_dims, device, dtype):
3058
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3059
make_A = partial(make_fullrank, device=device, dtype=dtype)
3061
b = torch.randn(*b_dims, dtype=dtype, device=device)
3067
@dtypes(*floating_and_complex_types())
3068
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3})
3069
def test_solve(self, device, dtype):
3070
def run_test(n, batch, rhs):
3071
A_dims = (*batch, n, n)
3072
b_dims = (*batch, n, *rhs)
3073
b, A = self.solve_test_helper(A_dims, b_dims, device, dtype)
3076
x = torch.linalg.solve(A, b)
3078
Ax = np.matmul(A.cpu(), x.unsqueeze(-1).cpu())
3081
Ax = np.matmul(A.cpu(), x.cpu())
3082
self.assertEqual(b.expand_as(Ax), Ax)
3085
expected = np.linalg.solve(A.cpu().numpy(), b.expand_as(x).cpu().numpy())
3086
self.assertEqual(x, expected)
3088
batches = [(), (0, ), (3, ), (2, 3)]
3090
nrhs = [(), (1, ), (5, )]
3091
for n, batch, rhs in itertools.product(ns, batches, nrhs):
3092
run_test(n, batch, rhs)
3094
@skipCUDAIfNoMagmaAndNoCusolver
3096
@dtypes(*floating_and_complex_types())
3097
def test_solve_batched_broadcasting(self, device, dtype):
3098
from numpy.linalg import solve
3100
def run_test(A_dims, B_dims):
3101
A_matrix_size = A_dims[-1]
3102
A_batch_dims = A_dims[:-2]
3103
B, A = self.solve_test_helper(A_batch_dims + (A_matrix_size, A_matrix_size), B_dims, device, dtype)
3104
actual = torch.linalg.solve(A, B)
3105
expected = solve(A.cpu().numpy(), B.cpu().numpy())
3106
self.assertEqual(actual, expected)
3109
run_test((5, 5), (2, 0, 5, 3))
3110
run_test((2, 0, 5, 5), (5, 3))
3111
run_test((2, 1, 3, 4, 4), (4, 6))
3112
run_test((4, 4), (2, 1, 3, 4, 2))
3113
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))
3117
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3118
@precisionOverride({torch.float: 1e-4, torch.cfloat: 1e-4})
3119
def test_tensorsolve(self, device, dtype):
3120
def run_test(a_shape, dims):
3121
a = torch.randn(a_shape, dtype=dtype, device=device)
3122
b = torch.randn(a_shape[:2], dtype=dtype, device=device)
3123
result = torch.linalg.tensorsolve(a, b, dims=dims)
3124
expected = np.linalg.tensorsolve(a.cpu().numpy(), b.cpu().numpy(), axes=dims)
3125
self.assertEqual(result, expected)
3128
out = torch.empty_like(result)
3129
ans = torch.linalg.tensorsolve(a, b, dims=dims, out=out)
3130
self.assertEqual(ans, out)
3131
self.assertEqual(ans, result)
3133
a_shapes = [(2, 3, 6), (3, 4, 4, 3)]
3134
dims = [None, (0, 2)]
3135
for a_shape, d in itertools.product(a_shapes, dims):
3136
run_test(a_shape, d)
3140
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3141
def test_tensorsolve_empty(self, device, dtype):
3143
a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3144
b = torch.empty(a.shape[:2], dtype=dtype, device=device)
3145
x = torch.linalg.tensorsolve(a, b)
3146
self.assertEqual(torch.tensordot(a, x, dims=len(x.shape)), b)
3150
@dtypes(torch.float32)
3151
def test_tensorsolve_errors_and_warnings(self, device, dtype):
3153
a = torch.eye(2 * 3 * 4, dtype=dtype, device=device).reshape((2 * 3, 4, 2, 3, 4))
3154
b = torch.randn(8, 4, dtype=dtype, device=device)
3155
self.assertTrue(np.prod(a.shape[2:]) != np.prod(b.shape))
3156
with self.assertRaisesRegex(RuntimeError, r'Expected self to satisfy the requirement'):
3157
torch.linalg.tensorsolve(a, b)
3160
out = torch.empty_like(a)
3161
b = torch.randn(6, 4, dtype=dtype, device=device)
3162
with warnings.catch_warnings(record=True) as w:
3164
torch.linalg.tensorsolve(a, b, out=out)
3166
self.assertEqual(len(w), 1)
3167
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3170
out = torch.empty_like(a).to(torch.int)
3171
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3172
torch.linalg.tensorsolve(a, b, out=out)
3175
if torch.cuda.is_available():
3176
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3177
out = torch.empty(0, dtype=dtype, device=wrong_device)
3178
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3179
torch.linalg.tensorsolve(a, b, out=out)
3183
@dtypes(*floating_and_complex_types())
3184
@precisionOverride({torch.float: 1e-3, torch.cfloat: 1e-3})
3185
def test_tensorinv(self, device, dtype):
3187
def run_test(a_shape, ind):
3188
a = torch.randn(a_shape, dtype=dtype, device=device)
3189
a_numpy = a.cpu().numpy()
3190
result = torch.linalg.tensorinv(a, ind=ind)
3191
expected = np.linalg.tensorinv(a_numpy, ind=ind)
3192
self.assertEqual(result, expected)
3195
out = torch.empty_like(result)
3196
ans = torch.linalg.tensorinv(a, ind=ind, out=out)
3197
self.assertEqual(ans, out)
3198
self.assertEqual(ans, result)
3201
run_test((12, 3, 4), ind=1)
3202
run_test((3, 8, 24), ind=2)
3203
run_test((18, 3, 3, 2), ind=1)
3204
run_test((1, 4, 2, 2), ind=2)
3205
run_test((2, 3, 5, 30), ind=3)
3206
run_test((24, 2, 2, 3, 2), ind=1)
3207
run_test((3, 4, 2, 3, 2), ind=2)
3208
run_test((1, 2, 3, 2, 3), ind=3)
3209
run_test((3, 2, 1, 2, 12), ind=4)
3214
@dtypes(*floating_and_complex_types())
3215
def test_tensorinv_empty(self, device, dtype):
3216
for ind in range(1, 4):
3218
a = torch.empty(0, 0, 1, 2, 3, 0, dtype=dtype, device=device)
3219
a_inv = torch.linalg.tensorinv(a, ind=ind)
3220
self.assertEqual(a_inv.shape, a.shape[ind:] + a.shape[:ind])
3225
@dtypes(*floating_and_complex_types())
3226
def test_tensorinv_errors_and_warnings(self, device, dtype):
3228
def check_shape(a_shape, ind):
3231
a = torch.randn(a_shape, dtype=dtype, device=device)
3232
with self.assertRaisesRegex(RuntimeError, "Expected self to satisfy the requirement"):
3233
torch.linalg.tensorinv(a, ind=ind)
3235
def check_ind(a_shape, ind):
3236
a = torch.randn(a_shape, dtype=dtype, device=device)
3237
with self.assertRaisesRegex(RuntimeError, "Expected a strictly positive integer"):
3238
torch.linalg.tensorinv(a, ind=ind)
3240
def check_out(a_shape, ind):
3242
a = torch.randn(a_shape, dtype=dtype, device=device)
3243
out = torch.empty_like(a)
3244
with warnings.catch_warnings(record=True) as w:
3246
torch.linalg.tensorinv(a, ind=ind, out=out)
3248
self.assertEqual(len(w), 1)
3249
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3252
out = torch.empty(0, dtype=torch.int, device=device)
3253
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
3254
torch.linalg.tensorinv(a, ind=ind, out=out)
3257
if torch.cuda.is_available():
3258
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3259
out = torch.empty(0, dtype=dtype, device=wrong_device)
3260
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3261
torch.linalg.tensorinv(a, ind=ind, out=out)
3264
check_shape((2, 3, 4), ind=1)
3265
check_shape((1, 2, 3, 4), ind=3)
3268
check_ind((12, 3, 4), ind=-1)
3269
check_ind((18, 3, 3, 2), ind=0)
3272
check_out((12, 3, 4), ind=1)
3273
check_out((3, 8, 24), ind=2)
3277
@dtypes(*floating_and_complex_types())
3278
def test_tensorinv_singular_input(self, device, dtype):
3280
def check_singular_input(a_shape, ind):
3281
prod_ind_end = np.prod(a_shape[ind:])
3282
a = torch.eye(prod_ind_end, dtype=dtype, device=device)
3284
a = a.reshape(a_shape)
3285
with self.assertRaisesRegex(torch.linalg.LinAlgError, "The diagonal element"):
3286
torch.linalg.tensorinv(a, ind=ind)
3289
check_singular_input((12, 3, 4), ind=1)
3290
check_singular_input((3, 6, 18), ind=2)
3292
def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
3295
res = torch_fn(x, y)
3296
if x.dtype == torch.bfloat16:
3297
ref = torch.from_numpy(np.array(np_fn(x.cpu().float().numpy(), y.cpu().float().numpy())))
3299
ref = torch.from_numpy(np.array(np_fn(x.cpu().numpy(), y.cpu().numpy())))
3300
if res.dtype == torch.bfloat16:
3301
self.assertEqual(res.cpu(), ref.bfloat16())
3303
self.assertEqual(res.cpu(), ref)
3306
out = torch.empty_like(res)
3307
torch_fn(x, y, out=out)
3308
self.assertEqual(out, res)
3311
x = torch.tensor([], dtype=dtype, device=device)
3312
y = torch.tensor([], dtype=dtype, device=device)
3316
x = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3317
y = 0.1 * torch.randn(5000, dtype=dtype, device=device)
3321
y = 0.1 * torch.randn(1, dtype=dtype, device=device).expand(5000)
3325
check(x[::2], y[::2])
3327
@dtypes(torch.float, torch.cfloat, torch.bfloat16, torch.float16)
3328
@dtypesIfCUDA(torch.float, torch.cfloat)
3329
@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5, torch.bfloat16: 1e-0})
3330
def test_dot_vs_numpy(self, device, dtype):
3331
self._test_dot_vdot_vs_numpy(device, dtype, torch.dot, np.dot)
3333
@dtypes(torch.float, torch.cfloat)
3334
@precisionOverride({torch.cfloat: 1e-4, torch.float32: 5e-5})
3335
def test_vdot_vs_numpy(self, device, dtype):
3336
self._test_dot_vdot_vs_numpy(device, dtype, torch.vdot, np.vdot)
3338
def _test_dot_vdot_invalid_args(self, device, torch_fn, complex_dtypes=False):
3339
def check(x, y, regex):
3340
with self.assertRaisesRegex(RuntimeError, regex):
3344
x = torch.randn(1, dtype=torch.cfloat, device=device)
3345
y = torch.randn(3, dtype=torch.cdouble, device=device)
3347
x = torch.randn(1, dtype=torch.float, device=device)
3348
y = torch.randn(3, dtype=torch.double, device=device)
3350
check(x, y, 'dot : expected both vectors to have same dtype')
3351
check(x.reshape(1, 1), y, '1D tensors expected')
3352
check(x.expand(9), y.to(x.dtype), 'inconsistent tensor size')
3354
if self.device_type != 'cpu':
3355
x_cpu = x.expand(3).cpu()
3356
check(x_cpu, y.to(x.dtype), 'Expected all tensors to be on the same device')
3358
@onlyNativeDeviceTypes
3359
def test_vdot_invalid_args(self, device):
3360
self._test_dot_vdot_invalid_args(device, torch.vdot)
3361
self._test_dot_vdot_invalid_args(device, torch.vdot, complex_dtypes=True)
3363
@onlyNativeDeviceTypes
3364
def test_dot_invalid_args(self, device):
3365
self._test_dot_vdot_invalid_args(device, torch.dot)
3366
self._test_dot_vdot_invalid_args(device, torch.dot, complex_dtypes=True)
3370
@dtypes(*floating_and_complex_types())
3371
def test_matrix_rank(self, device, dtype):
3372
matrix_rank = torch.linalg.matrix_rank
3374
def run_test(shape0, shape1, batch):
3375
a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3376
rank_a = matrix_rank(a)
3378
self.assertEqual(rank_a, matrix_rank(a.mH))
3379
aaH = torch.matmul(a, a.mH)
3380
rank_aaH = matrix_rank(aaH)
3381
rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3382
self.assertEqual(rank_aaH, rank_aaH_hermitian)
3383
aHa = torch.matmul(a.mH, a)
3384
self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3387
self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
3388
self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
3390
self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
3391
self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
3394
if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
3395
self.assertEqual(rank_aaH_hermitian,
3396
np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
3397
self.assertEqual(matrix_rank(aaH, 0.01, True),
3398
np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
3401
out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
3402
ans = matrix_rank(a, out=out)
3403
self.assertEqual(ans, out)
3404
self.assertEqual(ans, rank_a)
3407
batches = ((), (0, ), (4, ), (3, 5, ))
3408
for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3409
run_test(shape0, shape1, batch)
3413
@dtypes(*floating_and_complex_types())
3414
def test_matrix_rank_atol(self, device, dtype):
3416
def run_test_atol(shape0, shape1, batch):
3417
a = make_tensor((*batch, shape0, shape1), dtype=dtype, device=device)
3420
tolerances = [float(torch.rand(1)), ]
3422
for tol_type in all_types():
3423
tolerances.append(make_tensor(a.shape[:-2], dtype=tol_type, device=device, low=0))
3426
tolerances.append(make_tensor(a.shape[-3], dtype=torch.float32, device=device, low=0))
3427
for tol in tolerances:
3428
actual = torch.linalg.matrix_rank(a, atol=tol)
3429
actual_tol = torch.linalg.matrix_rank(a, tol=tol)
3430
self.assertEqual(actual, actual_tol)
3431
numpy_tol = tol if isinstance(tol, float) else tol.cpu().numpy()
3432
expected = np.linalg.matrix_rank(a.cpu().numpy(), tol=numpy_tol)
3433
self.assertEqual(actual, expected)
3436
batches = ((), (0, ), (4, ), (3, 5, ))
3437
for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
3438
run_test_atol(shape0, shape1, batch)
3442
@dtypes(torch.float64)
3443
def test_matrix_rank_atol_rtol(self, device, dtype):
3444
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
3445
make_arg = partial(make_fullrank, device=device, dtype=dtype)
3453
for tol_value in [0.81, torch.tensor(0.81, device=device)]:
3455
result = torch.linalg.matrix_rank(a, rtol=tol_value)
3456
self.assertEqual(result, 2)
3459
result = torch.linalg.matrix_rank(a, atol=tol_value)
3460
self.assertEqual(result, 7)
3463
result = torch.linalg.matrix_rank(a, atol=tol_value, rtol=tol_value)
3464
self.assertEqual(result, 2)
3468
@skipCUDAVersionIn([(11, 6), (11, 7)])
3469
@dtypes(*floating_and_complex_types())
3470
def test_matrix_rank_empty(self, device, dtype):
3471
matrix_rank = torch.linalg.matrix_rank
3474
def run_test(shape0, shape1, batch):
3475
a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
3476
rank_a = matrix_rank(a)
3477
expected = torch.zeros(batch, dtype=torch.int64, device=device)
3479
self.assertEqual(rank_a, matrix_rank(a.mH))
3481
aaH = torch.matmul(a, a.mH)
3482
rank_aaH = matrix_rank(aaH)
3483
rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
3484
self.assertEqual(rank_aaH, rank_aaH_hermitian)
3486
aHa = torch.matmul(a.mH, a)
3487
self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
3489
self.assertEqual(rank_a, expected)
3490
self.assertEqual(matrix_rank(a, 0.01), expected)
3492
self.assertEqual(rank_aaH, expected)
3493
self.assertEqual(matrix_rank(aaH, 0.01), expected)
3495
self.assertEqual(rank_aaH_hermitian, expected)
3496
self.assertEqual(matrix_rank(aaH, 0.01, True), expected)
3498
batches = ((), (4, ), (3, 5, ))
3499
for batch in batches:
3500
run_test(0, 0, batch)
3501
run_test(0, 3, batch)
3502
run_test(3, 0, batch)
3506
@dtypes(*floating_and_complex_types())
3507
def test_matrix_rank_out_errors_and_warnings(self, device, dtype):
3509
a = torch.eye(2, dtype=dtype, device=device)
3510
out = torch.empty(0, dtype=torch.bool, device=device)
3511
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Bool"):
3512
torch.linalg.matrix_rank(a, out=out)
3515
if torch.cuda.is_available():
3516
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
3517
out = torch.empty(0, dtype=dtype, device=wrong_device)
3518
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
3519
torch.linalg.matrix_rank(a, out=out)
3522
with warnings.catch_warnings(record=True) as w:
3523
out = torch.empty(3, dtype=dtype, device=device)
3525
torch.linalg.matrix_rank(a, out=out)
3527
self.assertEqual(len(w), 1)
3528
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
3532
@dtypes(*floating_and_complex_types())
3533
def test_matrix_rank_basic(self, device, dtype):
3534
matrix_rank = torch.linalg.matrix_rank
3536
a = torch.eye(10, dtype=dtype, device=device)
3537
self.assertEqual(matrix_rank(a).item(), 10)
3538
self.assertEqual(matrix_rank(a, hermitian=True).item(), 10)
3541
self.assertEqual(matrix_rank(a).item(), 9)
3542
self.assertEqual(matrix_rank(a, hermitian=True).item(), 9)
3544
@onlyNativeDeviceTypes
3545
@dtypes(torch.double)
3547
def test_chain_matmul(self, device, dtype):
3549
t = make_tensor((2, 2), dtype=dtype, device=device)
3550
self.assertEqual(t, torch.chain_matmul(t))
3551
with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
3552
torch.chain_matmul()
3556
with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
3557
torch.chain_matmul(make_tensor(1, dtype=dtype, device=device), make_tensor(1, dtype=dtype, device=device))
3559
@onlyNativeDeviceTypes
3560
@dtypes(torch.double, torch.cdouble)
3561
def test_multi_dot(self, device, dtype):
3563
tensors = [make_tensor(shape, dtype=dtype, device=device) for shape in shapes]
3564
np_arrays = [tensor.cpu().numpy() for tensor in tensors]
3565
res = torch.linalg.multi_dot(tensors).cpu()
3566
ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays)))
3567
self.assertEqual(res, ref)
3573
check([0, 2], [2, 1])
3574
check([2, 2], [2, 0])
3575
check([2, 0], [0, 3])
3576
check([0, 0], [0, 1])
3577
check([4, 2], [2, 0], [0, 3], [3, 2])
3583
check([1, 2], [2, 1])
3584
check([3, 2], [2, 4])
3587
check([3], [3, 4], [4, 2], [2, 5], [5])
3588
check([1, 2], [2, 2], [2, 3], [3, 1])
3591
check([10, 100], [100, 5], [5, 50])
3592
check([10, 20], [20, 30], [30, 5])
3594
@onlyNativeDeviceTypes
3595
@dtypes(torch.float)
3596
def test_multi_dot_errors(self, device, dtype):
3597
def check(tensors, out, msg):
3598
with self.assertRaisesRegex(RuntimeError, msg):
3599
torch.linalg.multi_dot(tensors, out=out)
3601
a = make_tensor(2, dtype=dtype, device=device)
3603
check([], None, "expected at least 2 tensors")
3604
check([a], None, "expected at least 2 tensors")
3606
check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D")
3607
check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D")
3609
check([a, a, a], None, "tensor 1 must be 2D")
3610
check([a, make_tensor((2, 2, 2), dtype=dtype, device=device), a], None, "tensor 1 must be 2D")
3612
check([a, make_tensor(2, dtype=torch.double, device=device)], None, "all tensors must have be the same dtype")
3613
check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype")
3615
if self.device_type == 'cuda':
3616
check([a, make_tensor(2, dtype=dtype, device="cpu")], None, "all tensors must be on the same device")
3617
check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device")
3619
check([a, make_tensor(3, dtype=dtype, device=device)], None, "cannot be multiplied")
3620
check([a, make_tensor((3, 2), dtype=dtype, device=device), a], None, "cannot be multiplied")
3622
@precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
3623
@skipCUDAIfNoCusolver
3625
@dtypes(*floating_and_complex_types())
3626
def test_qr(self, device, dtype):
3627
def run_test(tensor_dims, some):
3628
A = torch.randn(*tensor_dims, dtype=dtype, device=device)
3629
Q, R = torch.qr(A, some=some)
3632
m, n = tensor_dims[-2:]
3633
n_columns = m if (not some) and m > n else min(m, n)
3634
self.assertEqual(Q.size(-2), m)
3635
self.assertEqual(R.size(-1), n)
3636
self.assertEqual(Q.size(-1), n_columns)
3638
A_ = A.cpu().numpy()
3639
Q_ = Q.cpu().numpy()
3640
R_ = R.cpu().numpy()
3643
self.assertEqual(A_, np.matmul(Q_, R_))
3646
Q_out, R_out = torch.full_like(Q, math.nan), torch.full_like(R, math.nan)
3647
torch.qr(A, some=some, out=(Q_out, R_out))
3648
Q_out_ = Q_out.cpu().numpy()
3649
R_out_ = R_out.cpu().numpy()
3650
self.assertEqual(A_, np.matmul(Q_out_, R_out_))
3653
self.assertEqual(Q_, Q_out_)
3654
self.assertEqual(R_, R_out_)
3657
eye = torch.eye(n_columns, device=device, dtype=dtype).expand(Q.shape[:-2] + (n_columns, n_columns)).cpu().numpy()
3658
self.assertEqual(np.matmul(Q_.swapaxes(-1, -2).conj(), Q_), eye)
3659
self.assertEqual(R.triu(), R)
3661
tensor_dims_list = [(0, 5), (0, 0), (5, 0),
3662
(2, 1, 0, 5), (2, 1, 0, 0), (2, 1, 5, 0), (2, 0, 5, 5),
3663
(3, 5), (5, 5), (5, 3),
3664
(7, 3, 5), (7, 5, 5), (7, 5, 3),
3665
(7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)]
3666
for tensor_dims, some in itertools.product(tensor_dims_list, [True, False]):
3667
run_test(tensor_dims, some)
3669
@skipCUDAIfNoCusolver
3671
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3672
def test_qr_vs_numpy(self, device, dtype):
3674
test torch.linalg.qr vs numpy.linalg.qr
3682
for size in sizes_to_test:
3683
t = torch.randn(size, device=device, dtype=dtype)
3684
np_t = t.cpu().numpy()
3685
for mode in ['reduced', 'complete']:
3686
exp_q, exp_r = np.linalg.qr(np_t, mode=mode)
3687
q, r = torch.linalg.qr(t, mode=mode)
3688
self.assertEqual(q, exp_q)
3689
self.assertEqual(r, exp_r)
3692
exp_r = np.linalg.qr(np_t, mode='r')
3693
q, r = torch.linalg.qr(t, mode='r')
3695
self.assertEqual(q.shape, (0,))
3696
self.assertEqual(q.dtype, t.dtype)
3697
self.assertEqual(q.device, t.device)
3699
self.assertEqual(r, exp_r)
3701
@skipCUDAIfNoCusolver
3703
@dtypes(torch.float)
3704
def test_linalg_qr_autograd_errors(self, device, dtype):
3708
inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
3709
q, r = torch.linalg.qr(inp, mode='r')
3710
self.assertEqual(q.shape, (0,))
3712
with self.assertRaisesRegex(RuntimeError,
3713
"The derivative of linalg.qr depends on Q"):
3715
inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
3716
q, r = torch.linalg.qr(inp, mode='complete')
3718
with self.assertRaisesRegex(RuntimeError,
3719
"The QR decomposition is not differentiable when mode='complete' and nrows > ncols"):
3722
@skipCUDAIfNoCusolver
3724
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
3725
def test_qr_batched(self, device, dtype):
3727
test torch.linalg.qr vs numpy.linalg.qr. We need some special logic
3728
because numpy does not support batched qr
3730
def np_qr_batched(a, mode):
3731
"""poor's man batched version of np.linalg.qr"""
3735
result = np.linalg.qr(matrix, mode=mode)
3737
all_r.append(result)
3743
return np.array(all_r)
3745
return np.array(all_q), np.array(all_r)
3747
t = torch.randn((3, 7, 5), device=device, dtype=dtype)
3748
np_t = t.cpu().numpy()
3749
for mode in ['reduced', 'complete']:
3750
exp_q, exp_r = np_qr_batched(np_t, mode=mode)
3751
q, r = torch.linalg.qr(t, mode=mode)
3752
self.assertEqual(q, exp_q)
3753
self.assertEqual(r, exp_r)
3755
exp_r = np_qr_batched(np_t, mode='r')
3756
q, r = torch.linalg.qr(t, mode='r')
3758
self.assertEqual(q.shape, (0,))
3759
self.assertEqual(q.dtype, t.dtype)
3760
self.assertEqual(q.device, t.device)
3762
self.assertEqual(r, exp_r)
3764
@skipCUDAIfNoCusolver
3766
@dtypes(torch.float)
3767
def test_qr_error_cases(self, device, dtype):
3768
t1 = torch.randn(5, device=device, dtype=dtype)
3769
with self.assertRaisesRegex(RuntimeError, 'linalg.qr: The input tensor A must have at least 2 dimensions.'):
3771
t2 = torch.randn((5, 7), device=device, dtype=dtype)
3772
with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
3773
torch.linalg.qr(t2, mode='hello')
3775
def _check_einsum(self, *args, np_args=None):
3777
np_args = [arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg for arg in args]
3778
ref = np.einsum(*np_args)
3779
res = torch.einsum(*args)
3780
self.assertEqual(ref, res)
3784
with opt_einsum.flags(enabled=False):
3785
res = torch.einsum(*args)
3786
self.assertEqual(ref, res)
3788
with opt_einsum.flags(enabled=True, strategy='greedy'):
3789
res = torch.einsum(*args)
3790
self.assertEqual(ref, res)
3792
with opt_einsum.flags(enabled=True, strategy='optimal'):
3793
res = torch.einsum(*args)
3794
self.assertEqual(ref, res)
3796
@dtypes(torch.double, torch.cdouble)
3797
def test_einsum(self, device, dtype):
3799
x = make_tensor((5,), dtype=dtype, device=device)
3800
y = make_tensor((7,), dtype=dtype, device=device)
3801
A = make_tensor((3, 5), dtype=dtype, device=device)
3802
B = make_tensor((2, 5), dtype=dtype, device=device)
3803
C = make_tensor((2, 3, 5), dtype=dtype, device=device)
3804
D = make_tensor((2, 5, 7), dtype=dtype, device=device)
3805
E = make_tensor((7, 9), dtype=dtype, device=device)
3806
F = make_tensor((2, 3, 3, 5), dtype=dtype, device=device)
3807
G = make_tensor((5, 4, 6), dtype=dtype, device=device)
3808
H = make_tensor((4, 4), dtype=dtype, device=device)
3809
I = make_tensor((2, 3, 2), dtype=dtype, device=device)
3812
self._check_einsum('i->', x)
3813
self._check_einsum('i,i->', x, x)
3814
self._check_einsum('i,i->i', x, x)
3815
self._check_einsum('i,j->ij', x, y)
3818
self._check_einsum("ij->ji", A)
3819
self._check_einsum("ij->j", A)
3820
self._check_einsum("ij->i", A)
3821
self._check_einsum("ij,ij->ij", A, A)
3822
self._check_einsum("ij,j->i", A, x)
3823
self._check_einsum("ij,kj->ik", A, B)
3824
self._check_einsum("ij,ab->ijab", A, E)
3827
self._check_einsum("Aij,Ajk->Aik", C, D)
3828
self._check_einsum("ijk,jk->i", C, A)
3829
self._check_einsum("aij,jk->aik", D, E)
3830
self._check_einsum("abCd,dFg->abCFg", F, G)
3831
self._check_einsum("ijk,jk->ik", C, A)
3832
self._check_einsum("ijk,jk->ij", C, A)
3833
self._check_einsum("ijk,ik->j", C, B)
3834
self._check_einsum("ijk,ik->jk", C, B)
3837
self._check_einsum("ii", H)
3838
self._check_einsum("ii->i", H)
3839
self._check_einsum('iji->j', I)
3840
self._check_einsum('ngrg...->nrg...', make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device))
3843
self._check_einsum("i...->...", H)
3844
self._check_einsum("ki,...k->i...", A.t(), B)
3845
self._check_einsum("k...,jk->...", A.t(), B)
3846
self._check_einsum('...ik, ...j -> ...ij', C, x)
3847
self._check_einsum('Bik,k...j->i...j', C, make_tensor((5, 3), dtype=dtype, device=device))
3848
self._check_einsum('i...j, ij... -> ...ij', C, make_tensor((2, 5, 2, 3), dtype=dtype, device=device))
3851
l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3852
r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3853
w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3854
self._check_einsum("bn,anm,bm->ba", l, w, r)
3857
self._check_einsum("bn,Anm,bm->bA", l[:, ::2], w[:, ::2, ::2], r[:, ::2])
3860
self._check_einsum("...,be,b...,beg,gi,bc...->bi...", A, B, C, D, E, F)
3862
@dtypes(torch.double, torch.cdouble)
3863
def test_einsum_sublist_format(self, device, dtype):
3864
x = make_tensor((5,), dtype=dtype, device=device)
3865
y = make_tensor((7,), dtype=dtype, device=device)
3866
A = make_tensor((3, 5), dtype=dtype, device=device)
3867
B = make_tensor((2, 5), dtype=dtype, device=device)
3868
C = make_tensor((2, 1, 3, 1, 4), dtype=dtype, device=device)
3870
self._check_einsum(x, [0])
3871
self._check_einsum(x, [0], [])
3872
self._check_einsum(x, [0], y, [1], [0, 1])
3873
self._check_einsum(A, [0, 1], [1, 0])
3874
self._check_einsum(A, [0, 1], x, [1], [0])
3875
self._check_einsum(A, [0, 1], B, [2, 1])
3876
self._check_einsum(A, [0, 1], B, [2, 1], [0, 2])
3877
self._check_einsum(C, [0, 1, 2, 1, Ellipsis], [0, 2, 1, Ellipsis])
3878
self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0])
3879
self._check_einsum(A.t(), [0, 1], B, [Ellipsis, 0], [1, Ellipsis])
3880
self._check_einsum(A.t(), [0, Ellipsis], B, [1, 0], [Ellipsis])
3883
l = make_tensor((5, 10), dtype=dtype, device=device, noncontiguous=True)
3884
r = make_tensor((5, 20), dtype=dtype, device=device, noncontiguous=True)
3885
w = make_tensor((15, 10, 20), dtype=dtype, device=device)
3886
self._check_einsum(l, [40, 41], w, [2, 41, 50], r, [40, 50], [40, 2])
3888
@dtypes(torch.double, torch.cdouble)
3889
def test_einsum_random(self, device, dtype):
3890
def convert_label(label):
3894
return chr(ord('A') + label)
3896
return chr(ord('a') + label - 26)
3898
def convert_sublist(sublist):
3899
return ''.join(convert_label(label) for label in sublist)
3903
min_ops=1, max_ops=4,
3904
min_dims=1, max_dims=3,
3905
min_size=1, max_size=8,
3907
enable_diagonals=True,
3909
broadcasting_prob=0.1):
3911
all_labels = torch.arange(52)
3914
assert 0 <= n_labels < len(all_labels)
3915
assert 0 < min_ops <= max_ops
3916
assert 0 <= min_dims <= max_dims
3917
assert 0 <= min_size <= max_size
3918
assert 0 <= max_out_dim
3919
assert enable_diagonals or max_dims <= n_labels
3924
possible_labels = all_labels[torch.randperm(len(all_labels))[:n_labels]]
3925
labels_size = torch.randint_like(all_labels, min_size, max_size + 1)
3926
ellipsis_shape = torch.randint(min_size, max_size + 1, (max_dims - min_dims,))
3932
valid_labels = set()
3935
for _ in range(random.randint(min_ops, max_ops)):
3936
n_dim = random.randint(min_dims, max_dims)
3937
labels_idx = torch.ones(len(possible_labels)).multinomial(n_dim, enable_diagonals)
3938
labels = possible_labels[labels_idx]
3939
valid_labels.update(labels.tolist())
3940
shape = labels_size[labels]
3943
mask = Binomial(probs=broadcasting_prob).sample((n_dim,))
3944
broadcast_labels = torch.unique(labels[mask == 1])
3945
shape[(labels[..., None] == broadcast_labels).any(-1)] = 1
3947
labels = labels.tolist()
3948
shape = shape.tolist()
3951
if n_dim < max_dims and torch.rand(1) < ellipsis_prob:
3952
ell_num_dim = random.randint(1, max_dims - n_dim)
3953
ell_size = max(ell_size, ell_num_dim)
3954
ell_shape = ellipsis_shape[-ell_num_dim:]
3956
mask = Binomial(probs=broadcasting_prob).sample((ell_num_dim,))
3957
ell_shape[mask == 1] = 1
3958
ell_index = random.randint(0, n_dim)
3959
shape[ell_index:ell_index] = ell_shape
3960
labels.insert(ell_index, ...)
3962
operands.append(make_tensor(shape, dtype=dtype, device=device))
3963
sublists.append(labels)
3968
np_operands = [op.cpu().numpy() for op in operands]
3971
equation = ','.join(convert_sublist(l) for l in sublists)
3972
self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3975
args = list(itertools.chain.from_iterable(zip(operands, sublists)))
3976
self._check_einsum(*args, np_args=(equation, *np_operands))
3980
num_out_labels = max(0, random.randint(0, min(max_out_dim, len(valid_labels))) - ell_size)
3981
if num_out_labels > 0:
3982
out_labels_idx = torch.ones(len(valid_labels)).multinomial(num_out_labels)
3983
out_sublist = torch.tensor(list(valid_labels))[out_labels_idx].tolist()
3984
out_sublist.insert(random.randint(0, num_out_labels), ...)
3987
equation += '->' + convert_sublist(out_sublist)
3988
self._check_einsum(equation, *operands, np_args=(equation, *np_operands))
3991
args.append(out_sublist)
3992
self._check_einsum(*args, np_args=(equation, *np_operands))
3996
def test_einsum_corner_cases(self, device):
3997
def check(equation, *operands, expected_output):
3998
tensors = [torch.tensor(operand, device=device, dtype=torch.float32) if not isinstance(operand, tuple)
3999
else make_tensor(operand, dtype=torch.float32, device=device) for operand in operands]
4000
output = torch.einsum(equation, tensors)
4001
self.assertEqual(output, torch.tensor(expected_output, dtype=torch.float32, device=device))
4004
check(' ', 1, expected_output=1)
4005
check(' -> ', 1, expected_output=1)
4006
check(' , ', 2, 2, expected_output=4)
4007
check(' , , ', 2, 2, 2, expected_output=8)
4008
check(' , -> ', 2, 2, expected_output=4)
4009
check(' i ', [1], expected_output=[1])
4010
check(' i -> ', [1], expected_output=1)
4011
check(' i -> i ', [1], expected_output=[1])
4012
check(' i , i ', [2], [2], expected_output=4)
4013
check(' i , i -> i ', [2], [2], expected_output=[4])
4016
check('i', [], expected_output=[])
4017
check(' i j -> j', [[], []], expected_output=[])
4018
check('ij->i', [[], []], expected_output=[0., 0.])
4019
check(' i j k , k -> i j ', (3, 0, 6), (6,), expected_output=[[], [], []])
4022
check('i,j', [2], [1, 2], expected_output=[[2, 4]])
4023
check('i,ij->ij', [1, 2], [[1, 2, 3], [2, 3, 4]], expected_output=[[1, 2, 3], [4, 6, 8]])
4026
check('...', 1, expected_output=1)
4027
check('...->', 1, expected_output=1)
4028
check('...->...', 1, expected_output=1)
4029
check('...', [1], expected_output=[1])
4030
check('...->', [1], expected_output=1)
4031
check('z...->z', [1], expected_output=[1])
4032
check('Z...->...Z', [1], expected_output=[1])
4033
check('...a->', [[2], [4]], expected_output=6)
4034
check('a...b->ab', [[[1], [2]], [[3], [4]]], expected_output=[[3], [7]])
4036
def test_einsum_error_cases(self, device):
4037
def check(*args, regex, exception=RuntimeError):
4038
with self.assertRaisesRegex(exception, r'einsum\(\):.*' + regex):
4041
x = make_tensor((2,), dtype=torch.float32, device=device)
4042
y = make_tensor((2, 3), dtype=torch.float32, device=device)
4044
check('', [], regex=r'at least one operand', exception=ValueError)
4045
check('. ..', [x], regex=r'found \'.\' for operand 0 that is not part of any ellipsis')
4046
check('... ...', [x], regex=r'found \'.\' for operand 0 for which an ellipsis was already found')
4047
check('1', [x], regex=r'invalid subscript given at index 0')
4048
check(',', [x], regex=r'fewer operands were provided than specified in the equation')
4049
check('', [x, x], regex=r'more operands were provided than specified in the equation')
4050
check('', [x], regex=r'the number of subscripts in the equation \(0\) does not match the number '
4051
r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4052
check('ai', [x], regex=r'the number of subscripts in the equation \(2\) does not match the number '
4053
r'of dimensions \(1\) for operand 0 and no ellipsis was given')
4054
check('ai...', [x], regex=r'the number of subscripts in the equation \(2\) is more than the number '
4055
r'of dimensions \(1\) for operand 0')
4056
check('a->... .', [x], regex=r'found \'.\' for output but an ellipsis \(...\) was already found')
4057
check('a->..', [x], regex=r'found \'.\' for output that is not part of any ellipsis \(...\)')
4058
check('a->1', [x], regex=r'invalid subscript given at index 3')
4059
check('a->aa', [x], regex=r'output subscript a appears more than once in the output')
4060
check('a->i', [x], regex=r'output subscript i does not appear in the equation for any input operand')
4061
check('aa', [y], regex=r'subscript a is repeated for operand 0 but the sizes don\'t match, 3 != 2')
4062
check('...,...', [x, y], regex=r'does not broadcast')
4063
check('a,a', [x, make_tensor((3,), dtype=torch.float32, device=device)], regex=r'does not broadcast')
4064
check('a, ba', [x, y], regex=r'subscript a has size 3 for operand 1 which does not broadcast with previously'
4067
check(x, [-1], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4068
check(x, [52], regex=r'not within the valid range \[0, 52\)', exception=ValueError)
4070
def _gen_shape_inputs_linalg_triangular_solve(self, shape, dtype, device, well_conditioned=False):
4071
make_arg = partial(make_tensor, dtype=dtype, device=device)
4072
make_fullrank = partial(make_fullrank_matrices_with_distinct_singular_values, dtype=dtype, device=device)
4074
for left, uni, expand_a, tr_a, conj_a, expand_b, tr_b, conj_b in product((True, False), repeat=8):
4076
if (conj_a or conj_b) and not dtype.is_complex:
4079
if (expand_a or expand_b) and b == 1:
4082
size_a = (b, n, n) if left else (b, k, k)
4083
size_b = (b, n, k) if not tr_b else (b, k, n)
4086
if b == 1 or expand_a:
4088
if b == 1 or expand_b:
4091
if well_conditioned:
4092
PLU = torch.linalg.lu(make_fullrank(*size_a))
4095
A = PLU[1].transpose(-2, -1).contiguous()
4098
A = PLU[2].contiguous()
4100
A = make_arg(size_a)
4103
diag = A.diagonal(0, -2, -1)
4107
diag[diag.abs() < 1e-6] = 1.
4109
B = make_arg(size_b)
4112
A.transpose_(-2, -1)
4114
B.transpose_(-2, -1)
4120
A = A.expand(b, *size_a)
4122
B = B.expand(b, n, k)
4123
yield A, B, left, not tr_a, uni
4125
def _test_linalg_solve_triangular(self, A, B, upper, left, uni):
4126
X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4128
self.assertEqual(A @ X, B)
4130
self.assertEqual(X @ A, B)
4133
if not B.is_contiguous() and not B.transpose(-2, -1).is_contiguous():
4135
torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni, out=out)
4136
self.assertEqual(X, out)
4139
@dtypes(*floating_and_complex_types())
4140
@precisionOverride({torch.float32: 1e-3 if TEST_WITH_ROCM else 1e-1,
4141
torch.float64: 1e-8,
4142
torch.complex64: 1e-1,
4143
torch.complex128: 1e-8})
4144
def test_linalg_solve_triangular(self, device, dtype):
4150
gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4151
for b, n, k in product(bs, ns, ks):
4152
for A, B, left, upper, uni in gen_inputs((b, n, k), dtype, device, well_conditioned=True):
4153
self._test_linalg_solve_triangular(A, B, upper, left, uni)
4156
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Test fails for float64 on GPU (P100, V100) on Meta infra")
4159
@dtypes(*floating_and_complex_types())
4160
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4161
torch.float64: 1e-8, torch.complex128: 1e-8})
4162
def test_linalg_solve_triangular_large(self, device, dtype):
4165
iterative_cublas = (2, 64, 1)
4167
gen_inputs = self._gen_shape_inputs_linalg_triangular_solve
4168
for shape in (magma, iterative_cublas):
4169
for A, B, left, upper, uni in gen_inputs(shape, dtype, device, well_conditioned=True):
4170
self._test_linalg_solve_triangular(A, B, upper, left, uni)
4172
@dtypes(*floating_and_complex_types())
4173
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2,
4174
torch.float64: 1e-8, torch.complex128: 1e-8})
4175
def test_linalg_solve_triangular_broadcasting(self, device, dtype):
4176
make_arg = partial(make_tensor, dtype=dtype, device=device)
4178
sizes = (((2, 1, 3, 4, 4), (2, 1, 3, 4, 6)),
4179
((2, 1, 3, 4, 4), (4, 6)),
4180
((4, 4), (2, 1, 3, 4, 2)),
4181
((1, 3, 1, 4, 4), (2, 1, 3, 4, 5)))
4182
for size_A, size_B in sizes:
4183
for left, upper, uni in itertools.product([True, False], repeat=3):
4184
A = make_arg(size_A)
4189
diag = A.diagonal(0, -2, -1)
4193
diag[diag.abs() < 1e-6] = 1.
4194
B = make_arg(size_B)
4196
B.transpose_(-2, -1)
4198
X = torch.linalg.solve_triangular(A, B, upper=upper, left=left, unitriangular=uni)
4204
self.assertEqual(*torch.broadcast_tensors(B, B_other))
4206
def triangular_solve_test_helper(self, A_dims, b_dims, upper, unitriangular,
4208
triangle_function = torch.triu if upper else torch.tril
4209
b = torch.randn(*b_dims, dtype=dtype, device=device)
4210
A = torch.randn(*A_dims, dtype=dtype, device=device)
4212
A = torch.matmul(A, A.mT)
4213
A_triangular = triangle_function(A)
4215
A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.)
4216
return b, A_triangular
4220
@skipIfTorchDynamo("flaky, needs investigation")
4221
@dtypes(*floating_and_complex_types())
4222
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4223
torch.float64: 1e-8, torch.complex128: 1e-8})
4224
def test_triangular_solve(self, device, dtype):
4227
for k, n, (upper, unitriangular, transpose) in itertools.product(ks, ns,
4228
itertools.product([True, False], repeat=3)):
4229
b, A = self.triangular_solve_test_helper((n, n), (n, k), upper,
4230
unitriangular, device, dtype)
4231
x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
4233
self.assertEqual(b, np.matmul(A.t().cpu(), x.cpu()))
4235
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
4239
@dtypes(*floating_and_complex_types())
4240
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4241
torch.float64: 1e-8, torch.complex128: 1e-8})
4242
def test_triangular_solve_batched(self, device, dtype):
4243
def triangular_solve_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4244
b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4245
unitriangular, device, dtype)
4247
for i in range(b_dims[0]):
4248
x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper,
4249
unitriangular=unitriangular,
4250
transpose=transpose)[0])
4251
x_exp = torch.stack(x_exp_list)
4252
x_act = torch.triangular_solve(b, A, upper=upper,
4253
unitriangular=unitriangular,
4254
transpose=transpose)[0]
4255
self.assertEqual(x_act, x_exp)
4259
Ax = np.matmul(A.cpu(), x_act.cpu())
4260
self.assertEqual(b, Ax)
4262
def triangular_solve_zero_batch_helper(A_dims, b_dims, upper, unitriangular, transpose):
4263
b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4264
unitriangular, device, dtype)
4265
x = torch.triangular_solve(b, A, upper=upper,
4266
unitriangular=unitriangular,
4267
transpose=transpose)[0]
4268
self.assertTrue(x.shape == b.shape)
4270
for upper, unitriangular, transpose in itertools.product([True, False], repeat=3):
4272
triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4273
upper, unitriangular, transpose)
4276
triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 10),
4277
upper, unitriangular, transpose)
4278
triangular_solve_batch_helper((batchsize, 0, 0), (batchsize, 0, 0),
4279
upper, unitriangular, transpose)
4283
triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
4284
upper, unitriangular, transpose)
4290
@dtypes(*floating_and_complex_types())
4291
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
4292
torch.float64: 1e-8, torch.complex128: 1e-8})
4293
def test_triangular_solve_batched_many_batches(self, device, dtype):
4294
for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4296
b, A = self.triangular_solve_test_helper((256, 256, 5, 5), (5, 1),
4297
upper, unitriangular, device, dtype)
4298
x, _ = torch.triangular_solve(b, A,
4299
upper=upper, transpose=transpose, unitriangular=unitriangular)
4303
Ax = torch.matmul(A, x)
4305
rtol = 1e-2 if dtype in [torch.float32, torch.complex64] else self.precision
4306
self.assertEqual(Ax, b.expand_as(Ax), atol=self.precision, rtol=rtol)
4309
b, A = self.triangular_solve_test_helper((3, 3), (512, 512, 3, 1),
4310
upper, unitriangular, device, dtype)
4311
x, _ = torch.triangular_solve(b, A, upper=upper, transpose=transpose,
4312
unitriangular=unitriangular)
4316
self.assertEqual(torch.matmul(A, x), b)
4320
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
4321
@skipIfTorchDynamo("flaky, needs investigation")
4322
@dtypes(*floating_and_complex_types())
4323
def test_triangular_solve_batched_broadcasting(self, device, dtype):
4324
from scipy.linalg import solve_triangular as tri_solve
4326
def scipy_tri_solve_batched(A, B, upper, trans, diag):
4327
batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2]
4328
single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:]
4329
expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A),
4330
torch.Size(batch_dims_B)))
4331
expand_A = np.broadcast_to(A, expand_dims + single_dim_A)
4332
expand_B = np.broadcast_to(B, expand_dims + single_dim_B)
4333
flat_A = expand_A.reshape((-1,) + single_dim_A)
4334
flat_B = expand_B.reshape((-1,) + single_dim_B)
4335
flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag)
4336
for a, b in zip(flat_A, flat_B)])
4337
return flat_X.reshape(expand_B.shape)
4339
def run_test(A_dims, b_dims, device, upper, transpose, unitriangular):
4340
b, A = self.triangular_solve_test_helper(A_dims, b_dims, upper,
4341
unitriangular, device, dtype)
4342
x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(),
4343
upper, transpose, unitriangular))
4344
x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
4346
self.assertEqual(x, x_exp.to(device))
4348
for upper, transpose, unitriangular in itertools.product([True, False], repeat=3):
4350
run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular)
4351
run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular)
4352
run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular)
4353
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular)
4356
@dtypes(torch.float)
4357
def test_triangular_solve_large(self, device, dtype):
4359
A = torch.randn(1, 2, 2, device=device, dtype=dtype).tril_()
4360
B = torch.randn(1, 2, 524281, device=device, dtype=dtype)
4361
X = torch.linalg.solve_triangular(A, B, upper=False)
4362
self.assertEqual(A @ X, B)
4366
@dtypes(*floating_and_complex_types())
4367
def test_triangular_solve_out_errors_and_warnings(self, device, dtype):
4369
a = torch.eye(2, dtype=dtype, device=device)
4370
b = torch.randn(2, 1, dtype=dtype, device=device)
4371
out = torch.empty_like(b).to(torch.int)
4372
clone_a = torch.empty_like(a)
4373
with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4374
torch.triangular_solve(b, a, out=(out, clone_a))
4376
out = torch.empty_like(b)
4377
clone_a = clone_a.to(torch.int)
4378
with self.assertRaisesRegex(RuntimeError, "Expected out tensor to have dtype"):
4379
torch.triangular_solve(b, a, out=(out, clone_a))
4382
if torch.cuda.is_available():
4383
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
4384
out = torch.empty(0, dtype=dtype, device=wrong_device)
4385
clone_a = torch.empty_like(a)
4386
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4387
torch.triangular_solve(b, a, out=(out, clone_a))
4388
out = torch.empty(0, dtype=dtype, device=device)
4389
clone_a = torch.empty_like(a).to(wrong_device)
4390
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
4391
torch.triangular_solve(b, a, out=(out, clone_a))
4394
torch.triangular_solve(b, a)
4397
with warnings.catch_warnings(record=True) as w:
4398
out = torch.empty(1, dtype=dtype, device=device)
4399
clone_a = torch.empty(1, dtype=dtype, device=device)
4401
torch.triangular_solve(b, a, out=(out, clone_a))
4403
self.assertEqual(len(w), 2)
4404
self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
4405
self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
4408
def check_single_matmul(self, x, y):
4410
def assertEqual(answer, expected):
4411
if x.dtype.is_floating_point or x.dtype.is_complex:
4412
k = max(x.shape[-1], 1)
4413
self.assertEqual(answer, expected,
4414
msg=f"{x.shape} x {y.shape} = {answer.shape}",
4418
self.assertEqual(answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}")
4421
expected = np.matmul(x.cpu(), y.cpu())
4422
ans = torch.matmul(x, y)
4423
self.assertTrue(ans.is_contiguous())
4424
assertEqual(ans, expected)
4427
out = torch.empty_like(ans)
4428
ans = torch.matmul(x, y, out=out)
4429
self.assertIs(ans, out)
4430
self.assertTrue(ans.is_contiguous())
4431
assertEqual(ans, expected)
4433
def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
4435
Generates sequences of tuples (x, y) of with size(x) = x_dim and
4436
size(y) <= y_dim that are compatible wrt. matmul
4441
for y in range(1, y_dim + 1):
4442
for batch, mn in product(product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
4443
product(range(matrix_size), repeat=min(y, 2))):
4447
yield size_x, size_y
4449
for k in range(matrix_size):
4450
size_x = (k,) + mn[:1]
4452
size_x = batch[-(x - 2):] + size_x
4455
size_y = batch[-(y - 2):] + size_y
4456
yield size_x, size_y
4458
@dtypesIfCUDA(torch.float, torch.complex64)
4459
@dtypes(torch.int64, torch.float, torch.complex64)
4460
@setBlasBackendsToDefaultFinally
4461
def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
4462
for backend in ["cublas", "cublaslt"]:
4463
if torch.device(device).type == 'cuda':
4464
torch.backends.cuda.preferred_blas_library(backend)
4466
make_arg = partial(make_tensor, device=device, dtype=dtype)
4468
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4469
x = make_arg(size_x, noncontiguous=nctg_x)
4470
y = make_arg(size_y, noncontiguous=nctg_y)
4471
self.check_single_matmul(x, y)
4473
@dtypesIfCUDA(torch.float, torch.complex64)
4474
@dtypes(torch.int64, torch.float, torch.complex64)
4475
@setBlasBackendsToDefaultFinally
4476
def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
4477
for backend in ["cublas", "cublaslt"]:
4478
if torch.device(device).type == 'cuda':
4479
torch.backends.cuda.preferred_blas_library(backend)
4481
make_arg = partial(make_tensor, device=device, dtype=dtype)
4483
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)):
4484
x = make_arg(size_x, noncontiguous=nctg_x)
4485
y = make_arg(size_y, noncontiguous=nctg_y)
4486
self.check_single_matmul(x, y)
4488
@dtypesIfCUDA(torch.float, torch.complex64)
4489
@dtypes(torch.int64, torch.float, torch.complex64)
4490
@setBlasBackendsToDefaultFinally
4491
def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
4492
for backend in ["cublas", "cublaslt"]:
4493
if torch.device(device).type == 'cuda':
4494
torch.backends.cuda.preferred_blas_library(backend)
4496
make_arg = partial(make_tensor, device=device, dtype=dtype)
4498
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(3), (True, False), (True, False)):
4499
x = make_arg(size_x, noncontiguous=nctg_x)
4500
y = make_arg(size_y, noncontiguous=nctg_y)
4501
self.check_single_matmul(x, y)
4504
@dtypes(*floating_types_and(torch.half))
4505
def test_matmul_small_brute_force_tunableop(self, device, dtype):
4508
os.environ["PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE"] = "0"
4509
set_tunableop_defaults()
4511
torch.cuda.tunable.enable()
4513
torch.cuda.tunable.set_max_tuning_duration(1)
4514
torch.cuda.tunable.set_max_tuning_iterations(1)
4516
make_arg = partial(make_tensor, device=device, dtype=dtype)
4518
for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)):
4519
x = make_arg(size_x, noncontiguous=nctg_x)
4520
y = make_arg(size_y, noncontiguous=nctg_y)
4521
self.check_single_matmul(x, y)
4523
filename1 = torch.cuda.tunable.get_filename()
4524
filename2 = "tunableop_results_tmp1.csv"
4525
filename3 = "tunableop_results_tmp2.csv"
4526
ordinal = torch.cuda.current_device()
4527
assert filename1 == f"tunableop_results{ordinal}.csv"
4528
assert len(torch.cuda.tunable.get_validators()) > 0
4530
for key, value in torch.cuda.tunable.get_validators():
4531
validators[key] = value
4532
if torch.version.hip:
4533
assert "HIPBLASLT_VERSION" in validators
4534
assert re.match(r'^\d{3}-[a-z0-9]{8}$', validators["HIPBLASLT_VERSION"])
4535
assert len(torch.cuda.tunable.get_results()) > 0
4537
assert torch.cuda.tunable.write_file()
4538
assert torch.cuda.tunable.write_file(filename2)
4539
torch.cuda.tunable.set_filename(filename3)
4540
assert torch.cuda.tunable.write_file()
4541
assert torch.cuda.tunable.read_file()
4543
with open(filename1) as file1:
4544
file1_contents = file1.read()
4545
with open(filename2) as file2:
4546
file2_contents = file2.read()
4547
with open(filename3) as file3:
4548
file3_contents = file3.read()
4549
assert file1_contents == file2_contents
4550
assert file1_contents == file3_contents
4553
for filename in [filename1, filename2, filename3]:
4557
except FileNotFoundError:
4561
torch.cuda.tunable.enable(False)
4565
@dtypes(torch.float)
4566
def test_bmm_tunableop_rocm(self, device, dtype):
4568
set_tunableop_defaults()
4569
torch.cuda.tunable.enable(True)
4570
torch.cuda.tunable.set_max_tuning_iterations(10)
4574
dtype = torch.bfloat16
4575
device = torch.device("cuda:0")
4577
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4578
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4579
out = torch.bmm(i1, i2)
4581
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4582
i1 = torch.permute(i1, (1, 2, 0))
4583
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4584
i2 = torch.permute(i2, (1, 0, 2))
4585
out = torch.bmm(i1, i2)
4587
i1 = torch.randn((N, B, M), device=device, dtype=dtype)
4588
i1 = torch.permute(i1, (1, 0, 2))
4589
i2 = torch.randn((M, B, K), device=device, dtype=dtype)
4590
i2 = torch.permute(i2, (1, 2, 0))
4591
out = torch.bmm(i1, i2)
4593
input_tensor = torch.rand((1920, 1, 100), device=device, dtype=dtype)
4594
input_tensor = torch.as_strided(
4595
input_tensor, size=(1920, 1, 100), stride=(100, 100, 1)
4597
batch1_tensor = torch.rand((1920, 256, 512), device=device, dtype=dtype)
4598
batch1_tensor = torch.as_strided(
4599
batch1_tensor, size=(1920, 256, 512), stride=(512, 983040, 1)
4601
batch2_tensor = torch.rand((1920, 512, 100), device=device, dtype=dtype)
4602
batch2_tensor = torch.as_strided(
4603
batch2_tensor, size=(1920, 512, 100), stride=(51200, 100, 1)
4605
out = torch.baddbmm(input_tensor, batch1_tensor, batch2_tensor)
4609
filename = torch.cuda.tunable.get_filename()
4611
except FileNotFoundError:
4615
torch.cuda.tunable.enable(False)
4619
@dtypes(torch.float)
4620
def test_numeric_check_leak_tunableop_rocm(self, device, dtype):
4621
from torch.testing._internal.common_utils import CudaMemoryLeakCheck
4627
dtype = torch.bfloat16
4628
device = torch.device("cuda:0")
4629
i1 = torch.randn((B, N, M), device=device, dtype=dtype)
4630
i2 = torch.randn((B, M, K), device=device, dtype=dtype)
4631
out = torch.bmm(i1, i2)
4633
PYTORCH_TUNABLEOP_NUMERICAL_CHECK = "PYTORCH_TUNABLEOP_NUMERICAL_CHECK"
4634
prev_val = os.getenv(PYTORCH_TUNABLEOP_NUMERICAL_CHECK)
4636
os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = "1"
4637
torch.cuda.tunable.enable(True)
4638
ordinal = torch.cuda.current_device()
4639
filename = f"tunableop_results{ordinal}.csv"
4640
torch.cuda.tunable.set_filename(filename)
4641
iterations = torch.cuda.tunable.get_max_tuning_iterations()
4642
torch.cuda.tunable.set_max_tuning_iterations(10)
4643
with CudaMemoryLeakCheck(self):
4644
out = torch.bmm(i1, i2)
4645
torch.cuda.tunable.set_max_tuning_iterations(iterations)
4646
torch.cuda.tunable.enable(False)
4650
except FileNotFoundError:
4653
if prev_val is None:
4654
del os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK]
4656
os.environ[PYTORCH_TUNABLEOP_NUMERICAL_CHECK] = prev_val
4660
@dtypes(torch.float)
4661
def test_validator_tunableop_rocm(self, device, dtype):
4669
validator_num_lines = 5
4674
set_tunableop_defaults()
4675
torch.cuda.tunable.enable()
4677
torch.cuda.tunable.set_max_tuning_iterations(1)
4680
A = torch.randn(N, K, device=device, dtype=dtype)
4681
B = torch.randn(K, M, device=device, dtype=dtype)
4682
C = torch.matmul(A, B)
4683
self.assertEqual(len(torch.cuda.tunable.get_validators()), validator_num_lines)
4686
torch.cuda.tunable.enable(False)
4691
filename = torch.cuda.tunable.get_filename()
4693
except FileNotFoundError:
4698
def test_minimum_tuning_iteration_tunableop(self, device, dtype):
4704
set_tunableop_defaults()
4705
torch.cuda.tunable.enable()
4707
torch.cuda.tunable.set_max_tuning_iterations(1)
4712
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "0"
4713
self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4714
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS"] = "30"
4717
ref_num_results = len(torch.cuda.tunable.get_results())
4720
A = torch.randn(N, K, device=device, dtype=dtype)
4721
B = torch.randn(K, M, device=device, dtype=dtype)
4722
C = torch.matmul(A, B)
4725
total_num_results = len(torch.cuda.tunable.get_results())
4728
self.assertEqual((total_num_results - ref_num_results), 1)
4732
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "0"
4733
self.assertGreater(torch.cuda.tunable.get_max_tuning_iterations(), 0)
4734
os.environ["PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS"] = "100"
4737
ref_num_results = total_num_results
4740
A = torch.randn(N, K, device=device, dtype=dtype)
4741
B = torch.randn(K, M, device=device, dtype=dtype)
4742
C = torch.matmul(A, B)
4745
total_num_results = len(torch.cuda.tunable.get_results())
4748
self.assertEqual((total_num_results - ref_num_results), 1)
4752
torch.cuda.tunable.enable(False)
4757
filename = torch.cuda.tunable.get_filename()
4759
except FileNotFoundError:
4764
def test_matmul_check_entries_tunableop(self, device, dtype):
4769
set_tunableop_defaults()
4770
torch.cuda.tunable.enable()
4772
torch.cuda.tunable.set_max_tuning_iterations(1)
4775
ref_num_results = len(torch.cuda.tunable.get_results())
4781
for M in [32, 64, 32]:
4783
A = torch.randn(N, K, device=device, dtype=dtype)
4784
B = torch.randn(K, M, device=device, dtype=dtype)
4785
C = torch.matmul(A, B)
4788
total_num_results = len(torch.cuda.tunable.get_results())
4793
self.assertEqual((total_num_results - ref_num_results), count_matmul)
4797
torch.cuda.tunable.enable(False)
4802
filename = torch.cuda.tunable.get_filename()
4804
except FileNotFoundError:
4807
@dtypes(torch.float, torch.complex64)
4808
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
4809
a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0)
4810
b = torch.empty((4, 128, 512), device=device, dtype=dtype, requires_grad=True).transpose(-1, -2)
4811
c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
4813
torch.matmul(a.detach(), b.detach(), out=c)
4815
with self.assertRaisesRegex(RuntimeError, "functions with out=... arguments don't support automatic differentiation"):
4816
torch.matmul(a, b, out=c)
4818
with torch.no_grad():
4819
torch.matmul(a, b, out=c)
4822
@largeTensorTest('16GB', device='cuda')
4823
def test_large_bmm_mm_backward(self, device):
4824
A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4825
B = torch.randn([1024, 65536], device="cuda", requires_grad=True)
4826
G = torch.randn([1024, 2, 65536], device="cuda")
4832
@largeTensorTest('16GB', device='cuda')
4833
def test_large_bmm_backward(self, device):
4834
A = torch.randn([1024, 2, 1024], device="cuda").mT.contiguous().mT
4835
B = torch.randn([1, 1024, 65536], device="cuda", requires_grad=True)
4836
G = torch.randn([1024, 2, 65536], device="cuda")
4841
def test_linear_algebra_scalar_raises(self, device) -> None:
4842
m = torch.randn(5, 5, device=device)
4843
v = torch.randn(5, device=device)
4844
s = torch.tensor(7, device=device)
4845
self.assertRaises(RuntimeError, lambda: torch.mv(m, s))
4846
self.assertRaises(RuntimeError, lambda: torch.addmv(v, m, s))
4848
@dtypes(torch.float32, torch.complex64)
4849
def test_cross(self, device, dtype):
4850
x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4851
y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4852
res1 = torch.cross(x, y)
4853
res2 = torch.tensor((), dtype=dtype, device=device)
4854
torch.cross(x, y, out=res2)
4855
self.assertEqual(res1, res2)
4857
@dtypes(torch.float32, torch.complex64)
4858
def test_linalg_cross(self, device, dtype):
4859
x = torch.rand(100, 3, 100, dtype=dtype, device=device)
4860
y = torch.rand(100, 3, 100, dtype=dtype, device=device)
4861
res1 = torch.linalg.cross(x, y, dim=1)
4862
res2 = torch.tensor((), dtype=dtype, device=device)
4863
torch.linalg.cross(x, y, dim=1, out=res2)
4864
self.assertEqual(res1, res2)
4867
x = torch.rand(1, 3, 2, dtype=dtype, device=device)
4868
y = torch.rand(4, 3, 1, dtype=dtype, device=device)
4869
res1 = torch.linalg.cross(x, y, dim=1)
4870
res2 = torch.tensor((), dtype=dtype, device=device)
4871
torch.linalg.cross(x, y, dim=1, out=res2)
4872
self.assertEqual(res1, res2)
4874
@dtypes(torch.float32, torch.complex64)
4875
def test_cross_with_and_without_dim(self, device, dtype):
4876
x = torch.rand(100, 3, dtype=dtype, device=device)
4877
y = torch.rand(100, 3, dtype=dtype, device=device)
4878
res1 = torch.cross(x, y, dim=1)
4879
res2 = torch.cross(x, y, dim=-1)
4880
res3 = torch.cross(x, y)
4881
self.assertEqual(res1, res2)
4882
self.assertEqual(res1, res3)
4884
@dtypes(torch.float32, torch.complex64)
4885
def test_linalg_cross_with_and_without_dim(self, device, dtype):
4886
x = torch.rand(100, 3, dtype=dtype, device=device)
4887
y = torch.rand(100, 3, dtype=dtype, device=device)
4888
res1 = torch.linalg.cross(x, y, dim=1)
4889
res2 = torch.linalg.cross(x, y, dim=-1)
4890
res3 = torch.linalg.cross(x, y)
4891
self.assertEqual(res1, res2)
4892
self.assertEqual(res1, res3)
4894
def test_renorm(self, device):
4895
m1 = torch.randn(20, 20, device=device)
4896
res1 = torch.tensor((), device=device)
4898
def renorm(matrix, value, dim, max_norm):
4899
m1 = matrix.transpose(dim, 0).contiguous()
4901
m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0))))
4902
norms = m2.norm(value, 1, True)
4904
new_norms = norms.clone()
4905
new_norms[torch.gt(norms, max_norm)] = max_norm
4906
new_norms.div_(norms.add_(1e-7))
4908
m1.mul_(new_norms.expand_as(m1))
4909
return m1.transpose(dim, 0)
4912
maxnorm = m1.norm(2, 1).mean()
4913
m2 = renorm(m1, 2, 1, maxnorm)
4914
m1.renorm_(2, 1, maxnorm)
4915
self.assertEqual(m1, m2, atol=1e-5, rtol=0)
4916
self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), atol=1e-5, rtol=0)
4918
m1 = torch.randn(3, 4, 5, device=device)
4919
m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
4920
maxnorm = m2.norm(2, 0).mean()
4921
m2 = renorm(m2, 2, 1, maxnorm)
4922
m1.renorm_(2, 1, maxnorm)
4923
m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4)
4924
self.assertEqual(m3, m2)
4925
self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
4928
@skipCUDAIfNoCusolver
4929
@dtypes(*floating_and_complex_types())
4930
def test_ormqr(self, device, dtype):
4932
def run_test(batch, m, n, fortran_contiguous):
4933
A = make_tensor((*batch, m, n), dtype=dtype, device=device)
4934
reflectors, tau = torch.geqrf(A)
4935
if not fortran_contiguous:
4936
self.assertTrue(reflectors.mT.is_contiguous())
4937
reflectors = reflectors.contiguous()
4940
Q, _ = torch.linalg.qr(A, mode='complete')
4941
C_right = make_tensor((*batch, m, n), dtype=dtype, device=device)
4942
C_left = make_tensor((*batch, n, m), dtype=dtype, device=device)
4944
expected = Q @ C_right
4945
actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=False)
4946
self.assertEqual(expected, actual)
4948
expected = C_left @ Q
4949
actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=False)
4950
self.assertEqual(expected, actual)
4952
expected = Q.mH @ C_right
4953
actual = torch.ormqr(reflectors, tau, C_right, left=True, transpose=True)
4954
self.assertEqual(expected, actual)
4956
expected = C_left @ Q.mH
4957
actual = torch.ormqr(reflectors, tau, C_left, left=False, transpose=True)
4958
self.assertEqual(expected, actual)
4962
zero_tau = torch.zeros_like(tau)
4963
actual = torch.ormqr(reflectors, zero_tau, C_right, left=True, transpose=False)
4964
self.assertEqual(C_right, actual)
4966
batches = [(), (0, ), (2, ), (2, 1)]
4968
for batch, (m, n), fortran_contiguous in product(batches, product(ns, ns), [True, False]):
4969
run_test(batch, m, n, fortran_contiguous)
4972
@skipCUDAIfNoCusolver
4973
@dtypes(*floating_and_complex_types())
4974
def test_ormqr_errors_and_warnings(self, device, dtype):
4977
((10,), (2,), (2,), r"input must have at least 2 dimensions"),
4978
((2, 2), (2,), (2,), r"other must have at least 2 dimensions"),
4979
((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"),
4980
((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"),
4981
((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"),
4982
((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"),
4984
for a_size, tau_size, c_size, error_regex in test_cases:
4985
a = make_tensor(a_size, dtype=dtype, device=device)
4986
tau = make_tensor(tau_size, dtype=dtype, device=device)
4987
c = make_tensor(c_size, dtype=dtype, device=device)
4988
with self.assertRaisesRegex(RuntimeError, error_regex):
4989
torch.ormqr(a, tau, c)
4991
def test_blas_empty(self, device):
4992
def fn(torchfn, *args, test_out=False, **kwargs):
4993
def call_torch_fn(*args, **kwargs):
4994
return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
4995
for shape in args), **kwargs)
4996
result = call_torch_fn(*args, **kwargs)
5000
out = torch.full_like(result, math.nan)
5001
out1 = call_torch_fn(*args, **kwargs, out=out)
5005
self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
5006
self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
5007
self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
5008
self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
5009
self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)))
5010
self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True))
5012
self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
5013
self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape)
5014
t = torch.randn((5, 6), device=device)
5015
self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
5016
self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
5019
self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
5020
self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
5021
self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
5022
self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True))
5024
self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
5025
t = torch.randn((3,), device=device)
5026
self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
5027
self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
5030
self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
5031
self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
5032
self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
5033
self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)))
5034
self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True))
5036
self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape)
5037
self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape)
5038
self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape)
5039
self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape)
5040
c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
5041
self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2))
5042
self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True))
5045
self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
5046
self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
5047
t = torch.randn((5, 6), device=device)
5048
self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
5049
self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
5052
self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,)))
5053
self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True))
5054
self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
5055
self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
5056
self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
5057
self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4)))
5058
self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True))
5061
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
5062
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True))
5064
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
5065
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5066
@dtypesIfCUDA(*floating_and_complex_types_and(
5068
*[torch.bfloat16] if SM53OrLater else []
5070
@dtypes(*all_types_and_complex_and(torch.bfloat16))
5071
def test_corner_cases_of_cublasltmatmul(self, device, dtype):
5073
M = torch.randn(128, device=device).to(dtype)
5074
m1 = torch.randn(2048, 2400, device=device).to(dtype)
5075
m2 = torch.randn(128, 2400, device=device).to(dtype)
5076
torch.nn.functional.linear(m1, m2, M)
5078
m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
5079
m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
5080
M = torch.rand([128]).to(dtype).to(device)
5081
torch.addmm(M, m2.t(), m1)
5083
m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
5084
m2 = torch.randn(2048, 2400, device=device).to(dtype)
5085
M = torch.rand([128]).to(dtype).to(device)
5086
torch.addmm(M, m2, m1)
5088
M = torch.randn(16, device=device).to(dtype)
5089
m1 = torch.randn(32, 131071 , device=device).to(dtype)
5090
m2 = torch.randn(16, 131071, device=device).to(dtype)
5091
torch.nn.functional.linear(m1, m2, M)
5095
@dtypes(*floating_types_and(torch.bfloat16, torch.half))
5096
def test_hipblaslt_corner_cases_rocm(self, device, dtype):
5097
if dtype == torch.double:
5098
raise unittest.SkipTest("hipblasLt doesn't support doubles yet")
5102
DISABLE_ADDMM_HIP_LT = "DISABLE_ADDMM_HIP_LT"
5103
prev_val = os.getenv(DISABLE_ADDMM_HIP_LT)
5105
os.environ[DISABLE_ADDMM_HIP_LT] = "0"
5107
M = torch.randn(128, device=device, dtype=dtype)
5108
m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5109
m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5110
out1 = torch.nn.functional.linear(m1, m2, M)
5112
m1_cpu = m1.to('cpu')
5113
m2_cpu = m2.to('cpu')
5114
out1_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, M_cpu)
5115
self.assertTrue(torch.allclose(out1_cpu, out1.cpu(), rtol=1e-2, atol=1e-2))
5118
m1 = torch.randn(2048, 2400, device=device, dtype=dtype)
5119
m2 = torch.randn(128, 2400, device=device, dtype=dtype)
5120
out2 = torch.nn.functional.linear(m1, m2, bias=None)
5121
m1_cpu = m1.to('cpu')
5122
m2_cpu = m2.to('cpu')
5123
out2_cpu = torch.nn.functional.linear(m1_cpu, m2_cpu, bias=None)
5124
self.assertTrue(torch.allclose(out2_cpu, out2.cpu(), rtol=1e-2, atol=1e-2))
5126
if prev_val is None:
5127
del os.environ[DISABLE_ADDMM_HIP_LT]
5129
os.environ[DISABLE_ADDMM_HIP_LT] = prev_val
5131
@dtypesIfCUDA(*floating_and_complex_types_and(
5133
*[torch.bfloat16] if SM53OrLater else []
5135
@dtypes(*all_types_and_complex_and(torch.bfloat16, torch.half))
5136
def test_blas_alpha_beta_empty(self, device, dtype):
5139
if dtype is torch.bfloat16 and self.device_type == 'xla':
5147
input = torch.full((2,), value, dtype=dtype, device=device)
5148
mat = torch.ones((2, 0), dtype=dtype, device=device)
5149
vec = torch.ones((0,), dtype=dtype, device=device)
5150
out = torch.empty((2,), dtype=dtype, device=device)
5151
if dtype.is_complex:
5157
self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5158
torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta))
5159
self.assertEqual(torch.full((2,), beta * value, dtype=dtype, device=device),
5160
torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out))
5163
input = torch.full((2, 3), value, dtype=dtype, device=device)
5164
mat2 = torch.ones((0, 3), dtype=dtype, device=device)
5165
out = torch.empty((2, 3), dtype=dtype, device=device)
5166
self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5167
torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta))
5168
self.assertEqual(torch.full((2, 3), beta * value, dtype=dtype, device=device),
5169
torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out))
5171
@dtypes(*floating_and_complex_types_and(torch.half, torch.bfloat16))
5172
def test_blas_nan_out(self, device, dtype):
5181
nm = torch.randn((m, n), device=device).t()
5182
_m = torch.randn((), device=device).expand(m)
5183
_m_out = torch.full((m,), float('nan'), device=device)
5184
self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5185
self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum())
5188
mp = torch.randn((p, m), device=device).t()
5189
np_out = torch.full((n, p), float('nan'), device=device)
5190
self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out))
5193
bnm = torch.randn((b, m, n), device=device).transpose(1, 2)
5194
bmp = torch.randn((b, p, m), device=device).transpose(1, 2)
5195
bnp_out = torch.full((b, n, p), float('nan'), device=device)
5196
self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out))
5199
def test_blas_mv_large_input(self, device):
5205
nm = torch.randn((m, n), device=device).t()
5206
_m = torch.randn((), device=device).expand(m)
5207
_m_out = torch.full((m,), 0., device=device)
5209
self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
5212
def test_renorm_ps(self, device):
5214
x = torch.randn(5, 5)
5216
for p in [1, 2, 3, 4, inf]:
5217
res = x.renorm(p, 1, 1)
5218
expected = x / x.norm(p, 0, keepdim=True).clamp(min=1)
5219
self.assertEqual(res, expected, msg=f"renorm failed for {p}-norm")
5222
@skipCUDAIfNoCusolver
5223
@dtypes(*floating_and_complex_types())
5224
def test_householder_product(self, device, dtype):
5225
def generate_reflectors_and_tau(A):
5227
This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf.
5228
There is torch.geqrf function but it doesn't work with complex-valued input.
5232
flattened_batch_shape = [-1, *A_cpu.shape[-2:]]
5233
reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape)
5234
tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]]
5235
tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1])
5236
for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau):
5237
reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw'))
5238
reflectors_i[:] = reflectors_tmp.T
5239
reflectors = reflectors.view(*A_cpu.shape)
5240
tau = tau.view(tau_shape)
5241
return reflectors.to(A.device), tau.to(A.device)
5243
reflectors = torch.empty_like(A)
5244
tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device)
5245
return reflectors, tau
5247
def run_test(shape):
5248
A = torch.randn(*shape, dtype=dtype, device=device)
5249
reflectors, tau = generate_reflectors_and_tau(A)
5250
expected, _ = torch.linalg.qr(A)
5251
actual = torch.linalg.householder_product(reflectors, tau)
5255
self.assertEqual(expected, actual)
5257
self.assertTrue(actual.shape == shape)
5261
tau_empty = torch.empty(*shape[:-2], 0, dtype=dtype, device=device)
5262
identity_mat = torch.zeros_like(reflectors)
5263
identity_mat.diagonal(dim1=-1, dim2=-2)[:] = 1
5264
actual = torch.linalg.householder_product(reflectors, tau_empty)
5265
self.assertEqual(actual, identity_mat)
5267
out = torch.empty_like(A)
5268
ans = torch.linalg.householder_product(reflectors, tau, out=out)
5269
self.assertEqual(ans, out)
5271
self.assertEqual(expected, out)
5273
shapes = [(0, 0), (5, 0),
5275
(0, 0, 0), (0, 5, 5), (0, 5, 3),
5276
(2, 5, 5), (2, 5, 3),
5277
(2, 1, 5, 5), (2, 1, 5, 3)]
5278
for shape in shapes:
5282
@skipCUDAIfNoCusolver
5283
def test_householder_product_errors_and_warnings(self, device):
5286
((10,), (2,), r"input must have at least 2 dimensions"),
5287
((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"),
5288
((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"),
5290
for a_size, tau_size, error_regex in test_cases:
5291
a = torch.rand(*a_size, device=device)
5292
tau = torch.rand(*tau_size, device=device)
5293
with self.assertRaisesRegex(RuntimeError, error_regex):
5294
torch.linalg.householder_product(a, tau)
5297
reflectors = torch.randn(3, 3, device=device)
5298
tau = torch.randn(3, device=device)
5299
out = torch.empty(2, 3, device=device)
5300
with warnings.catch_warnings(record=True) as w:
5302
torch.linalg.householder_product(reflectors, tau, out=out)
5304
self.assertEqual(len(w), 1)
5305
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
5308
out = torch.empty_like(reflectors).to(torch.int)
5309
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
5310
torch.linalg.householder_product(reflectors, tau, out=out)
5312
with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"):
5313
torch.linalg.householder_product(reflectors, tau.to(torch.int))
5315
if torch.cuda.is_available():
5317
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5318
out = torch.empty_like(reflectors).to(wrong_device)
5319
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5320
torch.linalg.householder_product(reflectors, tau, out=out)
5323
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
5324
tau = tau.to(wrong_device)
5325
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
5326
torch.linalg.householder_product(reflectors, tau)
5328
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5329
@skipCUDAIfNoMagmaAndNoCusolver
5330
@skipIfTorchDynamo("Runtime error with torch._C._linalg.linalg_lu_factor")
5332
@dtypes(*floating_and_complex_types())
5333
def test_linalg_lu_family(self, device, dtype):
5340
make_arg_full = partial(make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype)
5341
make_arg = partial(make_tensor, device=device, dtype=dtype)
5343
def run_test(A, pivot, singular, fn):
5344
k = min(A.shape[-2:])
5345
batch = A.shape[:-2]
5346
check_errors = (fn == torch.linalg.lu_factor)
5347
if singular and check_errors:
5351
LU, pivots = fn(A, pivot=pivot)
5352
except RuntimeError:
5355
LU, pivots = fn(A, pivot=pivot)[:2]
5357
self.assertEqual(LU.size(), A.shape)
5358
self.assertEqual(pivots.size(), batch + (k,))
5361
self.assertEqual(pivots, torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand(batch + (k, )))
5363
P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot)
5365
self.assertEqual(P @ L @ U if pivot else L @ U, A)
5367
PLU = torch.linalg.lu(A, pivot=pivot)
5368
self.assertEqual(P, PLU.P)
5369
self.assertEqual(L, PLU.L)
5370
self.assertEqual(U, PLU.U)
5372
if not singular and A.size(-2) == A.size(-1):
5373
nrhs = ((), (1,), (3,))
5374
for left, rhs in product((True, False), nrhs):
5376
if not left and rhs == ():
5379
shape_B = A.shape[:-1] + rhs
5381
shape_B = A.shape[:-2] + rhs + A.shape[-1:]
5382
B = make_arg(shape_B)
5387
for adjoint in (True, False):
5388
X = torch.linalg.lu_solve(LU, pivots, B, left=left, adjoint=adjoint)
5389
A_adj = A.mH if adjoint else A
5391
self.assertEqual(B, A_adj @ X)
5393
self.assertEqual(B, X @ A_adj)
5396
X = torch.linalg.solve(A, B, left=left)
5397
X_ = X.unsqueeze(-1) if rhs == () else X
5398
B_ = B.unsqueeze(-1) if rhs == () else B
5400
self.assertEqual(B_, A @ X_)
5402
self.assertEqual(B_, X_ @ A)
5405
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
5406
batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
5408
pivots = (True, False) if self.device_type == "cuda" else (True,)
5409
fns = (partial(torch.lu, get_infos=True), torch.linalg.lu_factor, torch.linalg.lu_factor_ex)
5410
for ms, batch, pivot, singular, fn in itertools.product(sizes, batches, pivots, (True, False), fns):
5412
A = make_arg(shape) if singular else make_arg_full(*shape)
5414
if A.numel() == 0 and not singular:
5416
run_test(A, pivot, singular, fn)
5421
if (dtype == torch.double
5423
A = torch.ones(batch + ms, dtype=dtype, device=device)
5424
run_test(A, pivot, singular, fn)
5427
A = torch.ones(5, 3, 3, device=device)
5428
self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all())
5430
if self.device_type == 'cpu':
5432
fns = [torch.lu, torch.linalg.lu_factor, torch.linalg.lu_factor_ex, torch.linalg.lu]
5434
with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
5435
f(torch.empty(1, 2, 2), pivot=False)
5438
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
5439
@skipCUDAIfNoMagmaAndNoCusolver
5441
@setLinalgBackendsToDefaultFinally
5442
@dtypes(*floating_and_complex_types())
5443
def test_linalg_lu_solve(self, device, dtype):
5444
make_arg = partial(make_tensor, dtype=dtype, device=device)
5446
backends = ["default"]
5448
if torch.device(device).type == 'cuda':
5449
if torch.cuda.has_magma:
5450
backends.append("magma")
5452
backends.append("cusolver")
5457
batches = ((), (0,), (1,), (2,), (2, 1), (0, 2))
5458
for batch, n in product(batches, ns):
5459
yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs))
5461
shapes = ((1, 64), (2, 128), (1025, 2))
5463
yield make_arg((b, n, n)), make_arg((b, n, rhs))
5466
for A, B in gen_matrices():
5467
LU, pivots = torch.linalg.lu_factor(A)
5468
for backend in backends:
5469
torch.backends.cuda.preferred_linalg_library(backend)
5471
for left, adjoint in product((True, False), repeat=2):
5472
B_left = B if left else B.mT
5473
X = torch.linalg.lu_solve(LU, pivots, B_left, left=left, adjoint=adjoint)
5474
A_adj = A.mH if adjoint else A
5476
self.assertEqual(B_left, A_adj @ X)
5478
self.assertEqual(B_left, X @ A_adj)
5482
@dtypes(*floating_and_complex_types())
5483
def test_linalg_lu_cpu_errors(self, device, dtype):
5485
sample = torch.randn(3, 2, 2, device=device, dtype=dtype)
5486
B = torch.randn(3, 2, 2, device=device, dtype=dtype)
5487
LU, pivots = torch.linalg.lu_factor(sample)
5490
torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5491
torch.lu_unpack(LU, pivots)
5494
with self.assertRaisesRegex(RuntimeError, r"greater or equal to 1"):
5495
torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5496
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5497
torch.lu_unpack(LU, pivots)
5500
with self.assertRaisesRegex(RuntimeError, r"smaller or equal to LU.size\(-2\)"):
5501
torch.linalg.lu_solve(LU, pivots, B, adjoint=True)
5502
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5503
torch.lu_unpack(LU, pivots)
5506
sample = torch.randn(3, 4, 2, device=device, dtype=dtype)
5507
B = torch.randn(3, 4, 2, device=device, dtype=dtype)
5508
LU, pivots = torch.linalg.lu_factor(sample)
5511
torch.lu_unpack(LU, pivots)
5514
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5515
torch.lu_unpack(LU, pivots)
5518
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5519
torch.lu_unpack(LU, pivots)
5523
sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
5524
B = torch.randn(2, 3, 5, device=device, dtype=dtype)
5525
LU, pivots = torch.linalg.lu_factor(sample)
5528
torch.lu_unpack(LU, pivots)
5531
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5532
torch.lu_unpack(LU, pivots)
5535
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
5536
torch.lu_unpack(LU, pivots)
5541
@dtypes(torch.double)
5542
def test_lu_unpack_check_input(self, device, dtype):
5543
x = torch.rand(5, 5, 5, device=device, dtype=dtype)
5544
lu_data, lu_pivots = torch.linalg.lu_factor(x)
5546
with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
5547
torch.lu_unpack(lu_data, lu_pivots.long())
5550
p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False)
5551
self.assertTrue(l.numel() == 0 and u.numel() == 0)
5552
p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False)
5553
self.assertTrue(p.numel() == 0)
5554
p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False)
5555
self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0)
5559
@dtypes(torch.double)
5560
def test_lobpcg_basic(self, device, dtype):
5561
self._test_lobpcg_method(device, dtype, 'basic')
5563
@skipCUDAIfNoCusolver
5565
@dtypes(torch.double)
5566
def test_lobpcg_ortho(self, device, dtype):
5567
if torch.version.hip:
5568
torch.backends.cuda.preferred_linalg_library('magma')
5569
self._test_lobpcg_method(device, dtype, 'ortho')
5570
if torch.version.hip:
5571
torch.backends.cuda.preferred_linalg_library('default')
5573
def _test_lobpcg_method(self, device, dtype, method):
5574
from torch.testing._internal.common_utils import random_symmetric_pd_matrix, random_sparse_pd_matrix
5575
from torch._linalg_utils import matmul, qform
5576
from torch._lobpcg import lobpcg
5578
def test_tracker(worker):
5579
k = worker.iparams['k']
5580
nc = worker.ivars['converged_count']
5582
tol = worker.fparams['tol']
5583
rerr = worker.tvars['rerr']
5592
self.assertLessEqual(rerr[:k].max(), tol)
5595
I = torch.eye(k, k, dtype=dtype, device=device)
5596
self.assertEqual(qform(B, X[:, :k]), I)
5599
self.assertEqual(qform(A, X[:, :k]) / E[:k], I, atol=0.2, rtol=0)
5601
orig_lobpcg = lobpcg
5603
def lobpcg(*args, **kwargs):
5604
kwargs['tracker'] = test_tracker
5605
kwargs['niter'] = 1000
5606
kwargs['method'] = method
5607
kwargs['tol'] = 1e-8
5608
return orig_lobpcg(*args, **kwargs)
5613
for batches in [(), (2,), (2, 3)]:
5623
if method == 'basic' and (m, n, k) in [(9, 2, 2), (100, 15, 5)]:
5625
A = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5626
B = random_symmetric_pd_matrix(m, *batches, device=device, dtype=dtype)
5629
E, V = lobpcg(A, k=k, n=n, largest=False)
5630
self.assertEqual(E.shape, batches + (k,))
5631
self.assertEqual(V.shape, batches + (m, k))
5632
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5633
e = torch.linalg.eigvalsh(A)
5634
e_smallest = e[..., :k]
5635
self.assertEqual(E, e_smallest)
5638
E, V = lobpcg(A, k=k, n=n, largest=True)
5639
e_largest, _ = torch.sort(e[..., -k:], descending=True)
5640
self.assertEqual(E, e_largest, atol=prec, rtol=0)
5641
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5644
E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5645
self.assertEqual(matmul(A, V), mm(matmul(B, V), E.diag_embed()), atol=prec, rtol=0)
5648
E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5649
self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5653
for m, n, k, density in [
5661
if method == 'basic' and (m, n, k, density) in [(1000, 7, 3, 0.01)]:
5663
A = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5664
B = random_sparse_pd_matrix(m, density=density, device=device, dtype=dtype)
5665
A_eigenvalues = torch.arange(1, m + 1, dtype=dtype) / m
5666
e_smallest = A_eigenvalues[..., :k]
5667
e_largest, _ = torch.sort(A_eigenvalues[..., -k:], descending=True)
5670
E, V = lobpcg(A, k=k, n=n, largest=False)
5671
self.assertEqual(E, e_smallest)
5672
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5675
E, V = lobpcg(A, k=k, n=n, largest=True)
5676
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
5677
self.assertEqual(E, e_largest)
5680
E, V = lobpcg(A, B=B, k=k, n=n, largest=False)
5681
self.assertEqual(matmul(A, V), matmul(B, mm(V, E.diag_embed())), atol=prec, rtol=0)
5684
E, V = lobpcg(A, B=B, k=k, n=n, largest=True)
5685
self.assertEqual(matmul(A, V) / E.max(), mm(matmul(B, V), (E / E.max()).diag_embed()),
5690
@dtypes(torch.double)
5691
def test_lobpcg_torchscript(self, device, dtype):
5692
from torch.testing._internal.common_utils import random_sparse_pd_matrix
5693
from torch._linalg_utils import matmul as mm
5695
lobpcg = torch.jit.script(torch.lobpcg)
5699
A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5700
X1 = torch.randn((m, k), dtype=dtype, device=device)
5701
E1, V1 = lobpcg(A1, X=X1)
5702
eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5703
self.assertLess(eq_err, 1e-6)
5705
@unittest.skipIf(not TEST_SCIPY or (TEST_SCIPY and scipy.__version__ < '1.4.1'), "Scipy not found or older than 1.4.1")
5707
@skipIfTorchDynamo("fails in tracing scipy.sparse.lobpcg")
5709
@dtypes(torch.double)
5710
def test_lobpcg_scipy(self, device, dtype):
5711
"""Compare torch and scipy.sparse.linalg implementations of lobpcg
5714
from torch.testing._internal.common_utils import random_sparse_pd_matrix
5715
from torch._linalg_utils import matmul as mm
5716
from scipy.sparse.linalg import lobpcg as scipy_lobpcg
5720
if A.layout == torch.sparse_coo:
5721
values = A.coalesce().values().cpu().numpy().copy()
5722
indices = A.coalesce().indices().cpu().numpy().copy()
5723
return scipy.sparse.coo_matrix((values, (indices[0], indices[1])), A.shape)
5724
return A.cpu().numpy().copy()
5730
A1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5731
B1 = random_sparse_pd_matrix(m, density=2.0 / m, device=device, dtype=dtype)
5732
X1 = torch.randn((m, k), dtype=dtype, device=device)
5740
def tracker(worker):
5741
lambdas1.append(worker.E[:])
5749
E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5750
E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=1.1 * tol)
5751
iters1 = len(lambdas1)
5752
iters2 = len(lambdas2)
5753
self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5755
E2a, V2a = scipy_lobpcg(A2, X2, maxiter=niter, largest=False)
5757
eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5758
eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5759
self.assertLess(eq_err, 1e-6)
5760
self.assertLess(eq_err_scipy, 1e-6)
5762
self.assertEqual(E1, torch.from_numpy(E2.copy()))
5767
def tracker(worker):
5768
lambdas1.append(worker.E[:])
5770
E1, V1 = torch.lobpcg(A1, B=B1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5771
E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=39 * tol)
5772
E2a, V2a = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=False)
5773
iters1 = len(lambdas1)
5774
iters2 = len(lambdas2)
5775
self.assertLess(abs(iters1 - iters2), 0.05 * max(iters1, iters2))
5777
eq_err = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5778
eq_err_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5779
self.assertLess(eq_err, 1e-6)
5780
self.assertLess(eq_err_scipy, 1e-6)
5782
self.assertEqual(E1, torch.from_numpy(E2.copy()))
5786
elapsed_ortho_general = 0
5788
elapsed_general_scipy = 0
5789
for i in range(repeat):
5791
torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol)
5793
elapsed_ortho += end - start
5796
torch.lobpcg(A1, X=X1, B=B1, niter=niter, method='ortho', tol=tol)
5798
elapsed_ortho_general += end - start
5801
scipy_lobpcg(A2, X2, maxiter=niter, tol=1.1 * tol)
5803
elapsed_scipy += end - start
5806
scipy_lobpcg(A2, X2, B=B2, maxiter=niter, tol=39 * tol)
5808
elapsed_general_scipy += end - start
5810
elapsed_ortho_ms = 1000.0 * elapsed_ortho / repeat
5811
elapsed_ortho_general_ms = 1000.0 * elapsed_ortho_general / repeat
5812
elapsed_scipy_ms = 1000.0 * elapsed_scipy / repeat
5813
elapsed_general_scipy_ms = 1000.0 * elapsed_general_scipy / repeat
5816
CPU timings: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5817
-------------------------------------------------------
5818
| standard | generalized | method
5819
torch.lobpcg | {elapsed_ortho_ms:10.2f} | {elapsed_ortho_general_ms:10.2f} | ortho
5820
scipy_lobpcg | {elapsed_scipy_ms:10.2f} | {elapsed_general_scipy_ms:10.2f} | N/A
5821
-(input size: {m:4}, eigenpairs:{k:2}, units: ms per call)-
5829
def tracker(worker):
5830
lambdas1.append(worker.E[:])
5832
E1, V1 = torch.lobpcg(A1, X=X1, niter=niter, largest=True, tracker=tracker, tol=tol)
5833
iters1 = len(lambdas1)
5834
eq_err = torch.norm((mm(A1, V1) - V1 * E1), 2) / E1.max()
5837
E2, V2, lambdas2 = scipy_lobpcg(A2, X2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5838
iters2 = len(lambdas2)
5839
eq_err_scipy = (abs(A2.dot(V2) - V2 * E2)**2).sum() ** 0.5 / E2.max()
5840
except Exception as msg:
5841
print('Calling scipy_lobpcg failed [standard]:', msg)
5847
def tracker(worker):
5848
lambdas1.append(worker.E[:])
5850
E1, V1 = torch.lobpcg(A1, X=X1, B=B1, niter=niter, largest=True, tracker=tracker, tol=tol)
5851
iters1_general = len(lambdas1)
5852
eq_err_general = torch.norm((mm(A1, V1) - mm(B1, V1) * E1), 2) / E1.max()
5855
E2, V2, lambdas2 = scipy_lobpcg(A2, X2, B=B2, maxiter=niter, largest=True, retLambdaHistory=True, tol=tol)
5856
iters2_general = len(lambdas2)
5857
eq_err_general_scipy = (abs(A2.dot(V2) - B2.dot(V2) * E2)**2).sum() ** 0.5 / E2.max()
5858
except Exception as msg:
5859
print('Calling scipy_lobpcg failed [generalized]:', msg)
5861
eq_err_general_scipy = -1
5864
Handling of small tol={tol:6.0e}: torch.lobpcg vs scipy.sparse.linalg.lobpcg
5865
----------------------------------------------------------------------------
5866
| standard | generalized | niter | method
5867
torch.lobpcg | {eq_err:10.2e} | {eq_err_general:10.2e} | {iters1:6} | ortho
5868
scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:6} | N/A
5869
---(input size: {m:4}, eigenpairs:{k:2}, units: relative error, maxiter={niter:4})---
5872
def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None):
5875
if dtype in {torch.bfloat16, torch.half}:
5876
numpy_dtype = torch.float
5877
if dtype.is_complex:
5878
alpha = 0.9 + 0.3j if alpha is None else alpha
5879
beta = 0.5 + 0.6j if beta is None else beta
5881
alpha = 1.2 if alpha is None else alpha
5882
beta = 0.8 if beta is None else beta
5883
if activation == "gelu":
5884
res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
5886
res1 = f(t, m, v, alpha=alpha, beta=beta)
5887
res2 = torch.full_like(res1, math.nan)
5889
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
5890
if activation == "gelu":
5891
f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
5893
f(t, m, v, alpha=alpha, beta=beta, out=res2)
5894
res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
5896
res3 += (beta * t).to(numpy_dtype).cpu().numpy()
5897
if activation == "relu":
5898
res3 = res3 * (res3 > 0)
5899
elif activation == "gelu":
5900
res3_t = torch.from_numpy(res3).to(dtype)
5901
approximate = "tanh" if t.is_cuda else "none"
5902
res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
5903
res3 = res3_t.to(numpy_dtype).cpu().numpy()
5905
assert activation is None, f"unsupported activation {activation}"
5906
res3 = torch.from_numpy(res3).to(dtype)
5907
self.assertEqual(res1, res2)
5908
self.assertEqual(res1, res3)
5910
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4, torch.double: 1e-8,
5911
torch.cfloat: 1e-4, torch.cdouble: 1e-8})
5912
@dtypesIfCUDA(*floating_and_complex_types_and(
5913
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [],
5915
@dtypes(torch.bfloat16, torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
5916
def test_addmv(self, device, dtype):
5917
if IS_ARM64 and device == 'cpu' and dtype == torch.float16:
5918
raise unittest.SkipTest("Fails on ARM, see https://github.com/pytorch/pytorch/issues/125438")
5924
0.2 * torch.randn(50, device=device).to(dtype),
5925
0.2 * torch.randn(1, device=device).to(dtype).expand(50),
5928
0.2 * torch.randn(100, device=device).to(dtype),
5929
0.2 * torch.ones(1, device=device).to(dtype).expand(100),
5933
0.2 * torch.ones((), device=device).to(dtype).expand(50, 100),
5935
0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
5939
0.2 * torch.randint(3, (50, 1), dtype=torch.float, device=device).to(dtype).expand(50, 100),
5941
0.2 * torch.randn((50, 100), device=device).to(dtype),
5942
0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
5944
for m, v, t in itertools.product(ms, vs, ts):
5945
self._test_addmm_addmv(torch.addmv, t, m, v)
5947
t = torch.full((50,), math.nan, device=device).to(dtype)
5948
for m, v in itertools.product(ms, vs):
5949
self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
5951
@dtypesIfCUDA(*floating_types_and(*[torch.bfloat16] if TEST_WITH_ROCM or
5952
SM53OrLater else []))
5953
@dtypes(torch.float, torch.double)
5954
def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
5958
a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
5959
x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
5960
y_data = torch.ones(o, device=device, dtype=dtype)
5961
control = torch.tensor([15., 33., 51., 69., 87.], device=device, dtype=dtype)
5963
def _test(row_major, incx, incy, lda_tail):
5965
a_storage = torch.full((o, s + lda_tail), float('nan'), device=device, dtype=dtype)
5967
a_storage = torch.full((s, o + lda_tail), float('nan'), device=device, dtype=dtype).permute(1, 0)
5968
a = a_storage[:o, :s].copy_(a_data)
5970
x_storage = torch.full((s, incx), float('nan'), device=device, dtype=dtype)
5971
x = x_storage[:, 0].copy_(x_data)
5973
y_storage = torch.full((o, incy), float('nan'), device=device, dtype=dtype)
5974
y = y_storage[:, 0].copy_(y_data)
5976
self._test_addmm_addmv(torch.addmv, y, a, x)
5978
for row_major, incx, incy, lda_tail in itertools.product((False, True), (1, 2), (1, 2), (0, 1)):
5979
_test(row_major, incx, incy, lda_tail)
5981
def _test_addmm_impl(self, func, activation, device, dtype):
5982
M = torch.randn(10, 25, device=device).to(dtype)
5983
m1 = torch.randn(10, 50, device=device).to(dtype)
5984
m2 = torch.randn(50, 25, device=device).to(dtype)
5985
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
5988
V = torch.randn(25, device=device).to(dtype)
5989
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
5992
M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
5993
m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
5994
m2 = torch.randn(50, 25, device=device).to(dtype)
5995
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
5998
M = torch.full((10, 25), math.nan, device=device).to(dtype)
5999
m1 = torch.randn(10, 50, device=device).to(dtype)
6000
m2 = torch.randn(50, 25, device=device).to(dtype)
6001
self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
6004
for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
6005
def maybe_transpose(cond, m):
6008
return m.t().clone(memory_format=torch.contiguous_format).t()
6010
M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
6011
m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
6012
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
6013
self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation)
6017
self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,)
6019
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
6020
torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6021
@dtypesIfMPS(torch.float32)
6022
@dtypesIfCUDA(*floating_and_complex_types_and(
6023
*[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
6024
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6025
@tf32_on_and_off(0.05)
6026
@bf32_on_and_off(0.05)
6027
def test_addmm(self, device, dtype):
6028
self._test_addmm_impl(torch.addmm, None, device, dtype)
6030
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6031
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6032
@dtypesIfCUDA(*floating_types_and(
6033
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6034
@dtypes(*floating_types_and(torch.bfloat16))
6035
@tf32_on_and_off(0.05)
6036
@bf32_on_and_off(0.05)
6037
def test_addmm_relu(self, device, dtype):
6038
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6042
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6043
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6044
@dtypesIfCUDA(*floating_types_and(
6045
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6046
@dtypes(*floating_types_and(torch.bfloat16))
6047
@tf32_on_and_off(0.05)
6048
@bf32_on_and_off(0.05)
6049
def test_addmm_relu_tunableop_rocm(self, device, dtype):
6050
torch.cuda.tunable.enable(True)
6051
ordinal = torch.cuda.current_device()
6052
filename = f"tunableop_results{ordinal}.csv"
6053
torch.cuda.tunable.set_filename(filename)
6054
iterations = torch.cuda.tunable.get_max_tuning_iterations()
6055
torch.cuda.tunable.set_max_tuning_iterations(10)
6056
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
6061
except FileNotFoundError:
6064
torch.cuda.tunable.set_max_tuning_iterations(iterations)
6065
torch.cuda.tunable.enable(False)
6067
@precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 5e-2,
6068
torch.half: 5e-2, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
6069
@dtypesIfCUDA(*floating_types_and(
6070
*[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else []))
6071
@dtypes(*floating_types_and(torch.bfloat16))
6072
@tf32_on_and_off(0.05)
6073
@bf32_on_and_off(0.05)
6074
def test_addmm_gelu(self, device, dtype):
6075
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
6077
@dtypes(torch.float, torch.double)
6078
@dtypesIfCUDA(*floating_and_complex_types())
6079
@tf32_on_and_off(0.005)
6080
@bf32_on_and_off(0.005)
6081
def test_addmm_sizes(self, device, dtype):
6082
for m in [0, 1, 25]:
6083
for n in [0, 1, 10]:
6085
M = torch.randn(n, m, device=device).to(dtype)
6086
m1 = torch.randn(n, k, device=device).to(dtype)
6087
m2 = torch.randn(k, m, device=device).to(dtype)
6088
self._test_addmm_addmv(torch.addmm, M, m1, m2)
6090
m1 = torch.randn(n, k + 1, device=device).to(dtype)
6091
m2 = torch.randn(k, m, device=device).to(dtype)
6092
self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
6093
self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
6097
def test_addmm_baddbmm_overflow(self, device, dtype):
6098
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
6099
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
6100
inp = torch.zeros(128, 128, dtype=torch.half, device=device)
6101
mat1 = torch.ones(128, 1000, dtype=torch.half, device=device) * 100
6102
mat2 = torch.ones(1000, 128, dtype=torch.half, device=device) * 100
6103
out = torch.addmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6106
self.assertFalse(out.isinf().any())
6108
self.assertTrue((out == 10000.).all())
6109
inp = torch.zeros(3, 128, 128, dtype=torch.half, device=device)
6110
mat1 = torch.ones(3, 128, 1000, dtype=torch.half, device=device) * 100
6111
mat2 = torch.ones(3, 1000, 128, dtype=torch.half, device=device) * 100
6112
out = torch.baddbmm(inp, mat1, mat2, alpha=0.001, beta=0.)
6114
self.assertFalse(out.isinf().any())
6116
self.assertTrue((out == 10000.).all())
6117
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
6119
@dtypes(torch.float)
6120
def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
6121
for shape in [[3, 2, 2], [2, 20, 20]]:
6122
mat1, mat2 = (torch.randn(shape, dtype=dtype, device=device) for _ in range(2))
6123
inputs = [torch.randn(shape, dtype=dtype, device=device),
6124
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6125
outs = [None, torch.randn(shape, dtype=dtype, device=device),
6126
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan)]
6127
options = itertools.product(inputs, outs)
6128
for input, out in options:
6129
y_ref = torch.bmm(mat1, mat2)
6130
y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
6131
self.assertEqual(y_ref, y)
6133
@dtypes(torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64)
6134
def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
6135
batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6136
batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
6137
input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
6138
if dtype != torch.float32:
6139
with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
6140
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
6142
out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
6143
y_ref = torch.bmm(batch1, batch2)
6144
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
6145
self.assertEqual(out, y_ref)
6148
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6150
def test_matmul_45724(self, device):
6152
a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
6153
b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
6154
c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
6155
cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
6156
torch.matmul(a, b, out=c)
6157
self.assertEqual(c, cpu_result)
6159
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6160
@unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90")
6161
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6163
@parametrize("k", [16, 32])
6164
@parametrize("n", [16, 32])
6165
@parametrize("use_transpose_a", [True, False])
6166
@parametrize("use_transpose_b", [True, False])
6167
def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b):
6168
def genf_int_float(x, y, use_transpose):
6171
x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6172
x_float = x_int8.to(torch.float32)
6174
return x_int8.t(), x_float.t()
6175
return x_int8, x_float
6177
def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
6178
a_int8, a_float = genf_int_float(m, k, transpose_a)
6179
b_int8, b_float = genf_int_float(k, n, transpose_b)
6180
c_int32 = torch._int_mm(a_int8, b_int8)
6181
self.assertTrue(c_int32.dtype is torch.int32)
6182
self.assertEqual(c_int32.device, torch.device(device))
6184
self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6186
self.assertNotEqual(c_int32.float(), torch.mm(a_float, b_float))
6187
c_int32_result = c_int32.new_empty(c_int32.size())
6189
torch._int_mm(a_int8, b_int8, out=c_int32_result)
6191
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6193
self.assertNotEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6196
version = _get_torch_cuda_version()
6197
SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
6198
SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0)
6199
SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5)
6202
_test(17, k, n, use_transpose_a, use_transpose_b, True)
6203
elif version >= (11, 7):
6204
if not use_transpose_a and use_transpose_b:
6205
if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6206
_test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7))
6208
with self.assertRaisesRegex(RuntimeError,
6209
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6210
_test(17, k, n, use_transpose_a, use_transpose_b)
6212
if use_transpose_a and not use_transpose_b:
6213
with self.assertRaisesRegex(RuntimeError,
6214
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6215
_test(17, k, n, use_transpose_a, use_transpose_b)
6217
if use_transpose_a and use_transpose_b:
6218
with self.assertRaisesRegex(RuntimeError,
6219
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6220
_test(17, k, n, use_transpose_a, use_transpose_b)
6222
if not use_transpose_a and not use_transpose_b:
6223
if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
6224
_test(17, k, n, use_transpose_a, use_transpose_b)
6226
with self.assertRaisesRegex(RuntimeError,
6227
"CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul"):
6228
_test(17, k, n, use_transpose_a, use_transpose_b)
6230
with self.assertRaisesRegex(RuntimeError, "_int_mm_out_cuda not compiled for CUDA"):
6231
_test(17, k, n, use_transpose_a, use_transpose_b, False)
6233
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6234
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6236
def test__int_mm_errors(self, device):
6238
self.skipTest("_int_mm not compiled for ROCM")
6240
version = _get_torch_cuda_version()
6241
if version < (11, 7):
6242
self.skipTest("_int_mm only compiled for CUDA 11.7")
6245
return torch.empty((x, y), dtype=torch.int8, device=device)
6247
def _gen_pair(m, k, n):
6248
return genf_int(m, k), genf_int(k, n)
6250
self.assertRaisesRegex(RuntimeError,
6251
r"self.size\(0\) needs to be greater than 16, but got 16",
6252
lambda: torch._int_mm(*_gen_pair(16, 8, 32)))
6253
self.assertRaisesRegex(RuntimeError,
6254
r"self.size\(1\) needs to be greater than 0 and a multiple of 8, but got 7",
6255
lambda: torch._int_mm(*_gen_pair(17, 7, 32)))
6256
self.assertRaisesRegex(RuntimeError,
6257
r"self.size\(1\) needs to match mat2.size\(0\) but got 8 and 7",
6258
lambda: torch._int_mm(genf_int(17, 8), genf_int(7, 32)))
6259
self.assertRaisesRegex(RuntimeError,
6260
r"mat2.size\(1\) needs to be greater than 0 and a multiple of 8, but got 31",
6261
lambda: torch._int_mm(*_gen_pair(17, 8, 31)))
6262
self.assertRaisesRegex(RuntimeError,
6263
r"expected scalar type Char but found Float",
6264
lambda: torch._int_mm(genf_int(17, 8).float(), genf_int(8, 32)))
6265
self.assertRaisesRegex(RuntimeError,
6266
r"expected scalar type Char but found Float",
6267
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32).float()))
6268
self.assertRaisesRegex(RuntimeError,
6269
r"Expected result dtype to be of type kInt but got float",
6270
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 32).float()))
6271
self.assertRaisesRegex(RuntimeError,
6272
r"Expected result.size\(0\) to be 17 but got 15",
6273
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(15, 32).int()))
6274
self.assertRaisesRegex(RuntimeError,
6275
r"Expected result.size\(0\) to be 17 but got 16",
6276
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))
6279
@parametrize("m", [0, 8, 17])
6280
@parametrize("k", [0, 16, 32])
6281
@parametrize("n", [16, 32])
6282
@parametrize("use_transpose_a", [True, False])
6283
@parametrize("use_transpose_b", [True, False])
6284
@parametrize("non_contig_type", [0, 1, 2])
6285
def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):
6291
def genf_int_float(x, y, use_transpose, non_contig_type):
6294
if non_contig_type != 0:
6296
x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
6297
x_float = x_int8.to(torch.float32)
6298
if non_contig_type == 1:
6299
x_int8 = x_int8[:, : y // 2]
6300
x_float = x_float[:, : y // 2]
6301
elif non_contig_type == 2:
6302
x_int8 = x_int8[:, ::2]
6303
x_float = x_float[:, ::2]
6305
return x_int8.t(), x_float.t()
6306
return x_int8, x_float
6308
if non_contig_type != 0 and (m == 0 or k == 0):
6310
a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
6311
b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
6312
c_int32 = torch._int_mm(a_int8, b_int8)
6313
self.assertTrue(c_int32.dtype is torch.int32)
6314
self.assertEqual(c_int32.device, torch.device(device))
6315
self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
6316
c_int32_result = c_int32.new_empty(c_int32.size())
6318
torch._int_mm(a_int8, b_int8, out=c_int32_result)
6319
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
6321
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6322
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6323
@onlyNativeDeviceTypes
6324
def test__convert_weight_to_int4pack(self, device):
6326
test_list = [((64, 32), 2), ((64, 48), 2), ((64, 64), 2), ((256, 128), 4), ((256, 128), 8)]
6327
if self.device_type == 'cuda' and not SM80OrLater:
6328
self.skipTest("requires SM80 or later")
6331
if not CDNA2OrLater():
6332
self.skipTest("_int4_mm is supported only for CDNA2 or later")
6334
torch.manual_seed(1)
6335
for shape, innerKTiles in test_list:
6336
b = torch.rand(shape, dtype=torch.bfloat16, device=device)
6337
b_uint8, _ = _group_quantize_tensor(b, n_bit=4, q_group_size=32)
6338
b_int4pack = torch._convert_weight_to_int4pack(b_uint8, innerKTiles=innerKTiles)
6339
b_int4pack_meta = torch._convert_weight_to_int4pack(b_uint8.to(device="meta"), innerKTiles=innerKTiles)
6340
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
6342
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6343
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6344
@onlyNativeDeviceTypes
6345
@parametrize("m", [32, 64])
6346
@parametrize("k", [32, 64])
6347
@parametrize("n", [48, 64])
6348
def test__int4_mm(self, device, m, k, n):
6349
if self.device_type == 'cuda' and not SM80OrLater:
6350
self.skipTest("requires SM80 or later")
6353
if not CDNA2OrLater():
6354
self.skipTest("_int4_mm is supported only for CDNA2 or later")
6359
torch.manual_seed(1)
6360
a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6361
b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6363
def convert_weight_to_int4pack(b):
6364
b_uint8, b_scales_and_zeros = _group_quantize_tensor(
6365
b, n_bit=4, q_group_size=q_group
6367
b_int4pack = torch._convert_weight_to_int4pack(
6368
b_uint8, inner_k_tiles
6371
return b_int4pack, b_scales_and_zeros
6373
def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
6374
return torch._weight_int4pack_mm(
6375
a, b_int4pack, q_group, b_scales_and_zeros
6378
b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16)
6380
for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else []):
6381
a = a_bf16.to(dtype=dtype)
6382
b = b_bf16.to(dtype=dtype)
6383
b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype)
6384
ref = torch.mm(a, b)
6385
res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
6387
mean_err = ((res - ref).abs() / ref).mean()
6388
self.assertTrue(mean_err < 0.05)
6391
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
6392
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
6393
@onlyNativeDeviceTypes
6394
@parametrize("m", [32, 64])
6395
@parametrize("k", [32, 64])
6396
@parametrize("n", [48, 64])
6397
def test_compile_int4_mm(self, device, m, k, n):
6398
if self.device_type == 'cuda' and not SM80OrLater:
6399
self.skipTest("requires SM80 or later")
6402
if not CDNA2OrLater():
6403
self.skipTest("_int4_mm is supported only for CDNA2 or later")
6408
torch.manual_seed(1)
6409
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6410
b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
6412
b_int32, b_scales_and_zeros = _group_quantize_tensor(
6413
b, n_bit=4, q_group_size=q_group
6417
def int4_mm(a, b_int32, b_scales_and_zeros):
6418
b_int4pack = torch._convert_weight_to_int4pack(
6419
b_int32, inner_k_tiles
6421
return torch._weight_int4pack_mm(
6422
a, b_int4pack, q_group, b_scales_and_zeros
6425
res = int4_mm(a, b_int32, b_scales_and_zeros)
6426
ref = torch.mm(a, b)
6428
mean_err = ((res - ref).abs() / ref).mean()
6429
self.assertTrue(mean_err < 0.05)
6432
@parametrize("m", [32, 64])
6433
@parametrize("k", [32, 64])
6434
@parametrize("n", [48, 64])
6435
def test__int8_mm(self, device, m, k, n):
6436
torch.manual_seed(1)
6437
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6438
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6440
def convert_weight_to_int8pack(b):
6441
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6442
b, -128, 127, torch.int8
6444
return b_int8pack, b_scales
6446
def weight_int8pack_mm(a, b_int8pack, b_scales):
6447
return torch._weight_int8pack_mm(
6448
a, b_int8pack, b_scales
6451
b_int8pack, b_scales = convert_weight_to_int8pack(b)
6452
res = weight_int8pack_mm(a, b_int8pack, b_scales)
6453
ref = torch.mm(a, b.transpose(0, 1))
6455
mean_err = ((res - ref).abs() / ref).mean()
6456
self.assertTrue(mean_err < 0.05)
6459
@parametrize("m", [32, 64])
6460
@parametrize("k", [32, 64])
6461
@parametrize("n", [48, 64])
6462
def test_compile_int8_mm(self, device, m, k, n):
6463
torch.manual_seed(1)
6464
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
6465
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
6467
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
6468
b, -128, 127, torch.int8
6472
def int8_mm(a, b_int8pack, b_scales):
6473
return torch._weight_int8pack_mm(
6474
a, b_int8pack, b_scales
6477
res = int8_mm(a, b_int8pack, b_scales)
6478
ref = torch.mm(a, b.transpose(0, 1))
6480
mean_err = ((res - ref).abs() / ref).mean()
6481
self.assertTrue(mean_err < 0.05)
6484
@parametrize("m", [32, 35, 36, 40, 64])
6485
@parametrize("k", [32, 35, 36, 40, 64])
6489
def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k):
6490
torch.manual_seed(1)
6491
a = torch.rand((m, k), dtype=torch.half, device=device)
6492
b = torch.rand((1, k), dtype=torch.half, device=device)
6494
prev = torch._C._get_cpu_allow_fp16_reduced_precision_reduction()
6496
torch._C._set_cpu_allow_fp16_reduced_precision_reduction(False)
6497
ref = torch.mm(a, b.t())
6499
torch._C._set_cpu_allow_fp16_reduced_precision_reduction(True)
6500
except RuntimeError as e:
6501
raise unittest.SkipTest from e
6502
res = torch.mm(a, b.t())
6503
torch.testing.assert_close(res, ref, atol=1e-2, rtol=1e-2)
6505
torch._C._set_cpu_allow_fp16_reduced_precision_reduction(prev)
6508
@onlyNativeDeviceTypes
6510
@dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
6511
@dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble)
6512
@tf32_on_and_off(0.01)
6513
@bf32_on_and_off(0.01)
6514
def test_mm(self, device, dtype):
6515
def _test_mm(n, m, p, dtype, genf):
6517
def matrixmultiply(mat1, mat2):
6521
dtype_ = torch.float if dtype == torch.half else dtype
6522
if dtype == torch.half:
6525
res = torch.zeros(n, p, dtype=dtype_, device=device)
6526
for i, j in iter_indices(res):
6527
res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
6528
return res.half() if dtype == torch.half else res
6533
res = torch.mm(mat1, mat2)
6535
res2 = matrixmultiply(mat1, mat2)
6536
self.assertEqual(res, res2)
6540
mat2 = genf(p, m).t()
6541
res = torch.mm(mat1, mat2)
6543
res2 = matrixmultiply(mat1, mat2)
6544
self.assertEqual(res, res2)
6547
mat1 = genf(m, n).t()
6549
res = torch.mm(mat1, mat2)
6551
res2 = matrixmultiply(mat1, mat2)
6552
self.assertEqual(res, res2)
6555
mat1 = genf(m, n).t()
6556
mat2 = genf(p, m).t()
6557
res = torch.mm(mat1, mat2)
6559
res2 = matrixmultiply(mat1, mat2)
6560
self.assertEqual(res, res2)
6564
mat2 = genf(m, 1).expand(m, p)
6565
res = torch.mm(mat1, mat2)
6567
res2 = matrixmultiply(mat1, mat2)
6568
self.assertEqual(res, res2)
6575
torch.mm(mat1, mat2, out=res)
6577
res2 = matrixmultiply(mat1, mat2)
6578
self.assertEqual(res, res2)
6582
mat1 = genf(m, n).t()
6583
mat2 = genf(p, m).t()
6585
torch.mm(mat1, mat2, out=res)
6587
res2 = matrixmultiply(mat1, mat2)
6588
self.assertEqual(res, res2)
6591
return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
6593
def genf_bfloat(x, y):
6594
return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
6596
def genf_float(x, y):
6597
return torch.randn(x, y, dtype=dtype, device=device)
6599
def genf_Half(x, y):
6600
return torch.randn(x, y, dtype=dtype, device=device)
6602
for (n, m, p) in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
6603
if (dtype == torch.int32) or (dtype == torch.int64):
6605
elif (dtype == torch.bfloat16):
6607
elif (dtype == torch.half):
6612
_test_mm(n, m, p, dtype, genf)
6614
@onlyNativeDeviceTypes
6615
def test_mm_bmm_non_memory_dense(self, device):
6616
def _slice(tensor, fn):
6617
return fn(tensor)[..., ::2]
6618
A = torch.randn(3, 6, dtype=torch.cfloat, device=device)
6619
B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6620
out = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6621
out1 = torch.empty(3, 3, device=device, dtype=torch.complex64).t()
6622
A_conj = _slice(A, torch.conj)
6623
A_conj_physical = _slice(A, torch.conj_physical)
6625
self.assertEqual(torch.mm(A_conj, B, out=out), torch.mm(A_conj_physical, B, out=out))
6626
self.assertEqual(torch.mm(A_conj.t(), B, out=out), torch.mm(A_conj_physical.t(), B, out=out))
6628
Ab = torch.randn(2, 3, 6, dtype=torch.cfloat, device=device)
6629
Bb = torch.randn(2, 3, 3, dtype=torch.cfloat, device=device)
6630
Bb_ = torch.randn(1, 3, 3, dtype=torch.cfloat, device=device).expand(2, 3, 3)
6631
out_b = torch.empty(2, 3, 3, device=device, dtype=torch.complex64).mT
6633
Ab_conj = _slice(Ab, torch.conj)
6634
Ab_conj_physical = _slice(Ab, torch.conj_physical)
6639
self.assertEqual(torch.bmm(Ab_conj, Bb, out=out_b), torch.bmm(Ab_conj_physical, Bb, out=out_b))
6640
self.assertEqual(torch.bmm(t_b(Ab_conj), Bb, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb, out=out_b))
6643
self.assertEqual(torch.bmm(Ab_conj, Bb_, out=out_b), torch.bmm(Ab_conj_physical, Bb_, out=out_b))
6644
self.assertEqual(torch.bmm(t_b(Ab_conj), Bb_, out=out_b), torch.bmm(t_b(Ab_conj_physical), Bb_, out=out_b))
6646
@onlyNativeDeviceTypes
6647
def test_mm_conjtranspose(self, device):
6648
A = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6649
B = torch.randn(3, 3, dtype=torch.cfloat, device=device)
6652
out1 = torch.mm(A.t().conj(), B)
6653
out1_ref = torch.mm(A.t().conj_physical(), B)
6654
self.assertEqual(out1, out1_ref)
6657
out1 = torch.mm(A, B.t().conj())
6658
out1_ref = torch.mm(A, B.t().conj_physical())
6659
self.assertEqual(out1, out1_ref)
6662
out1 = torch.mm(A.t().conj(), B.t().conj())
6663
out1_ref = torch.mm(A.t().conj_physical(), B.t().conj_physical())
6664
self.assertEqual(out1, out1_ref)
6666
@onlyNativeDeviceTypes
6667
def test_mm_empty_inputs_mixed_dtype_errors(self, device):
6668
a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
6669
b = torch.randn(10, 20, dtype=torch.float32, device=device)
6670
with self.assertRaisesRegex(RuntimeError, "expected .* and .* to have the same dtype, but got:"):
6673
@onlyNativeDeviceTypes
6674
@dtypes(torch.float32, torch.float64)
6675
def test_strided_mm_bmm(self, device, dtype):
6677
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device)
6678
new_shape = [2, 2, 2]
6679
new_stride = [3, 1, 1]
6680
sx = torch.as_strided(x, size=new_shape, stride=new_stride)
6682
torch_fn = lambda x: torch.bmm(x, x)
6683
np_fn = lambda x: np.matmul(x, x)
6684
self.compare_with_numpy(torch_fn, np_fn, sx)
6686
torch_fn = lambda x: torch.mm(x, x)
6687
self.compare_with_numpy(torch_fn, np_fn, sx[0])
6689
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6690
@onlyNativeDeviceTypes
6691
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6692
@tf32_on_and_off(0.05)
6693
@bf32_on_and_off(0.05)
6694
def test_bmm(self, device, dtype):
6695
if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6701
batch_sizes = [1, 10]
6702
M, N, O = 23, 15, 12
6703
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6706
if dtype == torch.bfloat16 and self.device_type == 'cuda':
6707
is_supported = TEST_WITH_ROCM or SM53OrLater
6709
if not is_supported:
6710
for num_batches in batch_sizes:
6711
b1 = torch.randn(num_batches, M, N, device=device).to(dtype)
6712
b2 = torch.randn(num_batches, N, O, device=device).to(dtype)
6713
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6714
lambda: torch.bmm(b1, b2))
6718
d = {x: i for i, x in enumerate(p)}
6719
return (d[0], d[1], d[2])
6721
def generate_inputs(num_batches):
6723
for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6724
b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1)
6725
b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1)
6726
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6727
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6730
for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
6731
shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
6732
shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
6733
b1 = make_tensor(shape1, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, M, N)
6734
b2 = make_tensor(shape2, dtype=dtype, device=device, low=-0.1, high=0.1).expand(num_batches, N, O)
6737
for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6738
shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6739
shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6740
b1 = torch.randn(shape1, dtype=dtype, device=device)
6741
b2 = torch.randn(shape2, dtype=dtype, device=device)
6744
for num_batches in batch_sizes:
6745
for (b1, b2), perm3 in itertools.product(generate_inputs(num_batches), itertools.permutations((0, 1, 2))):
6746
res1 = torch.bmm(b1, b2)
6747
res2 = torch.full((num_batches, M, O), math.nan, dtype=dtype, device=device) \
6748
.permute(perm3).contiguous().permute(invert_perm(perm3))
6749
torch.bmm(b1, b2, out=res2)
6750
expect = torch.from_numpy(
6751
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6752
self.assertEqual(expect, res1)
6753
self.assertEqual(expect, res2)
6755
if self.device_type == 'cuda':
6757
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
6758
self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
6759
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()))
6761
def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
6762
getattr(out_tensor, func + "_")(b1, b2)
6763
self.assertEqual(out_tensor, ref)
6764
res3 = out_tensor.clone()
6766
with self.assertWarnsOnceRegex(
6767
UserWarning, f"This overload of {func}_ is deprecated"):
6768
getattr(out_tensor, func + "_")(1, b1, b2)
6769
self.assertEqual(out_tensor, ref * 2),
6770
getattr(res3, func + "_")(b1, b2, beta=1)
6771
self.assertEqual(out_tensor, res3)
6773
with self.assertWarnsOnceRegex(
6774
UserWarning, f"This overload of {func}_ is deprecated"):
6775
getattr(out_tensor, func + "_")(1., .5, b1, b2)
6776
self.assertEqual(out_tensor, ref * 2.5)
6777
getattr(res3, func + "_")(b1, b2, beta=1., alpha=.5)
6778
self.assertEqual(out_tensor, res3)
6780
with self.assertWarnsOnceRegex(
6781
UserWarning, f"This overload of {func} is deprecated"):
6782
self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
6784
res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=.5)
6785
self.assertEqual(res4, ref * 3),
6787
nan = torch.full_like(out_tensor, math.nan)
6788
res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
6789
self.assertEqual(res5, ref)
6792
res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1j, alpha=.5j)
6793
self.assertEqual(res6, out_tensor * .1j + .5j * ref)
6795
res6 = getattr(torch, func)(out_tensor, b1, b2, beta=.1, alpha=.5)
6796
self.assertEqual(res6, out_tensor * .1 + .5 * ref)
6798
res7 = torch.full_like(out_tensor, math.nan)
6799
getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
6800
self.assertEqual(res7, ref)
6802
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
6803
@onlyNativeDeviceTypes
6804
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6805
@tf32_on_and_off(0.05)
6806
@bf32_on_and_off(0.05)
6807
def test_addbmm(self, device, dtype):
6808
if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6815
M, N, O = 16, 17, 18
6818
if dtype == torch.bfloat16:
6819
if self.device_type == 'cpu':
6822
is_supported = TEST_WITH_ROCM or SM53OrLater
6824
if not is_supported:
6825
b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6826
b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6827
t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
6828
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6829
lambda: torch.addbmm(t, b1, b2))
6833
d = {x: i for i, x in enumerate(p)}
6834
return (d[0], d[1], d[2])
6836
def generate_tensor():
6837
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
6839
for perm1, perm2 in itertools.product(itertools.permutations((0, 1, 2)), repeat=2):
6840
for perm3 in itertools.permutations((0, 1)):
6841
b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1) * 0.1
6842
b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1) * 0.1
6843
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6844
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6845
ref = torch.from_numpy(
6846
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6847
).to(device=device, dtype=dtype).sum(0)
6848
out_tensor = torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
6849
yield b1, b2, ref, out_tensor
6851
for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
6852
shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
6853
shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
6854
b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N) * 0.1
6855
b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O) * 0.1
6856
ref = torch.from_numpy(
6857
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6858
).to(device=device, dtype=dtype).sum(0)
6859
out_tensor = torch.zeros_like(ref)
6860
yield b1, b2, ref, out_tensor
6862
for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6863
shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6864
shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6865
b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) * 0.1
6866
b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) * 0.1
6867
ref = torch.from_numpy(
6868
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
6869
).to(device=device, dtype=dtype).sum(0)
6870
out_tensor = torch.zeros_like(ref)
6871
yield b1, b2, ref, out_tensor
6873
for b1, b2, ref, out_tensor in generate_tensor():
6874
self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
6876
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
6877
@onlyNativeDeviceTypes
6878
@dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half))
6879
@tf32_on_and_off(0.05)
6880
@bf32_on_and_off(0.05)
6881
def test_baddbmm(self, device, dtype):
6882
if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater:
6892
if dtype == torch.bfloat16 and self.device_type == 'cuda':
6893
is_supported = TEST_WITH_ROCM or SM53OrLater
6895
if not is_supported:
6896
b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6897
b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6898
t = make_tensor((num_batches, M, O), dtype=dtype, device=device, low=-1, high=1)
6899
self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
6900
lambda: torch.baddbmm(t, b1, b2))
6904
d = {x: i for i, x in enumerate(p)}
6905
return (d[0], d[1], d[2])
6907
def generate_tensor():
6908
numpy_dtype = dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
6910
for perm1, perm2, perm3 in itertools.product(itertools.permutations((0, 1, 2)), repeat=3):
6911
b1 = make_tensor((num_batches, M, N), dtype=dtype, device=device, low=-1, high=1)
6912
b2 = make_tensor((num_batches, N, O), dtype=dtype, device=device, low=-1, high=1)
6913
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
6914
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
6915
ref = torch.from_numpy(
6916
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6917
out_tensor = torch.zeros_like(ref)
6918
out_tensor = out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
6919
yield b1, b2, ref, out_tensor
6921
for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
6922
shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
6923
shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
6924
b1 = make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, M, N)
6925
b2 = make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1).expand(num_batches, N, O)
6926
ref = torch.from_numpy(
6927
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6928
out_tensor = torch.zeros_like(ref)
6929
yield b1, b2, ref, out_tensor
6931
for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
6932
shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
6933
shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
6934
b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
6935
b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
6936
ref = torch.from_numpy(
6937
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()).to(device=device, dtype=dtype)
6938
out_tensor = torch.zeros_like(ref)
6939
yield b1, b2, ref, out_tensor
6941
for b1, b2, ref, out_tensor in generate_tensor():
6942
self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
6944
@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
6947
@dtypes(*floating_and_complex_types())
6948
def test_pinverse(self, device, dtype):
6949
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
6950
make_arg = partial(make_fullrank, device=device, dtype=dtype)
6954
MPI = torch.pinverse(M)
6955
MPI_ = MPI.cpu().numpy()
6956
M_ = M.cpu().numpy()
6958
self.assertEqual(M_, np.matmul(np.matmul(M_, MPI_), M_))
6959
self.assertEqual(MPI_, np.matmul(np.matmul(MPI_, M_), MPI_))
6960
self.assertEqual(np.matmul(M_, MPI_), np.matmul(M_, MPI_).swapaxes(-2, -1).conj())
6961
self.assertEqual(np.matmul(MPI_, M_), np.matmul(MPI_, M_).swapaxes(-2, -1).conj())
6963
self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2]))
6964
for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5),
6965
(3, 2), (5, 3, 2), (7, 5, 3, 2),
6966
(2, 3), (5, 2, 3), (7, 5, 2, 3),
6967
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]:
6968
M = torch.randn(*sizes, dtype=dtype, device=device)
6972
for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]:
6974
batchdims = sizes[:-2]
6975
M = make_arg(*batchdims, matsize, matsize)
6976
self.assertEqual(torch.eye(matsize, dtype=dtype, device=device).expand(sizes), M.pinverse().matmul(M),
6977
atol=1e-7, rtol=0, msg='pseudo-inverse for invertible matrix')
6980
@skipCUDAIfNoMagmaAndNoCusolver
6981
@dtypes(torch.double, torch.cdouble)
6982
def test_matrix_power_non_negative(self, device, dtype):
6984
t = make_tensor(size, dtype=dtype, device=device)
6986
res = torch.linalg.matrix_power(t, n)
6987
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
6988
self.assertEqual(res.cpu(), torch.from_numpy(ref))
6997
@skipCUDAIfNoMagmaAndNoCusolver
6998
@dtypes(torch.double, torch.cdouble)
6999
def test_matrix_power_negative(self, device, dtype):
7000
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7001
make_arg = partial(make_fullrank, device=device, dtype=dtype)
7005
for n in range(-7, 0):
7006
res = torch.linalg.matrix_power(t, n)
7007
ref = np.linalg.matrix_power(t.cpu().numpy(), n)
7008
self.assertEqual(res.cpu(), torch.from_numpy(ref))
7019
@dtypes(torch.float, torch.complex64)
7020
def test_linalg_matrix_exp_utils(self, device, dtype):
7022
def run_test(coeff_shape, data_shape):
7023
coeffs = torch.rand(*coeff_shape, device=device, dtype=torch.float)
7024
x = torch.rand(coeff_shape[1], *data_shape, device=device, dtype=dtype)
7026
res1 = torch._compute_linear_combination(x, coeffs)
7027
res2 = (x.unsqueeze(0) * coeffs.view(*coeff_shape, *([1] * len(data_shape)))).sum(1)
7028
self.assertEqual(res1, res2, atol=1e-5, rtol=0.0)
7031
res3 = torch.zeros(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7032
torch._compute_linear_combination(x, coeffs, out=res3)
7033
self.assertEqual(res1, res3, atol=1e-5, rtol=0.0)
7035
res4 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7036
torch._compute_linear_combination(x, coeffs, out=res4)
7037
self.assertEqual(res1, res4 - 1.0, atol=1e-5, rtol=0.0)
7039
res5 = torch.ones(coeff_shape[0], *data_shape, device=device, dtype=dtype)
7040
res5_clone = res5.clone()
7041
torch._compute_linear_combination(x, coeffs, out=res5)
7042
self.assertEqual(res1, res5 - res5_clone, atol=1e-5, rtol=0.0)
7044
run_test([1, 3], [2, 2])
7045
run_test([3, 1], [2, 2])
7046
run_test([1, 10], [10, 10])
7047
run_test([10, 1], [10, 10])
7048
run_test([5, 3], [2, 2])
7049
run_test([5, 3], [100, 100])
7050
run_test([3, 4], [3, 3, 3])
7051
run_test([3, 4], [3, 3, 3, 3])
7054
with self.assertRaises(RuntimeError):
7055
x = torch.rand([], device=device, dtype=dtype)
7056
coeffs = torch.rand([2, 2], device=device, dtype=dtype)
7057
res = torch._compute_linear_combination(x, coeffs)
7061
@dtypes(torch.complex64)
7062
def test_linalg_matrix_exp_no_warnings(self, device, dtype):
7064
with freeze_rng_state():
7065
torch.manual_seed(42)
7066
tens = 0.5 * torch.randn(10, 3, 3, dtype=dtype, device=device)
7067
tens = (0.5 * (tens.transpose(-1, -2) + tens))
7068
with warnings.catch_warnings(record=True) as w:
7069
tens.imag = torch.matrix_exp(tens.imag)
7070
self.assertFalse(len(w))
7074
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7075
def test_linalg_matrix_exp_boundary_cases(self, device, dtype):
7076
expm = torch.linalg.matrix_exp
7078
with self.assertRaisesRegex(RuntimeError, "Expected a floating point or complex tensor"):
7079
expm(torch.randn(3, 3).type(torch.int))
7081
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7082
expm(torch.randn(3))
7084
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7085
expm(torch.randn(3, 2, 1))
7088
x = torch.randn(3, 3, 1, 1)
7089
self.assertEqual(expm(x), x.exp())
7093
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
7094
def test_linalg_matrix_exp_perverse_nan_values(self, device, dtype):
7095
expm = torch.linalg.matrix_exp
7098
x[0, 0, 0] = torch.nan
7102
x = with_nan(torch.randn(1, 1, 1))
7103
self.assertTrue(torch.isnan(expm(x)).any())
7104
x = with_nan(torch.randn(1, 2, 2))
7105
for v in [1, 2, 3, 4, 5, 6, 7, 8, 9, 100, 1000]:
7106
self.assertTrue(torch.isnan(expm(x / v)).any())
7109
x = with_nan(torch.randn(2, 2, 2))
7110
self.assertTrue(torch.isnan(expm(x)).any())
7111
x = with_nan(torch.randn(4096, 2, 2))
7112
self.assertTrue(torch.isnan(expm(x)).any())
7117
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7118
def test_linalg_matrix_exp_analytic(self, device, dtype):
7119
expm = torch.linalg.matrix_exp
7121
x = torch.zeros(20, 20, dtype=dtype, device=device)
7122
self.assertTrue((expm(x) == torch.eye(20, 20, dtype=dtype, device=device)).all().item())
7124
def normalize_to_1_operator_norm(sample, desired_norm):
7125
sample_norm, _ = sample.abs().sum(-2).max(-1)
7126
sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7127
return sample_to_1_norm * desired_norm
7129
def gen_good_cond_number_matrices(*n):
7131
Generates a diagonally-domimant matrix
7132
with the eigenvalues centered at 1
7133
and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7135
identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7136
x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7137
x = (x - x * identity) + identity
7141
if dtype == torch.float:
7143
1.192092800768788e-07,
7144
5.978858893805233e-04,
7145
5.116619363445086e-02,
7146
5.800524627688768e-01,
7147
1.461661507209034e+00,
7148
3.010066362817634e+00
7152
2.220446049250313e-16,
7153
2.580956802971767e-08,
7154
3.397168839976962e-04,
7155
4.991228871115323e-02,
7156
2.996158913811580e-01,
7157
1.090863719290036e+00
7161
q = gen_good_cond_number_matrices(*n)
7162
q_ = q.cpu().numpy()
7163
qinv = torch.inverse(q)
7164
qinv_ = qinv.cpu().numpy()
7165
d = torch.randn(n[:-1], dtype=dtype, device=device)
7166
x = torch.from_numpy(
7167
np.matmul(q_, np.matmul(torch.diag_embed(d).cpu().numpy(), qinv_))).to(device)
7168
x_norm, _ = x.abs().sum(-2).max(-1)
7172
mexp_analytic = np.matmul(
7175
torch.diag_embed(d.exp()).cpu().numpy(),
7179
self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7183
for i in range(len(thetas) - 1):
7184
sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7185
sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7188
for sample_norm in sample_norms:
7189
x_normalized = normalize_to_1_operator_norm(x, sample_norm)
7191
mexp = expm(x_normalized)
7192
mexp_analytic = np.matmul(
7195
torch.diag_embed((d / x_norm.unsqueeze(-1) * sample_norm).exp()).cpu().numpy(),
7199
self.assertEqual(mexp, mexp_analytic, atol=1e-3, rtol=0.0)
7214
run_test(3, 100, 100)
7215
run_test(3, 200, 200)
7218
run_test(3, 3, 2, 2)
7219
run_test(3, 3, 3, 3)
7220
run_test(3, 3, 4, 4)
7221
run_test(3, 3, 5, 5)
7222
run_test(3, 3, 100, 100)
7223
run_test(3, 3, 200, 200)
7227
@dtypes(torch.float, torch.double)
7228
def test_linalg_matrix_exp_batch(self, device, dtype):
7231
tensors_batch = torch.zeros(n, dtype=dtype, device=device)
7232
tensors_batch = tensors_batch.view(-1, n[-2], n[-1])
7234
num_matrices = tensors_batch.size(0)
7236
for i in range(num_matrices):
7237
tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device))
7239
for i in range(num_matrices):
7240
tensors_batch[i, ...] = tensors_list[i]
7242
tensors_exp_map = (torch.linalg.matrix_exp(x) for x in tensors_list)
7243
tensors_exp_batch = torch.linalg.matrix_exp(tensors_batch)
7245
for i, tensor_exp in enumerate(tensors_exp_map):
7246
self.assertEqual(tensors_exp_batch[i, ...], tensor_exp)
7255
run_test(3, 3, 2, 2)
7256
run_test(3, 3, 3, 3)
7257
run_test(3, 3, 4, 4)
7258
run_test(3, 3, 5, 5)
7262
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
7263
def test_linalg_matrix_exp_compare_with_taylor(self, device, dtype):
7265
def normalize_to_1_operator_norm(sample, desired_norm):
7266
sample_norm, _ = sample.abs().sum(-2).max(-1)
7267
sample_to_1_norm = sample / sample_norm.unsqueeze(-1).unsqueeze(-1)
7268
return sample_to_1_norm * desired_norm
7270
def gen_good_cond_number_matrices(*n):
7272
Generates a diagonally-domimant matrix
7273
with the eigenvalues centered at 1
7274
and the radii at most (n[-1] - 1) / (n[-2] ** 2)
7276
identity = torch.eye(n[-2], n[-1], dtype=dtype, device=device).expand(*n)
7277
x = torch.rand(*n, dtype=dtype, device=device) / (n[-1] ** 2)
7278
x = (x - x * identity) + identity
7281
def get_taylor_approximation(a, deg):
7282
a_ = a.cpu().numpy()
7283
identity = torch.eye(a.size(-2), a.size(-1), dtype=dtype, device=device).expand_as(a)
7284
res = identity.cpu().numpy()
7285
taylor_term = identity.cpu().numpy()
7287
for i in range(1, deg + 1):
7288
taylor_term = np.matmul(a_, taylor_term) / i
7289
res = res + taylor_term
7293
def scale_square(a, deg):
7294
if a.abs().pow(2).sum().sqrt() < 1.0:
7295
return get_taylor_approximation(a, 12)
7297
s = int(torch.log2(a.abs().pow(2).sum().sqrt()).ceil().item())
7299
b = get_taylor_approximation(b, 18)
7302
return torch.from_numpy(b).to(a.device)
7305
degs = [1, 2, 4, 8, 12, 18]
7306
if dtype == torch.float:
7308
1.192092800768788e-07,
7309
5.978858893805233e-04,
7310
5.116619363445086e-02,
7311
5.800524627688768e-01,
7312
1.461661507209034e+00,
7313
3.010066362817634e+00
7317
2.220446049250313e-16,
7318
2.580956802971767e-08,
7319
3.397168839976962e-04,
7320
4.991228871115323e-02,
7321
2.996158913811580e-01,
7322
1.090863719290036e+00
7327
for i in range(len(thetas) - 1):
7328
sample_norms.append(0.5 * (thetas[i] + thetas[i + 1]))
7329
sample_norms = [thetas[0] / 2] + sample_norms + [thetas[-1] * 2]
7330
degs = [degs[0]] + degs
7332
for sample_norm, deg in zip(sample_norms, degs):
7333
x = gen_good_cond_number_matrices(*n)
7334
x = normalize_to_1_operator_norm(x, sample_norm)
7336
mexp = torch.linalg.matrix_exp(x)
7337
mexp_taylor = scale_square(x, deg)
7339
self.assertEqual(mexp, mexp_taylor, atol=1e-2, rtol=0.0)
7354
run_test(3, 3, 2, 2)
7355
run_test(3, 3, 3, 3)
7356
run_test(3, 3, 4, 4)
7357
run_test(3, 3, 5, 5)
7361
@dtypes(*floating_and_complex_types())
7362
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7363
torch.float64: 1e-8, torch.complex128: 1e-8})
7364
def test_slogdet(self, device, dtype):
7365
from torch.testing._internal.common_utils import (random_hermitian_matrix, random_hermitian_psd_matrix,
7366
random_hermitian_pd_matrix, random_square_matrix_of_rank)
7370
def run_test(matsize, batchdims, mat_chars):
7371
num_matrices = np.prod(batchdims)
7372
list_of_matrices = []
7373
if num_matrices != 0:
7374
for idx in range(num_matrices):
7375
mat_type = idx % len(mat_chars)
7376
if mat_chars[mat_type] == 'hermitian':
7377
list_of_matrices.append(random_hermitian_matrix(matsize, dtype=dtype, device=device))
7378
elif mat_chars[mat_type] == 'hermitian_psd':
7379
list_of_matrices.append(random_hermitian_psd_matrix(matsize, dtype=dtype, device=device))
7380
elif mat_chars[mat_type] == 'hermitian_pd':
7381
list_of_matrices.append(random_hermitian_pd_matrix(matsize, dtype=dtype, device=device))
7382
elif mat_chars[mat_type] == 'singular':
7383
list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7384
elif mat_chars[mat_type] == 'non_singular':
7385
list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7386
full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7388
full_tensor = torch.randn(*batchdims, matsize, matsize, dtype=dtype, device=device)
7390
actual_value = torch.linalg.slogdet(full_tensor)
7391
expected_value = np.linalg.slogdet(full_tensor.cpu().numpy())
7392
self.assertEqual(expected_value[0], actual_value[0], atol=self.precision, rtol=self.precision)
7393
self.assertEqual(expected_value[1], actual_value[1], atol=self.precision, rtol=self.precision)
7396
sign_out = torch.empty_like(actual_value[0])
7397
logabsdet_out = torch.empty_like(actual_value[1])
7398
ans = torch.linalg.slogdet(full_tensor, out=(sign_out, logabsdet_out))
7399
self.assertEqual(ans[0], sign_out)
7400
self.assertEqual(ans[1], logabsdet_out)
7401
self.assertEqual(sign_out, actual_value[0])
7402
self.assertEqual(logabsdet_out, actual_value[1])
7404
for matsize, batchdims in itertools.product([0, 3, 5], [(0,), (3,), (5, 3)]):
7405
run_test(matsize, batchdims, mat_chars=['hermitian_pd'])
7406
run_test(matsize, batchdims, mat_chars=['singular'])
7407
run_test(matsize, batchdims, mat_chars=['non_singular'])
7408
run_test(matsize, batchdims, mat_chars=['hermitian', 'hermitian_pd', 'hermitian_psd'])
7409
run_test(matsize, batchdims, mat_chars=['singular', 'non_singular'])
7413
@dtypes(*floating_and_complex_types())
7414
def test_slogdet_errors_and_warnings(self, device, dtype):
7416
a = torch.randn(2, 3, device=device, dtype=dtype)
7417
with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
7418
torch.linalg.slogdet(a)
7421
a = torch.randn(2, device=device, dtype=dtype)
7422
with self.assertRaisesRegex(RuntimeError, r'must have at least 2 dimensions'):
7423
torch.linalg.slogdet(a)
7425
a = torch.randn(2, 2, device=device, dtype=torch.bfloat16)
7426
with self.assertRaisesRegex(RuntimeError, r'Low precision dtypes not supported'):
7427
torch.linalg.slogdet(a)
7430
a = torch.randn(2, 3, 3, device=device, dtype=dtype)
7431
sign_out = torch.empty(1, device=device, dtype=dtype)
7432
real_dtype = a.real.dtype if dtype.is_complex else dtype
7433
logabsdet_out = torch.empty(1, device=device, dtype=real_dtype)
7434
with warnings.catch_warnings(record=True) as w:
7436
torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7438
self.assertEqual(len(w), 1)
7439
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7442
if torch.cuda.is_available():
7443
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7444
sign_out = torch.empty(0, device=wrong_device, dtype=dtype)
7445
logabsdet_out = torch.empty(0, device=wrong_device, dtype=real_dtype)
7446
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
7447
torch.linalg.slogdet(a, out=(sign_out, logabsdet_out))
7451
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
7452
@skipCUDAIfNoCusolver
7454
@dtypes(torch.double)
7455
def test_det_logdet_slogdet(self, device, dtype):
7456
def reference_slogdet(M):
7457
sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy())
7458
return M.new_tensor(sdet), M.new_tensor(logabsdet)
7460
def test_single_det(M, target, desc):
7461
target_sdet, target_logabsdet = target
7465
sdet, logabsdet = M.slogdet()
7466
linalg_sdet, linalg_logabsdet = torch.linalg.slogdet(M)
7469
self.assertEqual(det, target_sdet * target_logabsdet.exp(),
7470
atol=1e-6, rtol=0, msg=f'{desc} (det)')
7475
self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7476
atol=1e-6, rtol=0, msg=f'{desc} (slogdet)')
7477
self.assertEqual(linalg_sdet * linalg_logabsdet.exp(), target_sdet * target_logabsdet.exp(),
7478
atol=1e-6, rtol=0, msg=f'{desc} (linalg_slogdet)')
7486
self.assertTrue(logdet.item() != logdet.item(), f'{desc} (logdet negative case)')
7488
self.assertEqual(logdet.exp(), target_logabsdet.exp(),
7489
atol=1e-6, rtol=0, msg=f'{desc} (logdet non-negative case)')
7491
eye = torch.eye(5, dtype=dtype, device=device)
7492
test_single_det(eye, (torch.ones((), dtype=dtype, device=device), torch.zeros((), dtype=dtype, device=device)), 'identity')
7494
for n in range(250, 551, 100):
7495
mat = torch.randn(n, n, dtype=dtype, device=device)
7496
q, _ = torch.qr(mat)
7497
ref_det, ref_logabsdet = reference_slogdet(q)
7498
test_single_det(q, (ref_det, ref_logabsdet), 'orthogonal')
7501
assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
7504
ref_M_sdet, ref_M_logabsdet = reference_slogdet(M)
7506
test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic')
7507
if ref_M_logabsdet.exp().item() >= 1e-6:
7509
test_single_det(M_inv, reference_slogdet(M_inv), 'inverse')
7511
test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose')
7514
for scale in [-2, -0.1, 0, 10]:
7516
target = ref_M_sdet, ref_M_logabsdet + math.log(scale)
7518
target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7520
target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale)
7524
M_clone[:, x] *= scale
7525
test_single_det(M_clone, target, 'scale a row')
7528
M_clone[x, :] *= scale
7529
test_single_det(M_clone, target, 'scale a column')
7531
for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
7532
assert x1 != x2, 'x1 and x2 needs to be different for this test'
7533
target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7536
M_clone[:, x2] = M_clone[:, x1]
7537
test_single_det(M_clone, target, 'two rows are same')
7540
M_clone[x2, :] = M_clone[x1, :]
7541
test_single_det(M_clone, target, 'two columns are same')
7543
for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
7544
det_scale = scale1 * scale2 * -1
7546
target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale)
7547
elif det_scale == 0:
7548
target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
7550
target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale)
7554
t = M_clone[:, x1] * scale1
7555
M_clone[:, x1] += M_clone[:, x2] * scale2
7557
test_single_det(M_clone, target, 'exchanging rows')
7560
t = M_clone[x1, :] * scale1
7561
M_clone[x1, :] += M_clone[x2, :] * scale2
7563
test_single_det(M_clone, target, 'exchanging columns')
7565
def get_random_mat_scale(n):
7586
return math.factorial(n - 1) ** (-1.0 / (2 * n))
7588
for n in [5, 10, 25]:
7589
scale = get_random_mat_scale(n)
7590
test(torch.randn(n, n, dtype=dtype, device=device) * scale)
7591
r = torch.randn(n, n, dtype=dtype, device=device) * scale
7595
r = torch.randn(n, n, dtype=dtype, device=device) * scale
7596
test(r.mm(r.t()) + torch.eye(n, dtype=dtype, device=device) * 1e-6)
7598
r = torch.randn(n, n, dtype=dtype, device=device) * scale
7604
test((torch.randn(n, n, n + 1, dtype=dtype, device=device) * scale)[:, 2, 1:])
7606
r = torch.randn(n, n, dtype=dtype, device=device) * scale
7608
if reference_slogdet(u)[0] < 0:
7610
if reference_slogdet(v)[0] < 0:
7614
test(u.mm(s.diag()).mm(v))
7618
r = torch.randn(512, 512, dtype=dtype, device=device)
7620
s.fill_(1. / (100 * s.numel()))
7621
test(u.mm(s.diag()).mm(v))
7625
@dtypes(torch.double)
7626
def test_det_logdet_slogdet_batched(self, device, dtype):
7627
from torch.testing._internal.common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix,
7628
random_symmetric_pd_matrix, random_square_matrix_of_rank)
7632
def run_test(matsize, batchdims, mat_chars):
7633
num_matrices = reduce(operator.mul, batchdims, 1)
7634
list_of_matrices = []
7636
for idx in range(num_matrices):
7637
mat_type = idx % len(mat_chars)
7638
if mat_chars[mat_type] == 'sym':
7639
list_of_matrices.append(random_symmetric_matrix(matsize, dtype=dtype, device=device))
7640
elif mat_chars[mat_type] == 'sym_psd':
7641
list_of_matrices.append(random_symmetric_psd_matrix(matsize, dtype=dtype, device=device))
7642
elif mat_chars[mat_type] == 'sym_pd':
7643
list_of_matrices.append(random_symmetric_pd_matrix(matsize, dtype=dtype, device=device))
7644
elif mat_chars[mat_type] == 'sing':
7645
list_of_matrices.append(torch.ones(matsize, matsize, dtype=dtype, device=device))
7646
elif mat_chars[mat_type] == 'non_sing':
7647
list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize, dtype=dtype, device=device))
7648
full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize))
7650
full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize)))
7652
for fn in [torch.det, torch.logdet, torch.slogdet, torch.linalg.slogdet]:
7654
actual_value = fn(full_tensor)
7655
for full_idx in itertools.product(*(list(range(x)) for x in batchdims)):
7656
expected_value.append(fn(full_tensor[full_idx]))
7658
if fn == torch.slogdet or fn == torch.linalg.slogdet:
7659
sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims)
7660
expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims)
7661
self.assertEqual(sign_value, actual_value[0])
7662
self.assertEqual(expected_value, actual_value[1])
7664
expected_value = torch.stack(expected_value, dim=0).reshape(batchdims)
7665
self.assertEqual(actual_value, expected_value)
7667
for matsize, batchdims in itertools.product([3, 5], [(3,), (5, 3)]):
7668
run_test(matsize, batchdims, mat_chars=['sym_pd'])
7669
run_test(matsize, batchdims, mat_chars=['sing'])
7670
run_test(matsize, batchdims, mat_chars=['non_sing'])
7671
run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd'])
7672
run_test(matsize, batchdims, mat_chars=['sing', 'non_sing'])
7676
@dtypes(*floating_and_complex_types())
7677
def test_cholesky_inverse(self, device, dtype):
7678
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
7680
def run_test(shape, batch, upper, contiguous):
7681
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
7682
if A.numel() > 0 and not contiguous:
7684
self.assertFalse(A.is_contiguous())
7685
L = torch.linalg.cholesky(A)
7686
expected_inverse = torch.inverse(A)
7687
L = L.mH if upper else L
7688
actual_inverse = torch.cholesky_inverse(L, upper)
7689
self.assertEqual(actual_inverse, expected_inverse)
7692
batches = ((), (0,), (3, ), (2, 2))
7693
for shape, batch, upper, contiguous in list(itertools.product(shapes, batches, (True, False), (True, False))):
7694
run_test(shape, batch, upper, contiguous)
7697
A = random_hermitian_pd_matrix(3, 2, dtype=dtype, device=device)
7698
L = torch.linalg.cholesky(A)
7707
out = torch.empty_like(A)
7708
out_t = out.mT.clone(memory_format=torch.contiguous_format)
7710
ans = torch.cholesky_inverse(L, out=out)
7711
self.assertEqual(ans, out)
7712
expected = torch.inverse(A)
7713
self.assertEqual(expected, out)
7716
out = torch.empty_like(A)
7717
ans = torch.cholesky_inverse(L, out=out)
7718
self.assertEqual(ans, out)
7719
expected = torch.inverse(A)
7720
self.assertEqual(expected, out)
7724
@dtypes(*floating_and_complex_types())
7725
def test_cholesky_inverse_errors_and_warnings(self, device, dtype):
7727
a = torch.randn(2, device=device, dtype=dtype)
7728
with self.assertRaisesRegex(RuntimeError, "must have at least 2 dimensions"):
7729
torch.cholesky_inverse(a)
7732
a = torch.randn(2, 3, device=device, dtype=dtype)
7733
with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"):
7734
torch.cholesky_inverse(a)
7737
a = torch.randn(3, 3, device=device, dtype=dtype)
7738
out = torch.empty(2, 3, device=device, dtype=dtype)
7739
with warnings.catch_warnings(record=True) as w:
7741
torch.cholesky_inverse(a, out=out)
7743
self.assertEqual(len(w), 1)
7744
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
7747
out = torch.empty(*a.shape, dtype=torch.int, device=device)
7748
with self.assertRaisesRegex(RuntimeError, "but got result with dtype Int"):
7749
torch.cholesky_inverse(a, out=out)
7752
if torch.cuda.is_available():
7753
wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
7754
out = torch.empty(0, device=wrong_device, dtype=dtype)
7755
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
7756
torch.cholesky_inverse(a, out=out)
7760
a = torch.randn(3, 3, device=device, dtype=dtype)
7762
if self.device_type == 'cpu':
7763
with self.assertRaisesRegex(torch.linalg.LinAlgError, r"cholesky_inverse: The diagonal element 2 is zero"):
7764
torch.cholesky_inverse(a)
7766
elif self.device_type == 'cuda':
7767
out = torch.cholesky_inverse(a)
7768
self.assertTrue(out.isinf().any() or out.isnan().any())
7770
def _select_broadcastable_dims(self, dims_full=None):
7772
if dims_full is None:
7774
ndims = random.randint(1, 4)
7775
dims_full = [random.randint(1, 8) for _ in range(ndims)]
7777
ndims = len(dims_full)
7782
smaller_ndims = random.randint(1, ndims)
7785
for i in range(ndims - 1, -1, -1):
7786
j = random.randint(1, 3)
7792
dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
7796
dims_large = [dl] + dims_large
7797
if len(dims_small) < smaller_ndims:
7798
dims_small = [ds] + dims_small
7799
return (dims_small, dims_large, dims_full)
7801
def test_broadcast_fused_matmul(self, device):
7802
fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
7805
batch_dim = random.randint(1, 8)
7806
n_dim = random.randint(1, 8)
7807
m_dim = random.randint(1, 8)
7808
p_dim = random.randint(1, 8)
7810
def dims_full_for_fn():
7812
return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7813
elif fn == "addbmm":
7814
return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim])
7816
return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
7818
return ([n_dim], [n_dim, m_dim], [m_dim])
7820
return ([n_dim, m_dim], [n_dim], [m_dim])
7822
raise AssertionError("unknown function")
7824
(t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
7825
(t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
7827
t0_small = torch.randn(*t0_dims_small, device=device).float()
7828
t1 = torch.randn(*t1_dims, device=device).float()
7829
t2 = torch.randn(*t2_dims, device=device).float()
7831
t0_full = t0_small.expand(*t0_dims_full).to(device)
7833
fntorch = getattr(torch, fn)
7834
r0 = fntorch(t0_small, t1, t2)
7835
r1 = fntorch(t0_full, t1, t2)
7836
self.assertEqual(r0, r1)
7838
@tf32_on_and_off(0.001)
7839
@bf32_on_and_off(0.001)
7840
def test_broadcast_batched_matmul(self, device):
7841
n_dim = random.randint(1, 8)
7842
m_dim = random.randint(1, 8)
7843
p_dim = random.randint(1, 8)
7844
full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))]
7845
(batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims)
7847
def verify_batched_matmul(full_lhs, one_dimensional):
7848
if not one_dimensional:
7849
lhs_dims = [n_dim, m_dim]
7850
rhs_dims = [m_dim, p_dim]
7851
result_dims = [n_dim, p_dim]
7853
lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim]
7854
rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim]
7855
result_dims = [n_dim] if full_lhs else [p_dim]
7857
lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim]
7858
rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1]
7859
full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims
7860
dim0_dims = rhs_dims if full_lhs else lhs_dims
7861
small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims)
7863
small = torch.randn(*(small_dims), device=device).float()
7864
dim0 = torch.randn(*(dim0_dims), device=device).float()
7865
full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float()
7866
if not one_dimensional:
7867
(lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,))
7869
(lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,))
7871
def maybe_squeeze_result(l, r, result):
7872
if len(lhs_dims) == 1 and l.dim() != 1:
7873
return result.squeeze(-2)
7874
elif len(rhs_dims) == 1 and r.dim() != 1:
7875
return result.squeeze(-1)
7879
for lhs in lhsTensors:
7880
lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims)))
7881
lhs_expanded_matmul_fn = lhs_expanded.matmul
7882
for rhs in rhsTensors:
7883
rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)).
7884
expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims))))
7885
truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded))
7886
for l in (lhs, lhs_expanded):
7887
for r in (rhs, rhs_expanded):
7888
l_matmul_fn = l.matmul
7889
result = maybe_squeeze_result(l, r, l_matmul_fn(r))
7890
self.assertEqual(truth, result)
7892
torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
7893
self.assertEqual(truth, torch_result)
7895
out = torch.zeros_like(torch_result)
7896
torch.matmul(l, r, out=out)
7897
self.assertEqual(truth, maybe_squeeze_result(l, r, out))
7900
bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
7901
rhs_expanded.contiguous().view(-1, *rhs_mat_dims)))
7902
self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims))
7904
for indices in itertools.product((True, False), repeat=2):
7905
verify_batched_matmul(*indices)
7907
def lu_solve_test_helper(self, A_dims, b_dims, pivot, device, dtype):
7908
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7909
make_A = partial(make_fullrank, device=device, dtype=dtype)
7911
b = torch.randn(*b_dims, dtype=dtype, device=device)
7913
LU_data, LU_pivots, info = torch.linalg.lu_factor_ex(A)
7914
self.assertEqual(info, torch.zeros_like(info))
7915
return b, A, LU_data, LU_pivots
7918
@skipCUDAIfNoMagmaAndNoCusolver
7919
@dtypes(*floating_and_complex_types())
7920
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7921
torch.float64: 1e-8, torch.complex128: 1e-8})
7922
def test_lu_solve(self, device, dtype):
7923
def sub_test(pivot):
7924
for k, n in zip([2, 3, 5], [3, 5, 7]):
7925
b, A, LU_data, LU_pivots = self.lu_solve_test_helper((n, n), (n, k), pivot, device, dtype)
7926
x = torch.lu_solve(b, LU_data, LU_pivots)
7927
self.assertEqual(b, np.matmul(A.cpu(), x.cpu()))
7930
if self.device_type == 'cuda':
7934
@skipCUDAIfNoMagmaAndNoCusolver
7935
@dtypes(*floating_and_complex_types())
7936
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
7937
torch.float64: 1e-8, torch.complex128: 1e-8})
7938
def test_lu_solve_batched(self, device, dtype):
7939
def sub_test(pivot):
7940
def lu_solve_batch_test_helper(A_dims, b_dims, pivot):
7941
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, pivot, device, dtype)
7943
for i in range(b_dims[0]):
7944
x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i]))
7945
x_exp = torch.stack(x_exp_list)
7946
x_act = torch.lu_solve(b, LU_data, LU_pivots)
7947
self.assertEqual(x_exp, x_act)
7948
Ax = np.matmul(A.cpu(), x_act.cpu())
7949
self.assertEqual(b, Ax)
7951
for batchsize in [1, 3, 4]:
7952
lu_solve_batch_test_helper((batchsize, 5, 5), (batchsize, 5, 10), pivot)
7955
b = torch.randn(3, 0, 3, dtype=dtype, device=device)
7956
A = torch.randn(3, 0, 0, dtype=dtype, device=device)
7957
LU_data, LU_pivots = torch.linalg.lu_factor(A)
7958
self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots))
7961
if self.device_type == 'cuda':
7966
@skipCUDAIfNoMagmaAndNoCusolver
7967
@dtypes(*floating_and_complex_types())
7968
def test_lu_solve_batched_many_batches(self, device, dtype):
7969
def run_test(A_dims, b_dims):
7970
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
7971
x = torch.lu_solve(b, LU_data, LU_pivots)
7972
Ax = torch.matmul(A, x)
7973
self.assertEqual(Ax, b.expand_as(Ax))
7975
run_test((65536, 5, 5), (65536, 5, 10))
7976
run_test((262144, 5, 5), (262144, 5, 10))
7979
@skipCUDAIfNoMagmaAndNoCusolver
7980
@dtypes(*floating_and_complex_types())
7981
def test_lu_solve_batched_broadcasting(self, device, dtype):
7982
make_fullrank = make_fullrank_matrices_with_distinct_singular_values
7983
make_A = partial(make_fullrank, device=device, dtype=dtype)
7985
def run_test(A_dims, b_dims, pivot=True):
7986
A_matrix_size = A_dims[-1]
7987
A_batch_dims = A_dims[:-2]
7988
A = make_A(*A_batch_dims, A_matrix_size, A_matrix_size)
7989
b = make_tensor(b_dims, dtype=dtype, device=device)
7990
x_exp = np.linalg.solve(A.cpu(), b.cpu())
7991
LU_data, LU_pivots = torch.linalg.lu_factor(A)
7992
x = torch.lu_solve(b, LU_data, LU_pivots)
7993
self.assertEqual(x, x_exp)
7996
run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6))
7997
run_test((2, 1, 3, 4, 4), (4, 6))
7998
run_test((4, 4), (2, 1, 3, 4, 2))
7999
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5))
8003
@dtypes(*floating_and_complex_types())
8005
def test_lu_solve_large_matrices(self, device, dtype):
8006
def run_test(A_dims, b_dims):
8007
b, A, LU_data, LU_pivots = self.lu_solve_test_helper(A_dims, b_dims, True, device, dtype)
8008
x = torch.lu_solve(b, LU_data, LU_pivots)
8009
Ax = torch.matmul(A, x)
8010
self.assertEqual(Ax, b.expand_as(Ax))
8012
run_test((1, 1), (1, 1, 1025))
8014
@skipCUDAIfNoCusolver
8016
def test_pca_lowrank(self, device):
8017
from torch.testing._internal.common_utils import random_lowrank_matrix, random_sparse_matrix
8019
dtype = torch.double
8021
def run_subtest(guess_rank, actual_rank, matrix_size, batches, device, pca, **options):
8022
density = options.pop('density', 1)
8023
use_svd_lowrank = options.pop('use_svd_lowrank', False)
8024
if isinstance(matrix_size, int):
8025
rows = columns = matrix_size
8027
rows, columns = matrix_size
8029
a_input = random_lowrank_matrix(actual_rank, rows, columns, *batches, device=device, dtype=dtype)
8032
a_input = random_sparse_matrix(rows, columns, density, device=device, dtype=dtype)
8033
a = a_input.to_dense()
8036
m = a_input.mean(dim=-2, keepdim=True)
8037
u, s, v = pca(a_input, q=guess_rank, M=m, **options)
8039
u, s, v = pca(a_input, q=guess_rank, **options)
8041
self.assertEqual(s.shape[-1], guess_rank)
8042
self.assertEqual(u.shape[-2], rows)
8043
self.assertEqual(u.shape[-1], guess_rank)
8044
self.assertEqual(v.shape[-1], guess_rank)
8045
self.assertEqual(v.shape[-2], columns)
8047
A1 = u.matmul(s.diag_embed()).matmul(v.mT)
8048
ones_m1 = torch.ones(batches + (rows, 1), dtype=a.dtype, device=device)
8049
c = a.sum(axis=-2) / rows
8050
c = c.reshape(batches + (1, columns))
8051
A2 = a - ones_m1.matmul(c)
8052
self.assertEqual(A1, A2)
8056
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
8057
self.assertEqual(actual_rank * torch.ones(batches, device=device, dtype=torch.int64), detect_rank)
8058
S = torch.linalg.svdvals(A2)
8059
self.assertEqual(s[..., :actual_rank], S[..., :actual_rank])
8061
all_batches = [(), (1,), (3,), (2, 3)]
8062
for actual_rank, size, all_batches in [
8063
(2, (17, 4), all_batches),
8064
(2, (100, 4), all_batches),
8065
(6, (100, 40), all_batches),
8066
(12, (1000, 1000), [()]),
8068
for batches in all_batches:
8074
if guess_rank <= min(*size):
8075
run_subtest(guess_rank, actual_rank, size, batches, device, torch.pca_lowrank)
8076
run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.pca_lowrank)
8077
run_subtest(guess_rank, actual_rank, size, batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8078
run_subtest(guess_rank, actual_rank, size[::-1], batches, device, torch.svd_lowrank, use_svd_lowrank=True)
8081
for guess_rank, size in [
8082
(4, (17, 4)), (4, (4, 17)), (16, (17, 17)),
8083
(21, (100, 40)), (20, (40, 100)), (600, (1000, 1000))]:
8084
for density in [0.005, 0.1]:
8085
run_subtest(guess_rank, None, size, (), device, torch.pca_lowrank, density=density)
8088
jitted = torch.jit.script(torch.pca_lowrank)
8089
guess_rank, actual_rank, size, batches = 2, 2, (17, 4), ()
8090
run_subtest(guess_rank, actual_rank, size, batches, device, jitted)
8093
@onlyNativeDeviceTypes
8096
@dtypes(torch.float32, torch.float64)
8097
def test_nuclear_norm_out(self, device, dtype):
8103
((25, 25, 25), (2, 0)),
8104
((25, 25, 25), (0, 1)),
8106
for keepdim in [False, True]:
8107
for input_size, dim in test_cases:
8108
msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}'
8109
x = torch.randn(*input_size, device=device, dtype=dtype)
8110
result_out = torch.empty(0, device=device, dtype=dtype)
8112
result = torch.nuclear_norm(x, keepdim=keepdim)
8113
torch.nuclear_norm(x, keepdim=keepdim, out=result_out)
8115
result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim)
8116
torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out)
8117
self.assertEqual(result, result_out, msg=msg)
8119
@skipCUDAIfNoMagmaAndNoCusolver
8121
@dtypes(*floating_and_complex_types())
8122
def test_geqrf(self, device, dtype):
8124
def run_test(shape):
8127
A = make_tensor(shape, dtype=dtype, device=device)
8131
tau_size = "n" if m > n else "m"
8132
np_dtype = A.cpu().numpy().dtype
8133
ot = [np_dtype, np_dtype]
8134
numpy_geqrf_batched = np.vectorize(
8135
lambda x: np.linalg.qr(x, mode='raw'),
8137
signature=f'(m,n)->(n,m),({tau_size})')
8139
expected = numpy_geqrf_batched(A.cpu())
8140
actual = torch.geqrf(A)
8143
self.assertEqual(expected[0].swapaxes(-2, -1), actual[0])
8144
self.assertEqual(expected[1], actual[1])
8146
batches = [(), (0, ), (2, ), (2, 1)]
8148
for batch, (m, n) in product(batches, product(ns, ns)):
8149
run_test((*batch, m, n))
8153
def test_lapack_empty(self, device):
8159
def fn(torchfn, *args):
8160
return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
8164
self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape)
8165
self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape)
8166
self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape)
8167
self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape)
8170
self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0)))
8171
self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0)))
8172
self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)),
8173
fn(torch.slogdet, (0, 0)))
8175
@tf32_on_and_off(0.005)
8176
@bf32_on_and_off(0.005)
8177
def test_tensordot(self, device):
8178
a = torch.arange(60., device=device).reshape(3, 4, 5)
8179
b = torch.arange(24., device=device).reshape(4, 3, 2)
8180
c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
8181
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8182
axes=([1, 0], [0, 1])))
8183
self.assertEqual(c, cn)
8185
cout = torch.zeros((5, 2), device=device)
8186
torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
8187
self.assertEqual(c, cout)
8189
a = torch.randn(2, 3, 4, 5, device=device)
8190
b = torch.randn(4, 5, 6, 7, device=device)
8191
c = torch.tensordot(a, b, dims=2).cpu()
8192
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
8195
with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
8196
torch.tensordot(a, b, dims=-1)
8198
self.assertEqual(c, cn)
8199
c = torch.tensordot(a, b).cpu()
8200
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
8201
self.assertEqual(c, cn)
8203
a = torch.tensordot(torch.tensor(0.), torch.tensor(0.), 0)
8204
an = torch.from_numpy(np.tensordot(np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0))
8205
self.assertEqual(a, an)
8207
@skipCUDAIfNoCusolver
8210
@skipIfTorchDynamo("flaky, needs investigation")
8211
@dtypes(*floating_and_complex_types())
8212
def test_ldl_factor(self, device, dtype):
8213
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8215
def run_test(shape, batch, hermitian):
8216
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8217
actual_factors, actual_pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8218
actual_L = torch.tril(actual_factors, diagonal=-1)
8219
actual_L.diagonal(0, -2, -1).fill_(1.0)
8225
self.assertTrue((actual_pivots > 0).all())
8228
actual_D = torch.diag_embed(actual_factors.diagonal(0, -2, -1))
8231
return x.mH if hermitian else x.mT
8232
A_reconstructed = actual_L @ actual_D @ T(actual_L)
8235
return A.tril() + A.tril(-1).mT
8237
self.assertEqual(symmetric(A) if not hermitian else A, A_reconstructed)
8241
from scipy.linalg import ldl as scipy_ldl
8242
A_np = A.cpu().numpy()
8243
np_dtype = A_np.dtype
8244
scipy_ldl_batched = np.vectorize(
8245
lambda x: scipy_ldl(x, hermitian=hermitian, lower=True),
8246
otypes=[np_dtype, np_dtype, np.dtype('int64')],
8247
signature='(m,m)->(m,m),(m,m),(m)')
8249
expected = scipy_ldl_batched(A_np)
8250
expected_L, expected_D, expected_pivots = expected
8252
if expected_pivots.ndim > 1:
8253
permuted_expected_L = np.stack(
8254
[expected_L[i][expected_pivots[i], :] for i in range(expected_pivots.shape[0])]
8257
permuted_expected_L = expected_L[expected_pivots, :]
8258
self.assertEqual(actual_L, permuted_expected_L)
8259
self.assertEqual(actual_D, expected_D)
8261
self.assertEqual(actual_factors.shape, A.shape)
8262
self.assertEqual(actual_pivots.shape, A.shape[:-1])
8263
self.assertEqual(info.shape, A.shape[:-2])
8266
magma_254_available = self.device_type == 'cuda' and _get_magma_version() >= (2, 5, 4)
8267
hermitians = (True, False) if dtype.is_complex and (self.device_type == 'cpu' or magma_254_available) else (False,)
8270
batches = ((), (4,),)
8271
for shape, batch, hermitian in itertools.product(shapes, batches, hermitians):
8272
run_test(shape, batch, hermitian)
8274
@skipCUDAIfNoCusolver
8278
@skipCUDAIf(_get_torch_cuda_version() < (11, 4), "not available before CUDA 11.3.1")
8279
@dtypes(*floating_and_complex_types())
8280
def test_ldl_solve(self, device, dtype):
8281
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
8283
def run_test(shape, batch, nrhs, hermitian):
8284
A = random_hermitian_pd_matrix(shape, *batch, dtype=dtype, device=device)
8285
B = make_tensor((*A.shape[:-1], nrhs), dtype=dtype, device=device)
8286
factors, pivots, info = torch.linalg.ldl_factor_ex(A, hermitian=hermitian)
8287
X = torch.linalg.ldl_solve(factors, pivots, B, hermitian=hermitian)
8290
return A.tril() + A.tril(-1).mT
8293
expected_B = symmetric(A) @ X if not hermitian else A @ X
8294
self.assertEqual(B, expected_B)
8297
hermitians = (True, False) if dtype.is_complex and self.device_type == 'cpu' else (False,)
8300
batches = ((), (4,), (2, 2))
8302
for shape, batch, nrhs, hermitian in itertools.product(shapes, batches, nrhss, hermitians):
8303
run_test(shape, batch, nrhs, hermitian)
8307
@skipCUDAIfNoCusolver
8308
@setLinalgBackendsToDefaultFinally
8309
def test_preferred_linalg_library(self):
8311
x = torch.randint(2, 5, (2, 4, 4), device='cuda', dtype=torch.double)
8313
torch.backends.cuda.preferred_linalg_library('cusolver')
8314
out1 = torch.linalg.inv(x)
8316
torch.backends.cuda.preferred_linalg_library('magma')
8317
out2 = torch.linalg.inv(x)
8319
torch.backends.cuda.preferred_linalg_library('default')
8322
out_ref = torch.linalg.inv(x.cpu())
8324
self.assertEqual(out_ref, out1.cpu())
8325
self.assertEqual(out1, out2)
8328
@unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device")
8329
@setBlasBackendsToDefaultFinally
8330
def test_preferred_blas_library(self):
8332
m1 = torch.randint(2, 5, (2048, 2400), device='cuda', dtype=torch.float)
8333
m2 = torch.randint(2, 5, (128, 2400), device='cuda', dtype=torch.float)
8335
torch.backends.cuda.preferred_blas_library('cublaslt')
8336
out1 = torch.nn.functional.linear(m1, m2)
8338
torch.backends.cuda.preferred_blas_library('cublas')
8339
out2 = torch.nn.functional.linear(m1, m2)
8343
out_ref = torch.nn.functional.linear(m1.cpu(), m2.cpu())
8345
self.assertEqual(out1, out2)
8346
self.assertEqual(out_ref, out2.cpu())
8348
def test_permute_matmul(self):
8349
a = torch.ones([2, 5, 24, 24])
8350
b = torch.ones([3, 2, 5, 24, 24])
8351
c = a.permute(0, 1, 3, 2).matmul(b)
8352
self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720])
8354
def test_lower_precision_accumulation_with_ref_path(self):
8358
def check_correctness(fn, dtype, *args):
8359
expected = fn(*args).to(dtype=dtype)
8360
with torch.backends.mkldnn.flags(enabled=False):
8362
lower_args = (arg.to(dtype=dtype) for arg in args)
8363
tmp_result = fn(*lower_args)
8366
assert (torch.all(c == expected)), "Incorrect result with\n" \
8367
f"expected: {expected}\n" \
8370
for dtype in [torch.bfloat16, torch.half]:
8371
for transa in [True, False]:
8372
for transb in [True, False]:
8373
a = torch.ones(300, 300)
8374
b = torch.ones(300, 300)
8376
a = a.transpose(0, 1).contiguous().transpose(0, 1)
8378
b = b.transpose(0, 1).contiguous().transpose(0, 1)
8379
check_correctness(torch.matmul, dtype, a, b)
8381
a = torch.ones(1, 1, 300)
8382
b = torch.ones(1, 300, 1)
8383
check_correctness(torch.bmm, torch.bfloat16, a, b)
8384
check_correctness(torch.bmm, torch.half, a, b)
8386
a = torch.ones(1, 1, 300)
8387
b = torch.ones(1, 300, 1)
8388
c = torch.ones(1, 1, 1)
8389
check_correctness(torch.baddbmm, torch.bfloat16, c, a, b)
8390
check_correctness(torch.baddbmm, torch.half, c, a, b)
8392
for dtype in [torch.bfloat16, torch.half]:
8393
for trans in [True, False]:
8394
c = torch.ones(300) * -300
8395
a = torch.ones(300, 300)
8397
a = a.transpose(0, 1).contiguous().transpose(0, 1)
8399
check_correctness(torch.mv, dtype, a, b)
8400
check_correctness(torch.addmv, dtype, c, a, b)
8404
check_correctness(torch.dot, torch.bfloat16, a, b)
8405
check_correctness(torch.dot, torch.half, a, b)
8407
@dtypes(torch.float, torch.double)
8408
@precisionOverride({torch.float32: 1e-4})
8409
def test_1_sized_with_0_strided(self, device, dtype):
8410
a = make_tensor((8, 1, 64), dtype=dtype, device=device)
8411
a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
8412
b = make_tensor((8, 64, 512), dtype=dtype, device=device)
8413
b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
8414
res = torch.bmm(a_strided, b_strided)
8415
expect = torch.from_numpy(
8416
a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(device=device, dtype=dtype)
8417
self.assertEqual(expect, res)
8419
instantiate_device_type_tests(TestLinalg, globals())
8421
if __name__ == '__main__':
8422
TestCase._default_dtype_check_enabled = True