8
from typing import Dict, List, Sequence
10
from functools import partial
11
from itertools import product, combinations, permutations
14
from torch import inf, nan
15
from torch.testing import make_tensor
16
from torch.testing._internal.common_dtype import (
17
all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and,
18
integral_types_and, floating_and_complex_types_and, all_types_and, all_types,
20
from torch.testing._internal.common_utils import (
21
TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
23
from torch.testing._internal.common_device_type import (
24
OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
25
onlyNativeDeviceTypes, onlyCUDA, largeTensorTest, ops, precisionOverride)
26
from torch.testing._internal.common_methods_invocations import (
27
ReductionOpInfo, ReductionPythonRefInfo, reduction_ops, reference_masked_ops)
30
def _generate_input(shape, dtype, device, with_extremal):
32
x = torch.tensor((), dtype=dtype, device=device)
34
if dtype.is_floating_point or dtype.is_complex:
36
if dtype == torch.bfloat16:
37
x = torch.randn(*shape, device=device) * random.randint(30, 100)
38
x = x.to(torch.bfloat16)
40
x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
41
x[torch.randn(*shape) > 0.5] = 0
42
if with_extremal and dtype.is_floating_point:
44
x[torch.randn(*shape) > 0.5] = float('nan')
45
x[torch.randn(*shape) > 0.5] = float('inf')
46
x[torch.randn(*shape) > 0.5] = float('-inf')
47
elif with_extremal and dtype.is_complex:
48
x[torch.randn(*shape) > 0.5] = complex('nan')
49
x[torch.randn(*shape) > 0.5] = complex('inf')
50
x[torch.randn(*shape) > 0.5] = complex('-inf')
51
elif dtype == torch.bool:
52
x = torch.zeros(shape, dtype=dtype, device=device)
53
x[torch.randn(*shape) > 0.5] = True
55
x = torch.randint(15, 100, shape, dtype=dtype, device=device)
60
def _rand_shape(dim, min_size, max_size):
63
shape.append(random.randint(min_size, max_size))
66
def _reduced_shape(shape, dim=None, keepdim=False):
67
"""Computes the expected reduced shape given dim and keepdim
70
shape: The shape to reduce
71
dim : The dimensions to reduce
72
keepdim: If true, reduced dimensions have size 1 in the reduced shape,
73
otherwise they are removed from the reduced shape.
79
return [1] * len(shape) if keepdim else []
82
dim = dim if isinstance(dim, Sequence) else [dim]
83
dim = {i if i >= 0 else len(shape) + i for i in dim}
86
for i, size in enumerate(shape):
94
class TestReductions(TestCase):
100
def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim):
101
"""Tests output shape for input with ndim and dim and keepdim kwargs"""
102
shape = torch.randint(2, 5, (ndim,)).tolist()
103
t = make_tensor(shape, dtype=torch.float, device=device)
104
args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim))
105
result = op(t, *args, **dim_keepdim, **kwargs)
106
expected_shape = _reduced_shape(shape, **dim_keepdim)
107
self.assertEqual(result.shape, expected_shape, f"""
108
expected output shape to be {expected_shape} but got {list(result.shape)}
109
for input shape {shape} and {dim_keepdim}
115
@ops(reduction_ops, dtypes=OpDTypes.none)
116
def test_dim_default(self, device, op: ReductionOpInfo):
117
"""Tests that the default dim reduces all dimensions."""
118
for ndim in range(3):
119
self._test_dim_keepdim(op, device, ndim=ndim)
121
@ops(reduction_ops, dtypes=OpDTypes.none)
122
def test_dim_default_keepdim(self, device, op: ReductionOpInfo):
123
"""Tests that the default dim, when keepdim=True, reduces all dimensions to size 1."""
124
for ndim in range(3):
125
self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True)
127
@ops(reduction_ops, dtypes=OpDTypes.none)
128
def test_dim_none(self, device, op: ReductionOpInfo):
129
"""Tests that dim=None reduces all dimensions."""
130
for ndim in range(3):
131
self._test_dim_keepdim(op, device, ndim=ndim, dim=None)
133
@ops(reduction_ops, dtypes=OpDTypes.none)
134
def test_dim_none_keepdim(self, device, op: ReductionOpInfo):
135
"""Tests that dim=None, when keepdim=True, reduces all dimensions to size 1."""
136
for ndim in range(3):
137
self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True)
139
@ops(reduction_ops, dtypes=OpDTypes.none)
140
def test_dim_single(self, device, op: ReductionOpInfo):
141
"""Tests that dim=i reduces dimension i."""
142
self._test_dim_keepdim(op, device, ndim=0, dim=0)
143
self._test_dim_keepdim(op, device, ndim=1, dim=0)
144
self._test_dim_keepdim(op, device, ndim=2, dim=-1)
145
self._test_dim_keepdim(op, device, ndim=3, dim=1)
147
@ops(reduction_ops, dtypes=OpDTypes.none)
148
def test_dim_single_keepdim(self, device, op: ReductionOpInfo):
149
"""Tests that dim=i, when keepdim=True, reduces dimension i to size 1."""
150
self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True)
151
self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True)
152
self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True)
153
self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True)
155
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
156
def test_dim_empty(self, device, op: ReductionOpInfo):
157
"""Tests that dim=[] is a no-op"""
158
self._test_dim_keepdim(op, device, ndim=0, dim=[])
159
self._test_dim_keepdim(op, device, ndim=2, dim=[])
161
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
162
def test_dim_empty_keepdim(self, device, op: ReductionOpInfo):
163
"""Tests that dim=[], when keepdim=True, is a no-op"""
164
self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True)
165
self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True)
167
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
168
def test_dim_multi(self, device, op: ReductionOpInfo):
169
"""Tests that dim=[i, j, ...] reduces dimensions i, j, ...."""
170
self._test_dim_keepdim(op, device, ndim=1, dim=[0])
171
self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
173
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
174
def test_dim_multi_keepdim(self, device, op: ReductionOpInfo):
175
"""Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1."""
176
self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True)
177
self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True)
179
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
180
def test_dim_multi_unsorted(self, device, op: ReductionOpInfo):
181
"""Tests that operator correctly handles unsorted dim list."""
182
self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2])
184
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
185
def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo):
186
"""Tests that operator correctly handles unsorted dim list when keepdim=True."""
187
self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True)
189
@ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
190
def test_dim_multi_duplicate(self, device, op: ReductionOpInfo):
191
"""Tests that an error is raised if dim has duplicate entries."""
192
with self.assertRaises(RuntimeError):
193
self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2])
195
@ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
196
def test_dim_multi_unsupported(self, device, op: ReductionOpInfo):
197
"""Tests that ops claiming to not support multi dim actually don't."""
198
with self.assertRaises(TypeError):
199
self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
201
@ops(reduction_ops, dtypes=OpDTypes.none)
202
def test_dim_offbounds(self, device, op: ReductionOpInfo):
203
"""Tests that passing an off-bounds dim throws"""
204
with self.assertRaises(IndexError):
205
self._test_dim_keepdim(op, device, ndim=2, dim=2)
207
@ops(reduction_ops, dtypes=OpDTypes.none)
208
def test_dim_ndim_limit(self, device, op: ReductionOpInfo):
209
"""Tests that an exception is raised when reducing a tensor with more
210
than 64 dims along some specific dimensions. dim=None is ok"""
211
t = make_tensor([1] * 65, dtype=torch.float, device=device)
212
with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
215
@ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported)
216
def test_identity(self, device, dtype, op: ReductionOpInfo):
217
"""Tests that the identity value is an identity for the operator"""
218
t = make_tensor((10,), dtype=dtype, device=device)
219
t[1::2] = op.identity
220
args, kwargs = next(op.generate_args_kwargs(t))
221
result = op(t[::2], *args, **kwargs)
222
result_with_identity = op(t, *args, **kwargs)
223
self.assertEqual(result, result_with_identity, """
224
Adding identity value to the input tensor should not change the result.
230
@ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported,
231
allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
232
def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo):
233
"""Tests that nan is propagated to the output by default"""
234
t = make_tensor((5,), dtype=dtype, device=device)
236
args, kwargs = next(op.generate_args_kwargs(t))
237
result = op(t, *args, **kwargs)
238
self.assertTrue(result.isnan())
240
@ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported,
241
allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
242
def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo):
243
"""Tests that NaN values do not affect the result."""
244
t = make_tensor((10,), dtype=dtype, device=device)
246
args, kwargs = next(op.generate_args_kwargs(t))
247
result = op(t[::2], *args, **kwargs)
248
result_with_nan = op(t, *args, **kwargs)
249
self.assertEqual(result, result_with_nan)
251
@ops(reduction_ops, dtypes=OpDTypes.supported)
252
def test_result_dtype(self, device, dtype, op: ReductionOpInfo):
253
"""Tests that the result has the correct dtype"""
254
t = make_tensor((5,), dtype=dtype, device=device)
255
args, kwargs = next(op.generate_args_kwargs(t))
256
result: torch.Tensor = op(t, *args, **kwargs)
257
is_integral = dtype in integral_types_and(torch.bool)
258
if op.promotes_int_to_float and is_integral:
259
self.assertTrue(torch.is_floating_point(result))
260
elif op.promotes_int_to_int64 and is_integral:
261
self.assertEqual(result.dtype, torch.int64)
262
elif op.result_dtype is not None:
263
self.assertEqual(result.dtype, op.result_dtype)
264
elif op.complex_to_real:
265
_complex_to_real_dtype_map = {
266
torch.complex128: torch.float64,
267
torch.complex64: torch.float32,
268
torch.complex32: torch.float16,
270
self.assertEqual(result.dtype, _complex_to_real_dtype_map.get(dtype, dtype))
272
self.assertEqual(result.dtype, dtype)
274
@ops(reduction_ops, dtypes=OpDTypes.none)
275
def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo):
276
"""Tests for consistent behavior when reducing over an empty slice.
278
The rules for reducing over an empty slice are as follows:
279
- Return the identity value if the operator has one
280
- Otherwise, return NaN if the operator promotes integral dtype to
281
floating point dtypes.
282
- Otherwise, raise an error
284
See discussion here https://github.com/pytorch/pytorch/issues/61901
286
t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
287
for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []:
288
args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
289
if op.identity is not None:
291
result = op(t, *args, dim=dim, **kwargs)
292
self.assertEqual(result, torch.full_like(result, op.identity))
293
elif op.promotes_int_to_float:
295
result = op(t, *args, dim=dim, **kwargs)
296
self.assertEqual(result, torch.full_like(result, torch.nan))
299
if isinstance(op, ReductionPythonRefInfo):
301
with self.assertRaises(RuntimeError):
302
op(t, *args, dim=dim, **kwargs)
304
with self.assertRaises(IndexError):
305
op(t, *args, dim=dim, **kwargs)
307
@ops(reduction_ops, dtypes=OpDTypes.none)
308
def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo):
309
"""Tests that reducing a nonempty slice of an empty tensor returns an
310
empty tensor with the dimensions reduced."""
311
t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
312
for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []:
313
args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
314
result = op(t, *args, dim=dim, **kwargs)
315
self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
317
def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
318
"""Helper method to test noncontiguous input tensors."""
319
assert not t.is_contiguous()
321
t_contig = t.contiguous()
322
for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs):
323
kwargs.update(reduction_kwargs)
324
result = op(t, *args, **kwargs)
325
expected = op(t_contig, *args, **kwargs)
326
self.assertEqual(result, expected)
329
def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo):
330
"""Tests reducing along noncontiguous innermost dimension."""
331
t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
332
self._test_noncontiguous(op, t[:, ::2], dim=1)
335
def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo):
336
"""Tests reducing along noncontiguous outermost dimension."""
337
t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
338
self._test_noncontiguous(op, t[::2, :], dim=0)
341
def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo):
342
"""Tests reducing all dimensions of a noncontiguous tensor."""
343
t = make_tensor((5, 5, 5), dtype=dtype, device=device, low=-1, high=1)
344
self._test_noncontiguous(op, t[::2, ::3, 1:-1:2])
347
def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo):
348
"""Tests reducing a transposed tensor."""
349
t = make_tensor((5, 5), dtype=dtype, device=device, low=-1, high=1)
350
self._test_noncontiguous(op, t.T)
353
def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo):
354
"""Tests reducing a tensor with expanded singleton dimensions."""
355
t = make_tensor((2, 3), dtype=dtype, device=device, low=-1, high=1)
356
self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1))
364
def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
365
"""Compares op against op.ref for the given input and reduction kwargs"""
366
for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs):
367
kwargs.update(reduction_kwargs)
368
result = op(t, *args, **kwargs)
369
expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs)
370
self.assertEqual(result, expected, exact_dtype=False)
372
@ops(filter(lambda op: op.ref is not None, reduction_ops),
373
allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
374
def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo):
375
"""Compares op against reference for scalar input tensors"""
376
self._test_ref(op, make_tensor([], dtype=dtype, device=device))
378
@ops(filter(lambda op: op.ref is not None, reduction_ops),
379
allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
380
def test_ref_small_input(self, device, dtype, op: ReductionOpInfo):
381
"""Compares op against reference for small input tensors"""
382
t = make_tensor((5, 3, 4, 2), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
383
self._test_ref(op, t)
384
for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
385
self._test_ref(op, t, dim=dim)
387
@ops(filter(lambda op: op.ref is not None, reduction_ops),
388
allowed_dtypes=[torch.float64])
389
def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo):
390
"""Compares op against reference for a large 1D input tensor to check stability"""
391
self._test_ref(op, make_tensor((2 ** 20,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
393
@ops(filter(lambda op: op.ref is not None, reduction_ops),
394
allowed_dtypes=[torch.float64])
395
def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo):
396
"""Compares op against reference for a large 2D input tensor to test parallelism"""
397
t = make_tensor((32, 2 ** 16), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)
398
self._test_ref(op, t, dim=1)
400
@largeTensorTest("8gb")
401
@ops(filter(lambda op: op.ref is not None, reduction_ops),
402
allowed_dtypes=[torch.float64])
403
def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo):
404
"""Compares op against reference for a very large input tensor that requires 64 bit indexing"""
405
self._test_ref(op, make_tensor((275000000,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
407
@ops(filter(lambda op: op.ref is not None, reduction_ops),
408
allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
409
def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo):
410
"""Compares op against reference for input tensors with duplicate values"""
411
t = make_tensor((4, 4), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
412
t[::2, ::2] = t[1::2, 1::2]
413
self._test_ref(op, t)
414
self._test_ref(op, t, dim=0)
415
self._test_ref(op, t, dim=1)
417
@ops(filter(lambda op: op.ref is not None, reduction_ops),
418
allowed_dtypes=[torch.float32, torch.complex64])
419
def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo):
420
"""Compares op against reference for input tensors with extremal values"""
421
t = make_tensor((5,), dtype=dtype, device=device, exclude_zero=True)
422
extremals = [0, 1, nan, inf, -inf]
423
for extremal in extremals:
425
self._test_ref(op, t)
431
def test_var_unbiased(self, device):
432
tensor = torch.randn(100, device=device)
433
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
434
self.assertEqual(tensor.var(), tensor.var(unbiased=True))
435
self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False))
437
tensor = torch.tensor([1.0, 2.0], device=device)
438
self.assertEqual(tensor.var(unbiased=True), 0.5)
439
self.assertEqual(tensor.var(unbiased=False), 0.25)
441
tensor = torch.tensor([1.0, 2.0, 3.0], device=device)
442
self.assertEqual(tensor.var(unbiased=True), 1.0)
443
self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0)
445
tensor = torch.randn(100, device=device)
446
self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True))
447
self.assertEqual(tensor.std(), tensor.std(unbiased=True))
448
self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False))
450
def test_var_stability(self, device):
451
tensor = torch.tensor([2281.5, 2281.25], device=device)
452
self.assertEqual(tensor.var(dim=0), 0.03125)
453
self.assertEqual(tensor.var(), 0.03125)
455
def test_sum_dim_reduction_uint8_overflow(self, device):
456
example = [[-1, 2, 1], [5, 3, 6]]
457
x = torch.tensor(example, dtype=torch.uint8, device=device)
458
self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
459
self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8, device=device))
460
self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8, device=device))
461
y = torch.tensor(example, dtype=torch.uint8, device=device)
462
torch.sum(x, 0, out=y)
463
self.assertEqual(x.sum(0, dtype=torch.uint8), y)
465
def test_dim_reduction_less_than_64(self, device):
467
x = torch.randn(sizes, device=device)
468
ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
471
with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
473
with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
477
@dtypes(torch.float, torch.bfloat16)
478
def test_dim_reduction_lastdim(self, device, dtype):
479
x = torch.randn(3, 5, 40, device=device, dtype=dtype)
482
ops = [torch.norm, torch.argmax, torch.argmin]
486
self.assertEqual(y, y2)
489
def test_logsumexp(self, device):
490
from scipy.special import logsumexp
491
a = torch.randn(5, 4, device=device)
494
actual = a.logsumexp(1)
495
expected = logsumexp(a.cpu().numpy(), 1)
496
self.assertEqual(expected.shape, actual.shape)
497
self.assertEqual(expected, actual)
500
b = torch.zeros(5, 2, device=device)
502
torch.logsumexp(a, 1, out=c)
503
self.assertEqual(expected, b[:, 0])
506
e = torch.randint(-100, 100, [5, 4], device=device)
507
actual = e.logsumexp(1).to(torch.float64)
508
expected = logsumexp(e.cpu().numpy(), 1)
509
self.assertEqual(expected.shape, actual.shape)
510
self.assertEqual(expected, actual)
513
@dtypes(torch.complex64, torch.complex128)
514
def test_logcumsumexp_complex(self, device, dtype):
522
from scipy.special import logsumexp
524
def zero_out_neg_inf(t):
526
idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0)
527
t[idx] = torch.real(t[idx]).to(t.dtype)
530
def standardize_phase(t):
531
t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi))
534
def logcumsumexp_slow(a, dim):
536
for i in range(a.size(dim)):
537
index = [slice(None, None, None) for _ in range(a.ndim)]
538
index[dim] = slice(None, i + 1, None)
539
a_inp = a[tuple(index)]
540
res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
541
res = np.concatenate(res_lst, axis=dim)
542
return torch.as_tensor(res)
544
def compare_logcumsumexp(a, expected=None):
545
for i in range(a.ndim):
546
actual = torch.logcumsumexp(a, dim=i)
549
expected2 = logcumsumexp_slow(a, dim=i)
554
actual = standardize_phase(actual)
555
expected2 = standardize_phase(expected2)
560
actual = zero_out_neg_inf(actual)
561
expected2 = zero_out_neg_inf(expected2)
562
self.assertEqual(expected2.shape, actual.shape)
563
self.assertEqual(expected2, actual)
567
a1 = torch.randn((5, 10), dtype=dtype, device=device)
568
compare_logcumsumexp(a1)
571
a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device)
572
compare_logcumsumexp(a2)
580
a3_input = torch.tensor([
590
a3_expected = torch.tensor([
596
inf + (np.pi / 4) * 1j,
603
compare_logcumsumexp(a3_input, a3_expected)
605
a4_input = torch.tensor([
612
a4_expected = torch.tensor([
620
compare_logcumsumexp(a4_input, a4_expected)
623
def test_sum_parallel(self, device):
630
for dim in range(len(size) + 1):
631
nv = np.round(np.random.rand(*size))
632
tv = torch.from_numpy(nv)
635
self.assertTrue(tv.numel() > 32768)
642
diff = np.abs(nvs - tvs.numpy()).sum()
643
self.assertEqual(diff, 0)
645
_run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
646
_run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
647
_run_test([1, 32 * 8 * 32 * 8])
648
_run_test([1, 32770])
652
def _testCSelection(self, torchfn, mathfn):
655
a = torch.rand(*size)
656
b = torch.rand(*size)
658
expected_c = torch.zeros(*size)
659
expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b))
660
self.assertEqual(expected_c, c, atol=0, rtol=0)
663
def test_max_elementwise(self, device):
664
self._testCSelection(torch.max, max)
667
def test_min_elementwise(self, device):
668
self._testCSelection(torch.min, min)
670
def test_all_any(self, device):
672
x = torch.ones(*size, device=device).byte()
673
self.assertTrue(x.all())
674
self.assertTrue(x.any())
677
self.assertFalse(x.all())
678
self.assertTrue(x.any())
681
self.assertFalse(x.all())
682
self.assertFalse(x.any())
685
self.assertTrue(x.all())
686
self.assertTrue(x.any())
688
x = torch.ones(*size, device=device).bool()
689
self.assertTrue(x.all())
690
self.assertTrue(x.any())
693
self.assertFalse(x.all())
694
self.assertTrue(x.any())
699
def test_all_any_with_dim(self, device):
701
r1 = x.prod(dim=0, keepdim=False).byte()
702
r2 = x.all(dim=0, keepdim=False)
703
self.assertEqual(r1.shape, r2.shape)
704
self.assertTrue((r1 == r2).all())
706
r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte()
707
r4 = x.any(dim=1, keepdim=True)
708
self.assertEqual(r3.shape, r4.shape)
709
self.assertTrue((r3 == r4).all())
711
test(torch.tensor([[0, 0, 0],
714
[1, 1, 1]], device=device, dtype=torch.uint8))
716
def test_numpy_named_args(self, device):
717
x1 = torch.randn(10, device=device)
718
x2 = torch.randn(10, device=device)
719
res1 = torch.add(input=x1, other=x2)
720
res2 = torch.add(x1=x1, x2=x2)
721
self.assertEqual(res1, res2)
723
x1 = torch.randn(10, 10, 10, device=device)
724
res1 = x1.sum(dim=(0, 2), keepdim=True)
725
res2 = x1.sum(axis=(0, 2), keepdims=True)
726
self.assertEqual(res1, res2)
729
def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
730
use_complex=False) -> Dict[str, List[torch.Tensor]]:
731
float_types = [torch.double,
733
int_types = [torch.int64,
737
complex_types = [torch.complex64,
740
def make_contiguous(shape, dtype) -> torch.Tensor:
741
if dtype in float_types:
742
val = torch.randn(shape, dtype=dtype)
743
val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
744
val = val + ((val_range[1] - val_range[0]) / 2.0)
745
val = torch.clamp(val, min=val_range[0], max=val_range[1])
747
result = torch.zeros(shape, dtype=dtype)
748
result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
751
def make_non_contiguous(shape, dtype) -> torch.Tensor:
752
contig = make_contiguous(shape, dtype)
753
non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
754
non_contig = non_contig.select(-1, -1)
755
non_contig.copy_(contig)
756
self.assertFalse(non_contig.is_contiguous())
759
def make_contiguous_slice(size, dtype) -> torch.Tensor:
760
contig = make_contiguous((1, size), dtype)
761
non_contig = contig[:1, 1:size - 1]
762
self.assertTrue(non_contig.is_contiguous())
771
types += complex_types
772
tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
774
tensors["cont"].append(make_contiguous(shape, dtype))
775
tensors["noncont"].append(make_non_contiguous(shape, dtype))
776
tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
781
def _assert_matches_numpy(self, t, n):
782
self.assertEqual(n.shape, t.shape)
783
if t.dtype == torch.float:
784
self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True)
786
self.assertEqual(n, t, equal_nan=True)
789
def _test_dim_ops(self, pytorch_op, numpy_op,
790
use_floating=True, use_integral=True, use_complex=False):
791
def do_one(tensors_dict, dim):
792
for category, tensors in tensors_dict.items():
793
if category == "slice":
795
for tensor in tensors:
797
with warnings.catch_warnings():
798
warnings.simplefilter("ignore")
799
expected = numpy_op(tensor.cpu().numpy(), dim)
800
actual = pytorch_op(tensor, dim)
801
self._assert_matches_numpy(actual, expected)
802
if torch.cuda.is_available():
803
self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected)
804
do_one(self._make_tensors((5, 400000), use_floating=use_floating,
805
use_integral=use_integral, use_complex=use_complex), 1)
806
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
807
use_integral=use_integral, use_complex=use_complex), 0)
808
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
809
use_integral=use_integral, use_complex=use_complex), 1)
810
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
811
use_integral=use_integral, use_complex=use_complex), 2)
812
do_one(self._make_tensors((100000, ), use_floating=use_floating,
813
use_integral=use_integral, use_complex=use_complex), -1)
814
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
815
use_integral=use_integral, use_complex=use_complex), 0)
816
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
817
use_integral=use_integral, use_complex=use_complex), 1)
818
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
819
use_integral=use_integral, use_complex=use_complex), 2)
820
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
821
use_integral=use_integral, use_complex=use_complex), (1, 2))
822
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
823
use_integral=use_integral, use_complex=use_complex), (1, -1))
824
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
825
use_integral=use_integral, use_complex=use_complex), (0, 2))
826
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
827
use_integral=use_integral, use_complex=use_complex), (0, 2, 1))
831
def test_sum_dim(self, device):
833
lambda t, d: t.sum(d),
834
lambda n, d: n.sum(d),
835
use_floating=True, use_integral=True, use_complex=True)
838
def test_mean_dim(self, device):
840
lambda t, d: t.mean(d),
841
lambda n, d: n.mean(d),
846
def test_std_dim(self, device):
847
for unbiased in [False, True]:
849
lambda t, d: t.std(d, unbiased=unbiased),
850
lambda n, d: n.std(d, ddof=1 if unbiased else 0),
854
def test_var_dim(self, device):
855
for unbiased in [False, True]:
857
lambda t, d: t.var(d, unbiased=unbiased),
858
lambda n, d: n.var(d, ddof=1 if unbiased else 0),
863
def test_logsumexp_dim(self, device):
864
from scipy.special import logsumexp
866
lambda t, d: t.logsumexp(d),
867
lambda n, d: logsumexp(n, d),
871
def test_mean_int_with_optdtype(self, device):
872
a = make_tensor((3, 4, 5), dtype=torch.int64, device=device)
876
a_float = a.to(torch.float32)
877
self.assertEqual(a_float.mean(), a.mean(dtype=torch.float32))
880
def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True):
882
reduced_shape = fn(torch.ones(shape)).shape
884
def _test_out(dtype, other_dtype):
885
out = torch.ones(reduced_shape, dtype=dtype)
886
result = fn(x, out=out)
887
self.assertIs(out.dtype, result.dtype)
888
self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
889
result = fn(x, out=out, dtype=dtype)
890
self.assertIs(out.dtype, result.dtype)
891
self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
893
self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype))
895
for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]:
896
x = torch.ones(shape, dtype=dtype)
897
expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64
898
self.assertIs(expected_dtype, fn(x).dtype)
899
self.assertEqual(fn(x.to(expected_dtype)), fn(x))
901
if dtype.is_floating_point:
902
other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
903
elif dtype.is_complex:
904
other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128
906
other_dtype = torch.int32 if dtype != torch.int32 else torch.int16
907
self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
908
self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False)
911
if dtype.is_floating_point:
912
mixed_dtypes = [torch.int32, torch.complex64]
913
elif dtype.is_complex:
914
mixed_dtypes = [torch.int32, torch.float32]
916
mixed_dtypes = [torch.float32, torch.complex64]
918
for mixed_dtype in mixed_dtypes:
919
self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
920
self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False)
923
_test_out(dtype, other_dtype)
924
_test_out(dtype, mixed_dtype)
927
def test_sum_integer_upcast(self, device):
928
self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False)
929
self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs))
932
def test_prod_integer_upcast(self, device):
933
self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False)
934
self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs))
937
def test_cumsum_integer_upcast(self, device):
938
self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
941
def test_cumprod_integer_upcast(self, device):
942
self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
944
@dtypes(*all_types())
945
def test_mode(self, device, dtype):
947
x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE)
953
res1val = torch.ones(SIZE, device=device, dtype=dtype)
955
res1ind = torch.ones(SIZE, device=device, dtype=torch.long)
956
res1ind[0] = SIZE - 1
957
res1ind[1] = SIZE - 1
959
res2val, res2ind = torch.mode(x, keepdim=False)
960
self.assertEqual(res1val, res2val, atol=0, rtol=0)
961
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
964
res2val = torch.tensor((), device=device, dtype=dtype)
965
res2ind = torch.tensor((), device=device, dtype=torch.long)
966
torch.mode(x, keepdim=False, out=(res2val, res2ind))
967
self.assertEqual(res1val, res2val, atol=0, rtol=0)
968
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
971
res2val, res2ind = torch.mode(x, 0, False)
972
self.assertEqual(res1val, res2val, atol=0, rtol=0)
973
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
976
self.assertEqual(x, x0, atol=0, rtol=0)
978
def _test_mode_intervals(self, shape, intervals, device, dtype, v=1):
979
x = torch.arange(0, shape[1], device=device, dtype=dtype).expand(shape)
981
x[:, v] = intervals[0][0]
984
for (beg, end) in intervals:
987
values, indices = torch.mode(x, -1, False)
990
self.assertTrue((x.gather(1, indices.unsqueeze(1)).t() == values).all())
992
self.assertTrue((values == v).all().item())
995
@dtypes(*all_types_and(torch.half, torch.bfloat16))
996
def test_mode_large(self, device, dtype):
998
def testset_for_shape(shape, i):
1001
self._test_mode_intervals(shape, [(i, d - i)], device, dtype)
1003
self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device, dtype)
1006
testset_for_shape((65536, 10), 3)
1009
testset_for_shape((10, 2048), 10)
1012
testset_for_shape((10, 4096), 10)
1014
def test_mode_boolean(self, device):
1021
for shape in shapes:
1022
a = torch.zeros(shape, device=device, dtype=torch.bool)
1024
a[:, (shape[1] - 1) // 2:] = True
1025
values, indices = a.mode(-1)
1026
self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
1028
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1029
self.assertEqual(values, indexed)
1032
a[:, shape[1] // 2 + 1:] = True
1033
values, indices = a.mode(-1)
1035
self.assertEqual(values, torch.zeros(shape[0], dtype=torch.bool))
1036
indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1037
self.assertEqual(values, indexed)
1040
@expectedFailureMeta
1041
@onlyNativeDeviceTypes
1042
def test_mode_wrong_dtype(self, device):
1043
def test_for_dtypes(x_ty, v_ty, i_ty, message):
1044
x = torch.ones(10, device=device, dtype=x_ty)
1045
v = torch.ones(10, device=device, dtype=v_ty)
1046
i = torch.ones(10, device=device, dtype=i_ty)
1048
with self.assertRaisesRegex(RuntimeError, message):
1049
torch.mode(x, -1, True, out=(v, i))
1051
err_msg = "expected scalar type .* but got .* for "
1052
values_err = err_msg + "values"
1053
indices_err = err_msg + "indices"
1055
test_for_dtypes(torch.uint8, torch.int8, torch.long, values_err)
1056
test_for_dtypes(torch.int8, torch.int16, torch.long, values_err)
1057
test_for_dtypes(torch.int32, torch.float32, torch.long, values_err)
1058
test_for_dtypes(torch.float32, torch.float64, torch.long, values_err)
1060
test_for_dtypes(torch.uint8, torch.uint8, torch.int8, indices_err)
1061
test_for_dtypes(torch.int8, torch.int8, torch.int16, indices_err)
1062
test_for_dtypes(torch.int32, torch.int32, torch.float32, indices_err)
1063
test_for_dtypes(torch.float32, torch.float32, torch.float64, indices_err)
1066
def test_mode_wrong_device(self, device):
1070
with self.assertRaisesRegex(RuntimeError,
1071
"expected device .* but got .* for values"):
1072
values = torch.tensor([], device=device)
1073
torch.mode(x, -1, True, out=(values, torch.tensor([], dtype=torch.long)))
1075
with self.assertRaisesRegex(RuntimeError,
1076
"expected device .* but got .* for indices"):
1077
indices = torch.tensor([], device=device)
1078
torch.mode(x, -1, True, out=(torch.tensor([]), indices))
1082
def test_accreal_type(self, device) -> None:
1083
x = torch.ones(2, 3, 4)
1084
self.assertIsInstance(x.double().sum().item(), float)
1085
self.assertIsInstance(x.float().sum().item(), float)
1086
self.assertIsInstance(x.long().sum().item(), int)
1087
self.assertIsInstance(x.int().sum().item(), int)
1088
self.assertIsInstance(x.short().sum().item(), int)
1089
self.assertIsInstance(x.char().sum().item(), int)
1090
self.assertIsInstance(x.byte().sum().item(), int)
1092
def test_var_mean_some_dims(self, device):
1093
sizes = (4, 6, 7, 5, 3)
1096
x = torch.rand(sizes, device=device)
1097
for num_of_dims in range(2, dims):
1098
dim_list = list(combinations(list(range(dims)), r=num_of_dims))
1099
for dim in dim_list:
1100
for unbiased in [False, True]:
1101
for keepdim in [False, True]:
1102
var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
1103
var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
1104
mean2 = x.mean(dim=dim, keepdim=keepdim)
1105
self.assertEqual(var1, var2)
1106
self.assertEqual(mean1, mean2)
1109
def test_all_any_empty(self, device):
1110
x = torch.ByteTensor().to(device)
1111
self.assertTrue(x.all())
1112
self.assertFalse(x.any())
1114
x = torch.BoolTensor().to(device)
1115
self.assertTrue(x.all())
1116
self.assertFalse(x.any())
1118
def test_all_issue117215(self, device):
1119
info = torch.iinfo(torch.uint8)
1120
a = torch.randint(info.min, info.max, (73, 11, 3, 17), dtype=torch.uint8)
1121
b = torch.all(a, dim=0)
1122
c = a.to(torch.bool).all(dim=0)
1123
self.assertEqual(torch.ne(b, c).sum(), 0)
1125
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1126
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1127
def test_max_with_inf(self, device, dtype):
1128
a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1129
self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item())
1130
self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item())
1131
self.assertTrue(torch.max(a).item() == inf)
1132
self.assertTrue(torch.amax(a).item() == inf)
1134
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1135
@dtypes(torch.half, torch.float, torch.bfloat16, torch.double)
1136
def test_min_with_inf(self, device, dtype):
1137
a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1138
self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item())
1139
self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item())
1140
self.assertTrue(torch.min(a).item() == -inf)
1141
self.assertTrue(torch.amin(a).item() == -inf)
1143
def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False):
1144
def create_input(shape, device, dtype):
1145
if dtype.is_floating_point:
1146
return torch.randn(*shape, device=device, dtype=dtype)
1148
low = 0 if dtype == torch.bool else -1000
1149
high = 2 if dtype == torch.bool else 1000
1150
return torch.randint(low, high, shape, device=device, dtype=dtype)
1151
x = create_input((100, 100), device, dtype)
1152
self.compare_with_numpy(torchfn, reffn, x)
1154
x = create_input((10, 10, 10), device, dtype)
1156
self.compare_with_numpy(torchfn, reffn, x)
1159
if isinstance(x, tuple):
1164
if not skip_indices:
1166
x = create_input((size, size), device, dtype)
1169
for xinp, d in product(inputs, dims):
1170
self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp)
1171
result = torchfn(xinp, d, False)
1172
if isinstance(result, tuple):
1175
self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0)
1177
self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0)
1179
if dtype.is_floating_point:
1180
for index in (0, 4, 99):
1181
x = create_input((100,), device, dtype)
1183
if not skip_indices:
1184
result = torchfn(x, 0)
1185
v = get_values(result)
1186
self.assertEqual(v, nan)
1187
if isinstance(result, tuple):
1189
self.assertEqual(i, index)
1190
self.assertEqual(torchfn(x), nan)
1192
@dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1193
@dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1194
@dtypes(torch.half, torch.float, torch.double)
1195
def test_max(self, device, dtype):
1196
self._test_minmax_helper(torch.max, np.amax, device, dtype)
1198
@dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1199
@dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1200
@dtypes(torch.half, torch.float, torch.double)
1201
def test_min(self, device, dtype):
1202
self._test_minmax_helper(torch.min, np.amin, device, dtype)
1204
@dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1205
@dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1206
@dtypes(torch.half, torch.float, torch.double)
1207
def test_amin(self, device, dtype):
1208
self._test_minmax_helper(torch.amin, np.amin, device, dtype)
1210
@dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1211
@dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1212
@dtypes(torch.float, torch.double)
1213
def test_amax(self, device, dtype):
1214
self._test_minmax_helper(torch.amax, np.amax, device, dtype)
1216
@onlyNativeDeviceTypes
1217
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
1218
@dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
1219
def test_aminmax(self, device, dtype):
1221
def _amin_wrapper(x, dim=None, keepdims=False):
1222
with self.assertWarnsOnceRegex(UserWarning, "_aminmax is deprecated"):
1224
return torch._aminmax(x)[0]
1226
return torch._aminmax(x, dim, keepdims)[0]
1228
def _amax_wrapper(x, dim=None, keepdims=False):
1229
with self.assertWarnsOnceRegex(UserWarning, "_aminmax is deprecated"):
1231
return torch._aminmax(x)[1]
1233
return torch._aminmax(x, dim, keepdims)[1]
1235
self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype)
1236
self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype)
1240
def test_bincount(self, device):
1242
with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1243
torch.bincount(torch.tensor([1, -1], device=device))
1245
with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1246
torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
1248
with self.assertRaisesRegex(RuntimeError, 'not implemented'):
1249
torch.bincount(torch.tensor([1., 0.3], device=device))
1251
with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
1252
torch.bincount(torch.tensor([1, 3], device=device),
1253
torch.tensor([.2, .2], device=device),
1256
with self.assertRaisesRegex(RuntimeError, '1-d'):
1257
torch.bincount(torch.tensor([1, 0], device=device),
1258
torch.tensor([[1., 0.3], [1., 0.3]], device=device))
1260
with self.assertRaisesRegex(RuntimeError, 'same length'):
1261
torch.bincount(torch.tensor([1, 0], device=device),
1262
torch.tensor([1., 0.3, 0.5], device=device))
1264
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
1265
torch.zeros(0, dtype=torch.long, device=device))
1267
self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
1268
torch.zeros(10, dtype=torch.long, device=device))
1271
long_counts = torch.tensor(
1272
[0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
1274
torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
1277
count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
1278
count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
1279
self.assertEqual(count_uint8, count_int16)
1281
int_counts = torch.bincount(
1282
torch.tensor([1, 1, 1, 1], device=device), minlength=5)
1284
torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
1287
byte_counts = torch.bincount(
1288
torch.tensor([0, 1, 1, 1, 4], device=device),
1289
torch.tensor([.1, .2, .3, .4, .5], device=device))
1291
torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
1292
byte_counts = torch.bincount(
1293
torch.tensor([0, 1, 1, 1, 4], device=device),
1294
torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
1296
torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts)
1298
inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
1299
weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
1301
assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
1302
assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
1304
self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
1307
inputs[:, 1].bincount(weights[:, 1]),
1308
torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1310
self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
1311
torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1314
all0s = torch.zeros((32, 2), dtype=torch.int64, device=device)
1315
self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
1317
all1s = torch.ones((32, 2), dtype=torch.int64, device=device)
1318
self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
1321
big_exp = torch.zeros(10000000, device=device)
1323
big_w = torch.tensor([.5] * 100, device=device)
1324
big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w)
1325
self.assertEqual(big_exp, big_out)
1327
big_exp = torch.zeros(2, device=device, dtype=torch.int64)
1328
big_exp[1] = 1000000
1329
big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount()
1330
self.assertEqual(big_exp, big_out)
1333
def test_var_stability2(self, device):
1334
tensor = torch.FloatTensor([2281.5, 2281.25]).to(device)
1337
self.assertEqual(tensor.var(0), 0.03125)
1340
self.assertEqual(tensor.var(), 0.03125)
1343
tensor = tensor.unsqueeze(1)
1344
self.assertEqual(tensor.var(0), 0.03125)
1347
@dtypes(torch.bool, torch.double)
1348
def test_sum_all(self, device, dtype) -> None:
1349
def check_sum_all(tensor: torch.Tensor) -> None:
1350
pylist = tensor.reshape(-1).tolist()
1351
self.assertEqual(tensor.sum(), sum(pylist))
1353
if dtype != torch.bool:
1354
check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device))
1355
check_sum_all(torch.randn(200000, dtype=dtype, device=device))
1356
check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0])
1358
check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device))
1360
def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
1361
memory_format, compare_data=True, default_is_preserve=False):
1363
assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
1366
xc = input_generator_fn(device)
1368
if memory_format == torch.channels_last:
1369
xc = xc[..., ::2, ::2]
1371
xc = xc[..., ::2, ::2, ::2]
1373
clone = transformation_fn(xc, memory_format=torch.preserve_format)
1374
self.assertFalse(clone.is_contiguous())
1375
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1376
self.assertFalse(xc.is_contiguous())
1377
self.assertFalse(xc.is_contiguous(memory_format=memory_format))
1379
self.assertEqual(xc, clone.to(xc))
1381
xc = input_generator_fn(device)
1382
clone = transformation_fn(xc, memory_format=torch.contiguous_format)
1383
self.assertTrue(clone.is_contiguous())
1384
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1386
self.assertEqual(xc, clone.to(xc))
1388
xc = input_generator_fn(device)
1389
clone = transformation_fn(xc)
1391
if default_is_preserve:
1392
self.assertFalse(clone.is_contiguous())
1393
self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1395
self.assertTrue(clone.is_contiguous())
1396
self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1398
self.assertEqual(xc, clone.to(xc))
1400
x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
1402
permutation = list(range(len(x.shape)))
1403
random.shuffle(permutation)
1404
x = x.permute(permutation)
1405
self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
1408
@dtypes(torch.double)
1409
def test_sum_out(self, device, dtype: torch.dtype) -> None:
1410
x = torch.rand(100, 100, dtype=dtype, device=device)
1411
res1 = torch.sum(x, 1)
1412
res2 = torch.tensor((), dtype=dtype, device=device)
1413
torch.sum(x, 1, out=res2)
1414
self.assertEqual(res1, res2)
1415
x = torch.rand(100, 100, 100, dtype=dtype, device=device)
1416
res1 = x.sum(2).sum(1)
1417
res2 = torch.tensor((), dtype=dtype, device=device)
1418
torch.sum(x, (2, 1), out=res2)
1419
self.assertEqual(res1, res2)
1422
@dtypes(torch.float16, torch.float32)
1423
def test_prod_gpu(self, device, dtype):
1424
x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device)
1428
for dtype_output in [torch.float16, torch.float32]:
1429
result_expected = torch.tensor(2592, dtype=dtype_output, device=device)
1430
output = torch.prod(x, dtype=dtype_output)
1431
self.assertEqual(output, result_expected)
1433
output = x.prod(dtype=dtype_output)
1434
self.assertEqual(output, result_expected)
1437
@dtypes(torch.float)
1438
def test_prod(self, device, dtype):
1439
x = torch.rand(100, 100, dtype=dtype, device=device)
1440
res1 = torch.prod(x, 1)
1441
res2 = torch.tensor((), dtype=dtype, device=device)
1442
torch.prod(x, 1, out=res2)
1443
self.assertEqual(res1, res2)
1446
@dtypes(torch.float16, torch.bfloat16)
1447
def test_prod_lowp(self, device, dtype):
1448
x = torch.rand(100, 100, dtype=dtype, device=device)
1450
res1 = torch.prod(x, 1)
1451
res2 = torch.prod(x_ref, 1)
1452
self.assertEqual(res1, res2.to(dtype=dtype))
1453
res1 = torch.prod(x, 0)
1454
res2 = torch.prod(x_ref, 0)
1455
self.assertEqual(res1, res2.to(dtype=dtype))
1457
def test_prod_bool(self, device):
1458
vals = [[True, True], [True, False], [False, False], []]
1460
result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
1461
expect = np.prod(np.array(val), dtype=bool)
1462
self.assertEqual(result, expect)
1464
result = torch.prod(torch.tensor(val, device=device)).item()
1465
expect = np.prod(np.array(val))
1466
self.assertEqual(result, expect)
1469
def test_max_mixed_devices(self, device):
1470
a = torch.randn(10, device=device)
1471
if torch.cuda.is_available():
1472
values = torch.randn(10).cuda()
1473
indices = torch.cuda.LongTensor()
1474
self.assertRaises(RuntimeError,
1475
lambda: torch.max(a, 0, out=(values, indices)))
1476
self.assertRaises(RuntimeError,
1477
lambda: torch.amax(a, 0, out=values))
1480
def test_min_mixed_devices(self, device):
1481
a = torch.randn(10, device=device)
1482
if torch.cuda.is_available():
1483
values = torch.randn(10).cuda()
1484
indices = torch.cuda.LongTensor()
1485
self.assertRaises(RuntimeError,
1486
lambda: torch.min(a, 0, out=(values, indices)))
1487
self.assertRaises(RuntimeError,
1488
lambda: torch.amin(a, 0, out=values))
1491
def test_bucketization(self, device):
1492
values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device)
1493
values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1496
boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device)
1497
expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device)
1498
output = torch.empty(2, 2, 3, device=device, dtype=torch.int64)
1499
self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result)
1500
self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result)
1501
expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1502
self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result)
1503
self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result)
1506
for dtype in [torch.float32, torch.float16]:
1507
values_1d_float = values_1d.to(dtype)
1508
boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=dtype)
1509
expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1510
self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result)
1511
self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1514
boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64)
1515
values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64)
1516
expected_result = values_0_el.to(torch.int64)
1517
self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result)
1518
self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result)
1521
values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64)
1522
boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64)
1523
expected_result = torch.tensor([1, 4, 2, 4], device=device)
1524
self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result)
1525
expected_result = torch.tensor([2, 4, 3, 4], device=device)
1526
self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result)
1527
self.assertEqual(torch.searchsorted(boundaries, values_nan, side='right'), expected_result)
1530
values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32)
1531
boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64)
1532
expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device)
1533
if self.device_type != 'xla':
1534
self.assertWarnsRegex(
1535
UserWarning, "tensor is non-contiguous",
1536
lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result))
1539
self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)
1542
boundaries = torch.tensor([1.5, 2.5, 3.5], device=device)
1543
expected_result = torch.tensor(1, device=device)
1544
self.assertEqual(torch.searchsorted(boundaries, 2), expected_result)
1545
self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result)
1546
expected_result = torch.tensor(3, device=device)
1547
scalar_tensor_nan = torch.tensor(float('nan'), device=device)
1548
self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result)
1549
self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result)
1552
boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device)
1553
with self.assertRaisesRegex(
1554
RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"):
1555
torch.searchsorted(boundaries, values_3d)
1556
with self.assertRaisesRegex(
1557
RuntimeError, "boundaries tensor must be 1 dimension"):
1558
torch.bucketize(values_3d, boundaries)
1559
with self.assertRaisesRegex(
1560
RuntimeError, "only when boundaries tensor dimension is 1"):
1561
torch.searchsorted(boundaries, 1)
1564
def test_output_dtype(dtype, is_int32):
1565
output = values_1d.to(dtype)
1566
with self.assertRaisesRegex(
1567
RuntimeError, "output tensor's dtype is wrong"):
1568
torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32)
1570
test_output_dtype(torch.float32, False)
1571
test_output_dtype(torch.int32, False)
1572
test_output_dtype(torch.int64, True)
1575
with self.assertRaisesRegex(RuntimeError, "side can only be 'left' or 'right'"):
1576
torch.searchsorted(values_1d, values_1d, side='bad')
1579
with self.assertRaisesRegex(RuntimeError, "boundary and sorter must have the same size"):
1580
sequence = torch.rand_like(values_1d, dtype=torch.float)
1581
_, sorted_idx = torch.sort(sequence)
1582
torch.searchsorted(sequence, values_1d, sorter=sorted_idx[:-1])
1585
with self.assertRaisesRegex(RuntimeError, "sorter must be a tensor of long dtype"):
1586
sequence = torch.rand_like(values_1d, dtype=torch.float)
1587
_, sorted_idx = torch.sort(sequence)
1588
torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32))
1591
with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1592
torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3]))
1595
with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1596
torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2]))
1599
if self.device_type == 'cpu':
1600
def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False):
1601
values_1d_float = values_1d.to(torch.float32)
1602
boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32)
1604
values_1d_float = values_1d_float.to(torch.bfloat16)
1606
boundaries = boundaries.to(torch.bfloat16)
1607
expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1608
self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1610
test_dtype_bfloat16(True, False)
1611
test_dtype_bfloat16(False, True)
1612
test_dtype_bfloat16(True, True)
1614
@dtypes(*all_types_and(torch.half, torch.bfloat16))
1615
def test_nansum(self, device, dtype):
1620
zero = torch.zeros((), device=device, dtype=dtype)
1622
for noncontiguous, dim in args:
1624
scale = random.randint(10, 100)
1625
x = make_tensor((17, 17), device=device, dtype=dtype,
1626
low=-scale, high=scale, noncontiguous=noncontiguous)
1628
if dtype.is_floating_point:
1629
nan_mask = x < 0.2 * scale
1630
x_nonan = torch.where(nan_mask, zero, x)
1631
x[nan_mask] = np.nan
1635
dim_kwargs = {} if dim is None else {"dim": dim}
1636
expect = torch.sum(x_nonan, **dim_kwargs)
1637
actual = torch.nansum(x, **dim_kwargs)
1638
self.assertEqual(expect, actual)
1640
def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype,
1641
with_extremal=False, atol=None, rtol=None,
1642
exact_dtype=True, with_keepdim=False):
1644
for ndims in range(0, 4):
1645
shape = _rand_shape(ndims, min_size=5, max_size=10)
1646
for n in range(ndims + 1):
1647
for c in combinations(list(range(ndims)), n):
1648
for count_dim in permutations(c):
1650
x = _generate_input(shape, dtype, device, with_extremal)
1654
self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None,
1655
atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1659
torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim)
1660
np_func_partial = partial(np_func, keepdims=True, axis=count_dim)
1662
torch_func_partial = partial(torch_func, dim=count_dim)
1663
np_func_partial = partial(np_func, axis=count_dim)
1664
self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None,
1665
atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1667
@dtypes(*all_types_and_complex_and(torch.half))
1668
def test_count_nonzero(self, device, dtype):
1669
self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype)
1670
self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True)
1672
def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False):
1673
def is_integral(dtype):
1674
return dtype in integral_types()
1682
if IS_WINDOWS and is_integral(dtype):
1686
if dtype == torch.uint8:
1690
if dtype == torch.float16:
1693
elif dtype == torch.float32:
1700
self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
1701
atol=atol, rtol=rtol, exact_dtype=exact_dtype,
1702
with_keepdim=with_keepdim, with_extremal=with_extremal)
1704
@onlyNativeDeviceTypes
1705
@dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1706
def test_sum_vs_numpy(self, device, dtype):
1707
self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype)
1708
self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True)
1709
self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True)
1711
@onlyNativeDeviceTypes
1712
@dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1713
def test_nansum_vs_numpy(self, device, dtype):
1714
self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype)
1715
self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True)
1716
self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True)
1719
@dtypes(*complex_types())
1720
def test_nansum_complex(self, device, dtype):
1721
x = torch.randn((3, 3, 3), device=device, dtype=dtype)
1722
with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"):
1725
@dtypes(*all_types_and(torch.half))
1726
def test_nansum_out_dtype(self, device, dtype):
1728
inp_dtypes = all_types_and(torch.half) if out_dtype.is_floating_point else integral_types()
1729
for inp_dtype in inp_dtypes:
1730
shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10)
1731
x = _generate_input(shape, inp_dtype, device, with_extremal=False)
1732
torch_fn = partial(torch.nansum, dtype=out_dtype)
1733
np_out_dtype = torch_to_numpy_dtype_dict[out_dtype]
1734
np_fn = partial(np.nansum, dtype=np_out_dtype)
1735
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
1737
@dtypes(*all_types_and(torch.half))
1738
def test_argminmax_multiple(self, device, dtype):
1740
t = torch.ones(3, 3, device=device, dtype=dtype)
1741
self.compare_with_numpy(torch.argmax, np.argmax, t)
1742
self.compare_with_numpy(torch.argmin, np.argmin, t)
1745
if dtype in floating_types_and(torch.half, torch.bfloat16):
1746
t[2, 2] = float('nan')
1747
self.compare_with_numpy(torch.argmax, np.argmax, t)
1748
self.compare_with_numpy(torch.argmin, np.argmin, t)
1751
for ndims in range(1, 5):
1752
shape = _rand_shape(ndims, min_size=5, max_size=10)
1753
for with_extremal in [False, True]:
1754
for contiguous in [False, True]:
1756
x = _generate_input(shape, dtype, device, with_extremal)
1758
if dtype == torch.half:
1759
max_val = torch.max(x.to(torch.float))
1760
min_val = torch.min(x.to(torch.float))
1762
max_val = torch.max(x)
1763
min_val = torch.min(x)
1765
mask = torch.randn(x.shape) > 0.5
1766
x[mask] = torch.tensor(max_val + 1, dtype=dtype)
1768
mask = torch.randn(x.shape) > 0.5
1769
x[mask] = torch.tensor(min_val - 1, dtype=dtype)
1774
self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None)
1775
self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None)
1778
if dtype != torch.half:
1779
rand_dim = random.randint(0, ndims - 1)
1780
self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1],
1781
lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None)
1782
self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1],
1783
lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None)
1785
def verify_against_numpy(t):
1787
torch_fn = partial(torch.argmax, dim=1)
1788
np_fn = partial(np.argmax, axis=1)
1789
self.compare_with_numpy(torch_fn, np_fn, t)
1791
self.compare_with_numpy(torch_fn, np_fn, t.T)
1794
if dtype != torch.half:
1795
self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1796
self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1799
torch_fn = partial(torch.argmin, dim=1)
1800
np_fn = partial(np.argmin, axis=1)
1801
self.compare_with_numpy(torch_fn, np_fn, t)
1803
self.compare_with_numpy(torch_fn, np_fn, t.T)
1806
if dtype != torch.half:
1807
self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1808
self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1811
t = torch.tensor([[1, 5],
1813
[3, 3]], device=device, dtype=dtype)
1814
verify_against_numpy(t)
1817
t = torch.tensor([[1, 5],
1819
[0, 0]], device=device, dtype=dtype)
1820
verify_against_numpy(t)
1822
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1823
def test_all_any_vs_numpy(self, device, dtype):
1827
exact_dtype = True if dtype != torch.uint8 else False
1829
def _test_all_any(x):
1830
self.compare_with_numpy(torch.all, np.all, x)
1831
self.compare_with_numpy(torch.any, np.any, x)
1833
def _test_all_any_with_dim(x, dim):
1834
torch_fn = partial(torch.all, dim=dim)
1835
np_fn = partial(np.all, axis=dim)
1836
self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1838
torch_fn = partial(torch.any, dim=dim)
1839
np_fn = partial(np.any, axis=dim)
1840
self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1842
def _test_out_variant(x, dim):
1843
out = torch.empty_like(x)
1844
if dtype == torch.bool or dtype == torch.uint8:
1845
expected = torch.all(x, dim)
1846
torch.all(x, dim, out=out)
1847
self.assertEqual(expected, out)
1849
expected = torch.any(x, dim)
1850
torch.any(x, dim, out=out)
1851
self.assertEqual(expected, out)
1853
with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"):
1854
torch.all(x, dim, out=out)
1856
with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"):
1857
torch.any(x, dim, out=out)
1859
def _test_all_any_with_dim_keepdim(x, dim, keepdim):
1860
torch_fn = partial(torch.all, dim=dim, keepdim=keepdim)
1861
np_fn = partial(np.all, axis=dim, keepdims=keepdim)
1862
self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1864
torch_fn = partial(torch.any, dim=dim, keepdim=keepdim)
1865
np_fn = partial(np.any, axis=dim, keepdims=keepdim)
1866
self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1868
def _test_output_dtype(x):
1871
expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool
1872
self.assertEqual(torch.all(x).dtype, expected_dtype)
1873
self.assertEqual(torch.any(x).dtype, expected_dtype)
1875
self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype)
1876
self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype)
1878
for ndim in range(5):
1879
shape = _rand_shape(ndim, 1, 5)
1880
x = _generate_input(shape, dtype, device, with_extremal=False)
1883
_test_all_any(x[..., ::2])
1885
x = _generate_input(shape, dtype, device, with_extremal=True)
1888
_test_all_any(x[..., ::2])
1890
x = torch.zeros_like(x)
1893
_test_all_any(x[..., ::2])
1895
x = torch.ones_like(x)
1898
_test_all_any(x[..., ::2])
1899
_test_output_dtype(x)
1900
for dim in range(ndim):
1901
x = _generate_input(shape, dtype, device, with_extremal=False)
1902
_test_all_any_with_dim(x, dim)
1903
_test_all_any_with_dim(x.T, dim)
1904
_test_all_any_with_dim(x[..., ::2], dim)
1905
_test_out_variant(x, dim)
1906
_test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1907
_test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1909
x = _generate_input(shape, dtype, device, with_extremal=True)
1910
_test_all_any_with_dim(x, dim)
1911
_test_all_any_with_dim(x.T, dim)
1912
_test_all_any_with_dim(x[..., ::2], dim)
1913
_test_out_variant(x, dim)
1914
_test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1915
_test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1917
x = torch.zeros_like(x)
1918
_test_all_any_with_dim(x, dim)
1919
_test_all_any_with_dim(x.T, dim)
1920
_test_all_any_with_dim(x[..., ::2], dim)
1921
_test_out_variant(x, dim)
1922
_test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1923
_test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1925
x = torch.ones_like(x)
1926
_test_all_any_with_dim(x, dim)
1927
_test_all_any_with_dim(x.T, dim)
1928
_test_all_any_with_dim(x[..., ::2], dim)
1929
_test_out_variant(x, dim)
1930
_test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1931
_test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1934
@onlyNativeDeviceTypes
1935
def test_repeated_dim(self, device):
1936
ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
1938
x = torch.randn(3, 3, 3, 3, device=device)
1940
error_msg = r'appears multiple times in the list of dims'
1942
for dim in [(0, 0), (0, -4)]:
1943
with self.assertRaisesRegex(RuntimeError, error_msg):
1948
def test_var(self, device):
1949
cpu_tensor = torch.randn(2, 3, 3)
1950
device_tensor = cpu_tensor.to(device)
1951
self.assertEqual(device_tensor.var(), cpu_tensor.var())
1952
self.assertEqual(device_tensor.var(1), cpu_tensor.var(1))
1953
self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
1954
self.assertEqual(device_tensor.std(), cpu_tensor.std())
1955
self.assertEqual(device_tensor.std(1), cpu_tensor.std(1))
1956
self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
1958
cpu_tensor = torch.randn(100)
1959
device_tensor = cpu_tensor.to(device)
1960
self.assertEqual(device_tensor.var(), cpu_tensor.var())
1964
def test_var_large_input(self, device):
1966
cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67)
1967
device_tensor = cpu_tensor.to(device)
1969
self.assertEqual(cpu_tensor.var(2), device_tensor.var(2))
1973
@dtypes(torch.double)
1974
def test_sum_noncontig(self, device, dtype):
1975
x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2)
1977
self.assertEqual(x.sum().cpu(), y.sum())
1978
self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2)))
1979
self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3)))
1983
def test_min_max_nan(self, device):
1984
tests = [(lambda x: x.min(), 'min'),
1985
(lambda x: x.max(), 'max'),
1986
(lambda x: x.amin(), 'amin'),
1987
(lambda x: x.amax(), 'amax'),
1988
(lambda x: x.min(0).values, 'min_dim'),
1989
(lambda x: x.max(0).values, 'max_dim'),
1990
(lambda x: x.amin(0), 'amin_dim'),
1991
(lambda x: x.amax(0), 'amax_dim')]
1992
for f, name in tests:
1993
a = torch.arange(25.0).view(5, 5)
1995
actual = f(a.to(device)).cpu()
1996
expected = f(a).cpu()
1997
self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg=f'nans for {name}')
1998
self.assertEqual(actual[~torch.isnan(actual)],
1999
expected[~torch.isnan(expected)], msg=f'nans for {name}')
2003
def test_sum_cpu_device_mismatch(self, device):
2004
x = torch.randn(20, dtype=torch.float32, device=device)
2005
y = torch.randn(1, dtype=torch.float32)
2007
err_string = f"Expected out tensor to have device {device}, but got cpu instead"
2009
with self.assertRaisesRegex(RuntimeError, err_string):
2010
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2013
if self.device_type == 'cuda':
2015
with self.assertRaisesRegex(RuntimeError, err_string):
2016
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2019
@onlyNativeDeviceTypes
2020
def test_minmax_illegal_dtype(self, device):
2021
x = torch.randn(5, 5, dtype=torch.float32, device=device)
2022
valid_values = torch.empty(5, dtype=torch.float32, device=device)
2023
valid_indices = torch.empty(5, dtype=torch.long, device=device)
2024
illegal_values = torch.empty(5, dtype=torch.int, device=device)
2025
illegal_indices = torch.empty(5, dtype=torch.double, device=device)
2026
torch.max(x, dim=0, out=(valid_values, valid_indices))
2027
torch.min(x, dim=0, out=(valid_values, valid_indices))
2028
torch.amax(x, dim=0, out=valid_values)
2029
torch.amin(x, dim=0, out=valid_values)
2030
rmsg = r'scalar type|dtype'
2031
with self.assertRaisesRegex(RuntimeError, rmsg):
2032
torch.max(x, dim=0, out=(illegal_values, valid_indices))
2033
with self.assertRaisesRegex(RuntimeError, rmsg):
2034
torch.min(x, dim=0, out=(illegal_values, valid_indices))
2035
with self.assertRaisesRegex(RuntimeError, rmsg):
2036
torch.max(x, dim=0, out=(valid_values, illegal_indices))
2037
with self.assertRaisesRegex(RuntimeError, rmsg):
2038
torch.min(x, dim=0, out=(valid_values, illegal_indices))
2039
with self.assertRaisesRegex(RuntimeError, rmsg):
2040
torch.max(x, dim=0, out=(illegal_values, illegal_indices))
2041
with self.assertRaisesRegex(RuntimeError, rmsg):
2042
torch.min(x, dim=0, out=(illegal_values, illegal_indices))
2044
@dtypes(*all_types_and(torch.half, torch.bfloat16))
2045
def test_dim_arg_reduction_scalar(self, device, dtype):
2048
x = torch.tensor(example, device=device, dtype=dtype)
2049
self.assertEqual(x.argmax().item(), 0)
2050
self.assertEqual(x.argmax(dim=None).item(), 0)
2051
self.assertEqual(x.argmax(dim=0).item(), 0)
2052
self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2054
x = torch.tensor(example, device=device, dtype=dtype)
2055
self.assertEqual(x.argmin().item(), 0)
2056
self.assertEqual(x.argmin(dim=None).item(), 0)
2057
self.assertEqual(x.argmin(dim=0).item(), 0)
2058
self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2061
@precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
2062
@dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
2063
def test_dim_reduction(self, device, dtype):
2064
example = [[-1, 2, 1], [5, 3, 6]]
2067
torch.bfloat16: torch.bfloat16,
2068
torch.double: torch.double,
2069
torch.float: torch.float,
2070
torch.half: torch.half,
2071
torch.int64: torch.int64,
2072
torch.int32: torch.int64,
2073
torch.int16: torch.int64,
2074
torch.int8: torch.int64
2080
x = torch.tensor(example, device=device, dtype=dtype)
2081
self.assertEqual(x.sum().item(), 16)
2082
self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype]))
2083
self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype]))
2084
y = torch.tensor(example, device=device, dtype=sum_dtype[dtype])
2085
torch.sum(x, 0, out=y)
2086
self.assertEqual(x.sum(0), y)
2089
if dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
2090
x = torch.tensor(example, device=device, dtype=dtype)
2091
self.assertEqual(x.mean().item(), 16.0 / 6)
2092
self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype))
2093
self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype))
2094
self.assertEqual(x.mean(), x.mean((0, 1)))
2097
torch.bfloat16: torch.bfloat16,
2098
torch.double: torch.double,
2099
torch.float: torch.float,
2100
torch.float16: torch.float16,
2101
torch.int64: torch.int64,
2102
torch.int32: torch.int64,
2103
torch.int16: torch.int64,
2104
torch.int8: torch.int64,
2108
if not (self.device_type == 'cpu' and dtype in [torch.float16, torch.bfloat16]):
2109
x = torch.tensor(example, device=device, dtype=dtype)
2110
self.assertEqual(x.prod().item(), -180)
2111
self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype]))
2112
self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype]))
2114
x = torch.tensor(example, device=device, dtype=dtype)
2116
self.assertEqual(x.min().item(), -1)
2117
self.assertEqual(x.argmin().item(), 0)
2121
self.assertEqual(x.argmin(dim=None).item(), 0)
2123
self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype),
2124
torch.tensor([0, 0, 0], dtype=torch.int64)))
2125
self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype))
2126
self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64))
2128
self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype),
2129
torch.tensor([[0, 0, 0]], dtype=torch.int64)))
2130
self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype))
2131
self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64))
2133
self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype),
2134
torch.tensor([0, 1], dtype=torch.int64)))
2135
self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype))
2136
self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64))
2138
self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype),
2139
torch.tensor([[0], [1]], dtype=torch.int64)))
2140
self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype))
2141
self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64))
2144
self.assertEqual(x[:, :2].min().item(), -1)
2145
self.assertEqual(x[:, :2].amin().item(), -1)
2146
self.assertEqual(x[:, :2].argmin().item(), 0)
2148
x = torch.tensor(example, device=device, dtype=dtype)
2150
self.assertEqual(x.max().item(), 6)
2151
self.assertEqual(x.amax().item(), 6)
2152
self.assertEqual(x.argmax().item(), 5)
2154
self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype),
2155
torch.tensor([1, 1, 1], dtype=torch.int64)))
2156
self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype))
2157
self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64))
2159
self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype),
2160
torch.tensor([[1, 1, 1]], dtype=torch.int64)))
2161
self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype))
2162
self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64))
2164
self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype),
2165
torch.tensor([1, 2], dtype=torch.int64)))
2166
self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype))
2167
self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64))
2169
self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype),
2170
torch.tensor([[1], [2]], dtype=torch.int64)))
2171
self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype))
2172
self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64))
2175
self.assertEqual(x[:, :2].max().item(), 5)
2176
self.assertEqual(x[:, :2].amax().item(), 5)
2177
self.assertEqual(x[:, :2].argmax().item(), 2)
2180
"mean", "median", "nanmedian", "mode", "norm", "prod",
2181
"std", "sum", "var", "max", "min", "amax", "amin"]
2183
def normfn_attr(t, dim, keepdim=False, out=None):
2185
return attr(t, 2, dim, keepdim, out=out)
2187
for fn_name in dim_red_fns:
2188
fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
2190
def fn(x, dim, keepdim=False, out=None):
2191
ans = fn_attr(x, dim, keepdim=keepdim, out=out)
2192
return ans if not isinstance(ans, tuple) else ans[0]
2194
def fn_tuple(x, dim, keepdim=False, out=None):
2195
return fn_attr(x, dim, keepdim=keepdim, out=out)
2197
def test_multidim(x, dim):
2198
self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
2199
self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
2200
self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())
2203
x = torch.randn(3, 4, 5, device=device)
2204
dim = random.randint(0, 2)
2205
test_multidim(x, dim)
2208
x = torch.randn(1, device=device)
2210
self.assertEqual(fn(x, dim).shape, ())
2211
self.assertEqual(fn(x, dim, keepdim=True).shape, (1,))
2215
singleton_dim = random.randint(0, 2)
2216
dims[singleton_dim] = 1
2217
x = torch.randn(dims, device=device)
2218
test_multidim(x, singleton_dim)
2221
if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']:
2222
y = torch.randn(5, 3, device=device)
2223
values = torch.randn(5, 3, device=device)
2224
indices = torch.zeros(5, 3, device=device).long() - 1
2225
fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
2226
values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
2227
self.assertEqual(values[:, 1], values_expected,
2228
msg=f'{fn_name} values with out= kwarg')
2229
self.assertEqual(indices[:, 1], indices_expected,
2230
msg=f'{fn_name} indices with out= kwarg')
2233
x = torch.randn(5, 3, device=device)
2234
y = torch.randn(5, 3, device=device)
2235
fn(y, 1, keepdim=False, out=x[:, 1])
2236
expected = fn(y, 1, keepdim=False)
2237
self.assertEqual(x[:, 1], expected, msg=f'{fn_name} with out= kwarg')
2240
@largeTensorTest('10GB')
2241
def test_reduction_split(self, device):
2244
input_ = torch.randn(5, 14400, 14400, device=device)
2245
result = input_.sum(dim=0)
2246
expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4]
2247
self.assertEqual(result, expect)
2250
@dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2251
def test_reduction_vectorize_along_input_corner(self, device, dtype):
2253
size = 1024 * 1024 * 64 + 3
2255
x = torch.zeros(size, dtype=dtype, device=device)
2257
for i in range(100):
2260
self.assertEqual(x.sum(), 1.0)
2262
self.assertEqual(y.sum(), 0.0)
2264
self.assertEqual(y.sum(), 1.0)
2265
for i in range(1, 100):
2268
self.assertEqual(x.sum(), 1.0)
2269
self.assertEqual(y.sum(), 1.0)
2271
size = 1024 * 1024 * 64 + 3
2273
ysize = size - shift
2274
x = torch.zeros(size, dtype=dtype, device=device)
2276
for i in range(100):
2279
self.assertEqual(x.argmax().item(), i)
2281
self.assertEqual(y.argmax().item(), i - shift)
2282
for i in range(1, 100):
2285
self.assertEqual(x.argmax().item(), size - i)
2286
self.assertEqual(y.argmax().item(), ysize - i)
2288
size = (7, 1024 * 1024 + 3)
2289
x = torch.zeros(size, dtype=dtype, device=device)
2290
for i in range(100):
2296
self.assertEqual(xs[j].item(), float(j))
2297
for i in range(100):
2303
self.assertEqual(xs[j].item(), float(j))
2305
size = (7, 1024 * 1024 + 3)
2306
x = torch.zeros(size, dtype=dtype, device=device)
2307
for i in range(100):
2311
xs1 = x.argmax(dim=-1)
2312
xs2 = x.max(dim=-1).indices
2314
self.assertEqual(xs1[j].item(), i)
2315
self.assertEqual(xs2[j].item(), i)
2316
for i in range(1, 100):
2320
xs1 = x.argmax(dim=-1)
2321
xs2 = x.max(dim=-1).indices
2323
self.assertEqual(xs1[j].item(), size[1] - i)
2324
self.assertEqual(xs2[j].item(), size[1] - i)
2326
size = (7, 1024 * 1024 + 3)
2327
x = torch.zeros(size, dtype=dtype, device=device)
2328
for i in range(100):
2332
xs1 = x.argmin(dim=-1)
2333
xs2 = x.min(dim=-1).indices
2335
self.assertEqual(xs1[j].item(), i)
2336
self.assertEqual(xs2[j].item(), i)
2337
for i in range(1, 100):
2341
xs1 = x.argmin(dim=-1)
2342
xs2 = x.min(dim=-1).indices
2344
self.assertEqual(xs1[j].item(), size[1] - i)
2345
self.assertEqual(xs2[j].item(), size[1] - i)
2348
@dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2349
def test_reduction_vectorize_along_output(self, device, dtype):
2350
def run_test(input_):
2353
for i in range(min(M, N)):
2355
output1 = input_.argmax(dim=0)
2356
output2 = input_.sum(dim=0)
2357
for i in range(min(M, N)):
2358
self.assertEqual(output1[i], i)
2359
self.assertEqual(output2[i], 1)
2361
run_test(torch.zeros(64, 64, dtype=dtype, device=device))
2363
run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64))
2364
run_test(torch.zeros(64, 62, dtype=dtype, device=device))
2365
run_test(torch.zeros(64, 2, dtype=dtype, device=device))
2367
run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64))
2368
run_test(torch.zeros(64, 61, dtype=dtype, device=device))
2369
run_test(torch.zeros(64, 1, dtype=dtype, device=device))
2372
def test_argminmax_large_axis(self, device):
2374
x = torch.zeros(2**31, device=device, dtype=torch.int8)
2376
self.assertEqual(x.argmax(0), x.shape[0] - 1)
2377
self.assertEqual(x.max(0).indices, x.shape[0] - 1)
2379
self.assertEqual(x.argmin(0), x.shape[0] - 1)
2380
self.assertEqual(x.min(0).indices, x.shape[0] - 1)
2382
def test_argminmax_axis_with_dim_one(self, device):
2385
x = torch.zeros(1, n)
2386
self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64))
2387
self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64))
2389
self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64))
2390
self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64))
2392
self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2393
self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2395
self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2396
self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2398
@dtypes(torch.int, torch.long, torch.float, torch.double)
2399
@dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double)
2400
def test_median_real_values(self, device, dtype):
2402
sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2405
t = torch.randn(size, device=device).type(dtype)
2406
t_numpy = t.cpu().numpy()
2408
self.assertEqual(res, t.nanmedian())
2409
k = int((t.numel() - 1) / 2)
2410
self.assertEqual(res, t.view(-1).sort()[0][k])
2411
if t.numel() % 2 == 1:
2414
self.assertEqual(res.cpu().numpy(), np.median(t_numpy))
2415
for dim in range(t.ndim):
2416
res = t.median(dim, True)
2417
self.assertEqual(res, t.nanmedian(dim, True))
2418
size = t.size(dim) if t.ndim > 0 else 1
2419
k = int((size - 1) / 2)
2420
self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim))
2421
self.assertEqual(res[0], t.gather(dim, res[1]))
2425
self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False)
2427
@dtypes(torch.float, torch.double)
2428
@dtypesIfCUDA(torch.half, torch.float, torch.double)
2429
def test_median_nan_values(self, device, dtype):
2431
sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2434
t = torch.rand(size, device=device, dtype=dtype)
2435
t.masked_fill_(t < 0.1, float('nan'))
2436
t_numpy = t.cpu().numpy()
2437
for op in [torch.median, torch.nanmedian]:
2438
numpy_op = np.median if op == torch.median else np.nanmedian
2440
num_nan = t.isnan().sum()
2441
if op == torch.median and num_nan > 0:
2444
k = int((t.numel() - num_nan - 1) / 2)
2445
self.assertEqual(res, t.view(-1).sort()[0][k])
2446
if (t.numel() - num_nan) % 2 == 1:
2449
self.assertEqual(res.item(), numpy_op(t.cpu().numpy()))
2450
for dim in range(t.ndim):
2451
res = op(t, dim, True)
2452
size = t.size(dim) if t.ndim > 0 else 1
2453
num_nan = t.isnan().sum(dim, True)
2454
if op == torch.median:
2455
k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2))
2457
k = ((size - num_nan - 1) / 2).type(torch.long)
2458
self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k))
2459
self.assertEqual(res[0], t.gather(dim, res[1]))
2462
mask = (size - num_nan) % 2 == 1
2463
res = res[0].masked_select(mask).cpu()
2464
ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()]
2465
self.assertEqual(res, torch.from_numpy(ref))
2467
def test_median_corner_cases(self, device):
2468
def check(op, a, args, key):
2469
t = torch.tensor(a, device=device)
2472
key = torch.tensor(key, device=device)
2475
key = torch.tensor(key[0], device=device)
2478
key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device))
2479
self.assertEqual(res, key)
2482
check(torch.median, nan, [], nan)
2483
check(torch.median, [], [], nan)
2484
check(torch.nanmedian, nan, [], nan)
2485
check(torch.median, nan, [0], [nan, 0])
2486
check(torch.nanmedian, nan, [0], [nan, 0])
2487
check(torch.median, [nan], [0, True], [[nan], [0]])
2488
check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2489
check(torch.median, [nan], [0, True], [[nan], [0]])
2490
check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2493
check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]])
2494
check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]])
2495
check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
2496
check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
2499
a = torch.arange(12, device=device)
2500
self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
2501
self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
2504
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2505
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2506
self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
2507
self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
2510
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2511
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2512
self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2513
self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2516
@onlyNativeDeviceTypes
2517
@dtypes(torch.float, torch.double)
2518
def test_quantile(self, device, dtype):
2520
ops = ['quantile', 'nanquantile']
2521
inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)]
2522
quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)]
2523
keepdims = [True, False]
2526
inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)])
2527
inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]])
2528
inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]])
2529
quantiles.extend([0.5, [0., 1.], np.random.rand(10)])
2532
for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims):
2533
if type(x) is tuple:
2534
a = torch.randn(x, dtype=dtype, device=device)
2536
a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan'))
2538
a = torch.tensor(x, dtype=dtype, device=device)
2540
q = torch.tensor(q, dtype=dtype, device=device)
2542
torch_op = getattr(torch, op)
2543
numpy_op = getattr(np, op)
2546
interpolations = ('linear', 'lower', 'higher', 'midpoint', 'nearest')
2547
for interpolation, dim in product(interpolations,
2548
[None] + list(range(a.ndim))):
2549
result = torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation)
2550
expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim,
2551
interpolation=interpolation, keepdims=keepdim)
2552
self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type()))
2555
out = torch.empty_like(result)
2556
torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out)
2557
self.assertEqual(out.cpu(), result.cpu())
2559
def test_quantile_backward(self, device):
2560
def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)):
2562
t = torch.tensor(a, device=device, requires_grad=True)
2563
op(t, torch.tensor(q, device=device), dim).sum().backward()
2564
self.assertEqual(t.grad, expected_grad)
2566
check([1., 2, 3], 0.5, 0, [0, 1, 0])
2567
check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0])
2568
check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5])
2569
check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25])
2570
check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]])
2571
check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]])
2572
check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile])
2573
check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile])
2575
def test_quantile_error(self, device):
2576
def check(a, q, args, kwargs, message):
2577
with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message):
2578
at = torch.tensor(a, device=device)
2579
qt = torch.tensor(q, device=device) if isinstance(q, list) else q
2580
torch.quantile(at, qt, *args, **kwargs)
2582
check([], 0.5, [], {}, r'input tensor must be non-empty')
2583
check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor')
2584
check([1], 0.5, [], {}, r'input tensor must be either float or double dtype')
2585
check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor')
2586
check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1')
2587
check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1')
2588
check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.int32, device=device)},
2589
r'out tensor must be same dtype as the input tensor')
2590
check([1.], [1.], [None, False], {'interpolation': 'random_mode'},
2591
r"interpolation must be one of linear, lower, higher, midpoint or nearest, but got random_mode")
2593
if self.device_type == "cpu":
2594
check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]')
2596
if self.device_type == "cuda":
2597
with self.assertRaisesRegex(
2598
RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'):
2599
torch.randn(1, device=device).quantile(torch.tensor(0.5))
2600
with self.assertRaisesRegex(
2601
RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'):
2602
torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1))
2604
def test_std_mean(self, device):
2605
x = torch.rand(100, 50, 20, device=device)
2606
for dim in range(x.dim()):
2607
for unbiased in [False, True]:
2608
for keepdim in [False, True]:
2609
std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2610
std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2611
mean2 = x.mean(dim=dim, keepdim=keepdim)
2612
self.assertEqual(std1, std2)
2613
self.assertEqual(mean1, mean2)
2615
def test_std_mean_all_dims(self, device):
2616
x = torch.rand(100, 50, 20, device=device)
2617
for unbiased in [False, True]:
2618
std1, mean1 = torch.std_mean(x, unbiased=unbiased)
2619
std2 = x.std(unbiased=unbiased)
2621
self.assertEqual(std1, std2)
2622
self.assertEqual(mean1, mean2)
2624
def test_var_mean(self, device):
2625
x = torch.rand(100, 300, 50, device=device)
2626
for dim in range(x.dim()):
2627
for unbiased in [False, True]:
2628
for keepdim in [False, True]:
2629
var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2630
var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
2631
mean2 = x.mean(dim=dim, keepdim=keepdim)
2632
self.assertEqual(var1, var2)
2633
self.assertEqual(mean1, mean2)
2635
def test_var_mean_all_dims(self, device):
2636
x = torch.rand(100, 50, 20, device=device)
2637
for unbiased in [False, True]:
2638
var1, mean1 = torch.var_mean(x, unbiased=unbiased)
2639
var2 = x.var(unbiased=unbiased)
2641
self.assertEqual(var1, var2)
2642
self.assertEqual(mean1, mean2)
2644
def test_std_mean_some_dims(self, device):
2645
sizes = (4, 6, 7, 5, 3)
2647
x = torch.rand(sizes, device=device)
2648
for num_of_dims in range(2, dims):
2649
dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2650
for dim in dim_list:
2651
for unbiased in [False, True]:
2652
for keepdim in [False, True]:
2653
std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2654
std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2655
mean2 = x.mean(dim=dim, keepdim=keepdim)
2656
self.assertEqual(std1, std2)
2657
self.assertEqual(mean1, mean2)
2659
def _compare_std_var_with_numpy(self, op, device, dtype, input, dim,
2660
keepdim, unbiased, use_out):
2661
a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy()
2664
'keepdims' : keepdim,
2665
'ddof' : 1 if unbiased else 0,
2669
del numpy_kwargs['axis']
2670
del numpy_kwargs['keepdims']
2673
torch_op = torch.var
2676
torch_op = torch.std
2679
self.fail("Unknown op!")
2681
numpy_result = numpy_op(a, **numpy_kwargs)
2683
if dim is None and use_out is False:
2684
torch_result = torch_op(input, unbiased)
2685
elif dim is not None and use_out is False:
2686
torch_result = torch_op(input, dim, unbiased, keepdim)
2687
elif dim is not None and use_out is True:
2688
out = torch.empty(0, device=device, dtype=dtype)
2689
torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2691
out = torch.empty(0, device=device, dtype=dtype)
2692
torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2694
exact_dtype = input.dtype not in (torch.bfloat16, torch.complex32, torch.complex64, torch.complex128)
2695
self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype)
2697
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2698
def test_var_vs_numpy(self, device, dtype):
2701
for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2706
self._compare_std_var_with_numpy('var', device, dtype, *test_case)
2708
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2709
def test_std_vs_numpy(self, device, dtype):
2712
for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2717
self._compare_std_var_with_numpy('std', device, dtype, *test_case)
2719
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2720
def test_var_correction_vs_numpy(self, device, dtype):
2734
tensor = make_tensor(_size, device=device, dtype=dtype)
2735
array = tensor.cpu().numpy()
2737
for dim, correction, keepdim in test_args:
2738
numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2739
if correction is None:
2741
numpy_kwargs['ddof'] = 1
2743
numpy_res = np.asarray(np.var(array, **numpy_kwargs))
2744
torch_res = torch.var(tensor, dim=dim, correction=correction, keepdim=keepdim)
2748
numpy_res[np.isinf(numpy_res)] = np.nan
2749
torch_res[torch_res.isinf()] = np.nan
2751
self.assertEqual(torch_res, numpy_res)
2753
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2754
def test_std_correction_vs_numpy(self, device, dtype):
2768
tensor = make_tensor(_size, device=device, dtype=dtype)
2769
array = tensor.cpu().numpy()
2771
for dim, correction, keepdim in test_args:
2772
numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2773
if correction is None:
2775
numpy_kwargs['ddof'] = 1
2777
numpy_res = np.asarray(np.std(array, **numpy_kwargs))
2778
torch_res = torch.std(tensor, dim=dim, correction=correction, keepdim=keepdim)
2782
numpy_res[np.isinf(numpy_res)] = np.nan
2783
torch_res[torch_res.isinf()] = np.nan
2785
self.assertEqual(torch_res, numpy_res)
2787
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2788
def test_std_mean_correction(self, device, dtype):
2802
tensor = make_tensor(_size, device=device, dtype=dtype)
2804
for dim, correction, keepdim in test_args:
2805
kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2806
std1 = torch.std(tensor, **kwargs)
2808
mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2810
mean1 = torch.mean(tensor)
2812
mean1 = mean1.reshape((1,) * tensor.ndim)
2813
std2, mean2 = torch.std_mean(tensor, **kwargs)
2815
self.assertEqual(std1, std2)
2816
self.assertEqual(mean1, mean2)
2818
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2819
def test_var_mean_correction(self, device, dtype):
2833
tensor = make_tensor(_size, device=device, dtype=dtype)
2835
for dim, correction, keepdim in test_args:
2836
kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2837
var1 = torch.var(tensor, **kwargs)
2839
mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2841
mean1 = torch.mean(tensor)
2843
mean1 = mean1.reshape((1,) * tensor.ndim)
2844
var2, mean2 = torch.var_mean(tensor, **kwargs)
2846
self.assertEqual(var1, var2)
2847
self.assertEqual(mean1, mean2)
2849
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2850
def test_warn_invalid_degrees_of_freedom(self, device, dtype):
2851
def _assert_warning(_func, _tensor, _correction):
2852
with warnings.catch_warnings(record=True) as w:
2853
_func(_tensor, dim=-1, correction=_correction)
2854
self.assertIn('degrees of freedom is <= 0', str(w[0].message))
2857
size = (10, correction)
2858
tensor = make_tensor(size, dtype=dtype, device=device)
2859
for f in [torch.std, torch.var, torch.var_mean, torch.std_mean]:
2860
_assert_warning(f, tensor, correction)
2862
def test_amin_amax_some_dims(self, device):
2863
sizes = (4, 6, 7, 5, 3)
2865
x = torch.rand(sizes, device=device)
2866
for num_of_dims in range(2, dims):
2867
dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2868
for dim in dim_list:
2869
for keepdim in [False, True]:
2870
amin1 = torch.amin(x, dim=dim, keepdim=keepdim)
2871
amax1 = torch.amax(x, dim=dim, keepdim=keepdim)
2874
for i, d in enumerate(dim):
2877
amin2 = torch.amin(amin2, dim=d, keepdim=keepdim)
2878
amax2 = torch.amax(amax2, dim=d, keepdim=keepdim)
2879
self.assertEqual(amin1, amin2)
2880
self.assertEqual(amax1, amax2)
2882
def test_histc(self, device):
2884
with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
2885
torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2887
actual = torch.histc(torch.tensor([], device=device), min=0, max=3)
2888
expected = torch.zeros(100, dtype=torch.float, device=device)
2889
self.assertEqual(expected, actual)
2892
actual = torch.histc(
2893
torch.tensor([2, 5], dtype=torch.float, device=device))
2894
expected = torch.zeros(100, dtype=torch.float, device=device)
2897
self.assertEqual(expected, actual)
2899
actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2901
torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2904
actual = torch.histc(
2905
torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2907
torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2910
actual = torch.histc(
2911
torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2912
bins=5, min=1, max=5)
2914
torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2917
actual = torch.histc(
2918
torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2919
bins=4, min=0, max=3)
2921
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2924
actual = torch.histc(
2925
torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3)
2927
torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2929
self.assertEqual(actual.dtype, torch.double)
2931
actual = torch.histc(
2932
torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2933
bins=4, min=0, max=3)
2935
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2937
self.assertEqual(actual.dtype, torch.float)
2939
actual = torch.histc(
2940
torch.tensor(0, dtype=torch.float, device=device),
2941
bins=1, min=0, max=3)
2943
torch.tensor([1], dtype=torch.float, device=device),
2946
with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'):
2947
torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device))
2948
with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'):
2949
torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device))
2952
torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device),
2953
bins=1, min=0, max=3),
2954
torch.tensor([0], dtype=torch.float, device=device))
2956
torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device),
2958
torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
2960
with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'):
2961
torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device))
2964
torch.histc(torch.tensor([1., 2., float("nan")], dtype=torch.float, device=device),
2966
torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
2968
with self.assertRaisesRegex(RuntimeError, "max must be larger than min"):
2969
torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device),
2970
bins=4, min=5, max=1)
2973
def test_against_np(tensor, bins=100, min=0, max=0):
2974
if min == 0 and max == 0:
2975
min = tensor.min().item()
2976
max = tensor.max().item()
2977
nparr = tensor.cpu().numpy()
2978
actual = torch.histc(tensor, bins=bins, min=min, max=max)
2979
expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
2980
actual_cpu = actual.cpu()
2982
self.assertEqual(actual, expected.to(actual_cpu))
2984
test_against_np(torch.tensor([1., 2, 1], device=device))
2985
test_against_np(torch.randn(5000, device=device))
2988
test_against_np(torch.randn(301, device=device), bins=10)
2991
test_against_np(torch.randn(201, device=device), min=0.1, max=1)
2993
noncontig = torch.randn(100, 3, device=device)[:, 2]
2994
test_against_np(noncontig)
2996
multidim = torch.randn(3, 5, 7, 2, device=device)
2997
test_against_np(multidim)
2999
expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
3000
test_against_np(expanded)
3002
linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device)
3003
test_against_np(linear, bins=20, min=0, max=0.99)
3006
@dtypes(torch.bfloat16, torch.half)
3007
def test_histc_lowp(self, device, dtype):
3008
actual = torch.histc(
3009
torch.tensor([1, 2, 1], dtype=dtype, device=device), bins=4, min=0, max=3)
3011
torch.tensor([0, 2, 1, 0], dtype=dtype, device=device),
3013
self.assertEqual(actual.dtype, dtype)
3016
Runs torch.histogram and numpy.histogram on the specified input parameters
3017
and asserts that their output is equal.
3019
def _test_histogram_numpy(self, t, bins, bin_range, weights, density):
3021
if not torch.is_tensor(t):
3024
return t.cpu().numpy()
3027
def reference_histogram(self, t, bins, bin_range, weights, density, dtype):
3028
(np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3029
(np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density)
3030
return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype))
3034
(actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density)
3036
(actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density)
3038
(expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype)
3041
Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3042
When bin edges are not explicitly defined, histogram uses the linspace operator internally
3043
to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3044
from numpy.linspace output.
3045
Issue: https://github.com/pytorch/pytorch/issues/58758
3047
if not torch.is_tensor(bins):
3048
self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5)
3050
(expected_hist, expected_bin_edges) = reference_histogram(
3051
self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3053
self.assertEqual(actual_hist, expected_hist)
3054
self.assertEqual(actual_bin_edges, expected_bin_edges)
3057
hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype,
3059
bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype,
3064
torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out))
3066
torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out))
3068
self.assertEqual(hist_out, expected_hist)
3069
self.assertEqual(bin_edges_out, expected_bin_edges)
3072
@dtypes(torch.float32)
3073
def test_histogram(self, device, dtype):
3083
for contig, bins_contig, bin_ct, weighted, density, shape in \
3084
product([True, False], [True, False], range(1, 10), [True, False], [True, False], shapes):
3085
values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3086
weights = make_tensor(shape, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) if weighted else None
3089
self._test_histogram_numpy(values, bin_ct, None, weights, density)
3092
bin_range = sorted((random.uniform(-9, 9), random.uniform(-9, 9)))
3093
self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3096
bin_range[1] = bin_range[0]
3097
self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3100
bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort()
3103
bin_edges_noncontig = make_tensor(bin_ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3104
bin_edges_noncontig.copy_(bin_edges)
3105
bin_edges = bin_edges_noncontig
3106
self.assertEqual(bin_edges.is_contiguous(), bins_contig)
3107
self._test_histogram_numpy(values, bin_edges, None, weights, density)
3110
elt = random.uniform(-9, 9)
3111
values = make_tensor(shape, dtype=dtype, device=device, low=elt, high=elt, noncontiguous=not contig)
3112
self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3113
self._test_histogram_numpy(values, bin_edges, None, weights, density)
3117
make_tensor(bin_ct + 1, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3121
self._test_histogram_numpy(bin_edges, bin_edges, None, weights, density)
3124
for bin_ct, shape in product(range(1, 10), shapes):
3125
values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
3126
(actual_hist, actual_bin_edges) = torch.histogram(values, bin_ct)
3127
(expected_hist, expected_bin_edges) = torch.histogram(
3128
values, bin_ct, range=None, weight=None, density=False)
3129
self.assertEqual(actual_hist, expected_hist)
3130
self.assertEqual(actual_bin_edges, expected_bin_edges)
3133
Runs torch.histogramdd and numpy.histogramdd on the specified input parameters
3134
and asserts that their output is equal.
3136
def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
3139
return list(map(to_np, t))
3140
if not torch.is_tensor(t):
3142
return t.cpu().numpy()
3145
def reference_histogramdd(t, bins, bin_range, weights, density, dtype):
3146
(np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3150
N = np.prod(np_t.shape[:-1])
3151
reshaped_t = np.reshape(np_t, (N, D))
3152
reshaped_wt = np.reshape(np_weights, (N,)) if np_weights is not None else None
3156
return (torch.tensor(float('nan') if density else 0.), [])
3159
reshaped_range = None if not bin_range else [(bin_range[2 * i], bin_range[2 * i + 1]) for i in range(D)]
3161
(np_hist, np_bin_edges) = np.histogramdd(reshaped_t, np_bins,
3162
range=reshaped_range, weights=reshaped_wt, density=density)
3164
return (torch.from_numpy(np_hist).to(dtype), [torch.from_numpy(t).to(dtype) for t in np_bin_edges])
3166
(actual_hist, actual_bin_edges) = torch.histogramdd(t, bins, range=bin_range, weight=weights, density=density)
3167
(expected_hist, expected_bin_edges) = reference_histogramdd(t, bins, bin_range, weights, density, actual_hist.dtype)
3169
D = len(actual_bin_edges)
3170
self.assertEqual(D, len(expected_bin_edges))
3173
Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3174
When bin edges are not explicitly defined, histogram uses the linspace operator internally
3175
to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3176
from numpy.linspace output.
3177
Issue: https://github.com/pytorch/pytorch/issues/58758
3179
if not torch.is_tensor(bins):
3180
for dim in range(D):
3181
self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5)
3183
(expected_hist, expected_bin_edges) = reference_histogramdd(
3184
t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3185
self.assertEqual(D, len(expected_bin_edges))
3187
self.assertEqual(actual_hist, expected_hist)
3188
for dim in range(D):
3189
self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim])
3192
@dtypes(torch.float32)
3193
def test_histogramdd(self, device, dtype):
3205
for contig, bins_contig, weighted, density, shape in \
3206
product([True, False], [True, False], [True, False], [True, False], shapes):
3209
values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3211
make_tensor(shape[:-1], dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3217
bin_ct = random.randint(1, 5)
3218
self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3221
bin_ct = [random.randint(1, 5) for dim in range(D)]
3222
self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3225
bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)]
3226
bin_range = [elt for t in bin_range_tuples for elt in t]
3227
self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3230
for dim in range(D):
3231
bin_range[2 * dim + 1] = bin_range[2 * dim]
3232
self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3235
bin_edges = [make_tensor(ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() for ct in bin_ct]
3238
bin_edges_noncontig = [
3239
make_tensor(ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3242
for dim in range(D):
3243
bin_edges_noncontig[dim].copy_(bin_edges[dim])
3244
bin_edges = bin_edges_noncontig
3245
for dim in range(D):
3246
self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig)
3247
self._test_histogramdd_numpy(values, bin_edges, None, weights, density)
3250
@dtypes(torch.float32)
3251
def test_histogram_error_handling(self, device, dtype):
3252
with self.assertRaisesRegex(RuntimeError, 'not implemented for'):
3253
values = make_tensor((), dtype=torch.int32, device=device)
3254
torch.histogram(values, 1)
3256
inconsistent_dtype = torch.float32 if dtype != torch.float32 else torch.float64
3258
with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensors should have the same dtype'):
3259
values = make_tensor((), dtype=dtype, device=device)
3260
bins = make_tensor((), dtype=inconsistent_dtype, device=device)
3261
torch.histogram(values, bins)
3263
with self.assertRaisesRegex(RuntimeError, 'input tensor and weight tensor should have the same dtype'):
3264
values = make_tensor((), dtype=dtype, device=device)
3265
weight = make_tensor((), dtype=inconsistent_dtype, device=device)
3266
torch.histogram(values, 1, weight=weight)
3268
with self.assertRaisesRegex(RuntimeError, 'input tensor and hist tensor should have the same dtype'):
3269
values = make_tensor((), dtype=dtype, device=device)
3270
hist = make_tensor((), dtype=inconsistent_dtype, device=device)
3271
bin_edges = make_tensor((), dtype=dtype, device=device)
3272
torch.histogram(values, 1, out=(hist, bin_edges))
3274
with self.assertRaisesRegex(RuntimeError, 'input tensor and bin_edges tensor should have the same dtype'):
3275
values = make_tensor((), dtype=dtype, device=device)
3276
hist = make_tensor((), dtype=dtype, device=device)
3277
bin_edges = make_tensor((), dtype=inconsistent_dtype, device=device)
3278
torch.histogram(values, 1, out=(hist, bin_edges))
3280
with self.assertRaisesRegex(RuntimeError, 'bins tensor should have one dimension'):
3281
t = make_tensor((2, 2), dtype=dtype, device=device)
3282
torch.histogram(t, t)
3284
with self.assertRaisesRegex(RuntimeError, 'bins tensor should have at least 1 element'):
3285
t = make_tensor((0), dtype=dtype, device=device)
3286
torch.histogram(t, t)
3288
with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
3289
values = make_tensor((), dtype=dtype, device=device)
3290
torch.histogram(values, -1)
3292
with self.assertRaisesRegex(RuntimeError, 'if weight tensor is provided it should have the same shape \
3293
as the input tensor excluding its innermost dimension'):
3294
values = make_tensor((2, 2), dtype=dtype, device=device)
3295
weight = make_tensor((1), dtype=dtype, device=device)
3296
torch.histogram(values, 1, weight=weight)
3298
with self.assertRaisesRegex(TypeError, 'received an invalid combination of arguments'):
3299
values = make_tensor((), dtype=dtype, device=device)
3300
bin_edges = make_tensor((), dtype=dtype, device=device)
3301
torch.histogram(values, bin_edges, range=(0, 1))
3303
with self.assertRaisesRegex(RuntimeError, 'min should not exceed max'):
3304
values = make_tensor((), dtype=dtype, device=device)
3305
torch.histogram(values, 2, range=(1, 0))
3307
with self.assertRaisesRegex(RuntimeError, r'range \[nan, nan\] is not finite'):
3308
values = torch.tensor([float("nan")], device=device, dtype=dtype)
3309
torch.histogram(values, 2)
3314
def test_tensor_compare_ops_empty(self, device):
3316
master_input = torch.randn(shape, device=device)
3317
np_input = np.empty(shape)
3319
('amax', torch.amax, np.amax),
3320
('amin', torch.amin, np.amin),
3321
('max', lambda *args, **kwargs: torch.max(*args, **kwargs).values, np.max),
3322
('min', lambda *args, **kwargs: torch.min(*args, **kwargs).values, np.min),
3323
('median', lambda *args, **kwargs: torch.median(*args, **kwargs).values, np.median),
3326
for name, fn, np_function in test_functions:
3329
error_msg = f"test function: {name}"
3330
self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3331
self.assertEqual(np_function(np_input, axis=2),
3332
fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False)
3334
self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3335
self.assertEqual(np_function(np_input, axis=-1),
3336
fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False)
3338
self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3340
self.assertEqual(np_function(np_input, axis=2, keepdims=True),
3341
fn(master_input, dim=2, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3343
self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3345
self.assertEqual(np_function(np_input, axis=-1, keepdims=True),
3346
fn(master_input, dim=-1, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3349
self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3356
def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device):
3358
master_input = torch.randn(shape, device=device)
3359
np_input = np.empty(shape)
3361
('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax),
3362
('argmin', torch.argmin, {'dtype': torch.int64}, np.argmin),
3363
('kthvalue', lambda *args, k=1, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values,
3364
{}, lambda *args, k=1, axis=None, **kwargs: np.partition(*args, k, **kwargs).take(k - 1, axis=axis))
3367
for name, fn, dtype, np_function in test_functions:
3368
error_msg = f"test function: {name}"
3369
self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_msg)
3371
np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False
3374
self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_msg)
3376
np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False
3380
self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=True),
3382
self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=-1, keepdim=True),
3386
self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3387
if name != 'kthvalue':
3388
self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input))
3396
def test_tensor_reduce_ops_empty(self, device):
3397
from scipy.special import logsumexp
3399
master_input = torch.randn(shape, device=device)
3400
np_input = np.empty(shape)
3402
('prod', torch.prod, 1., np.prod),
3403
('sum', torch.sum, 0., np.sum),
3404
('norm', torch.norm, 0., np.linalg.norm),
3405
('mean', torch.mean, nan, np.mean),
3406
('var', torch.var, nan, np.var),
3407
('std', torch.std, nan, np.std),
3408
('logsumexp', torch.logsumexp, -inf, logsumexp),
3411
for name, fn, return_value, np_function in test_functions:
3413
error_msg = f"test function: {name}"
3414
self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3415
self.assertEqual(np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg,
3418
self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3419
self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg,
3422
self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3424
self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True),
3425
msg=error_msg, exact_dtype=False)
3427
self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3429
self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True),
3430
msg=error_msg, exact_dtype=False)
3432
self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg)
3433
self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg)
3434
self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True),
3436
self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True),
3439
if name != 'logsumexp':
3441
self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(),
3443
self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(),
3445
self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)),
3446
fn(master_input, dim=1, keepdim=True).cpu().numpy(),
3448
self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)),
3449
fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
3453
self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg)
3455
self.assertRaises(TypeError, lambda: fn(master_input))
3460
def test_reduction_empty_any_all(self, device):
3462
x = torch.randn(shape, device=device)
3464
for dtype in all_types_and_complex_and(torch.half, torch.bool):
3466
if dtype == torch.uint8:
3467
out_dtype = torch.uint8
3469
out_dtype = torch.bool
3474
self.assertEqual((2, 0), xb.any(2).shape)
3475
self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape)
3476
self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1))
3477
self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True))
3478
self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any())
3481
self.assertEqual((2, 0), xb.all(2).shape)
3482
self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape)
3483
self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1))
3484
self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True))
3485
self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all())
3488
def test_reduce_dtype(self, device):
3489
def test_reduction(op, has_no_dim, takes_dtype=True):
3490
x = torch.randn(3, 3, dtype=torch.float, requires_grad=True, device=device)
3493
grad1, = torch.autograd.grad([op(x)], [x])
3494
grad2, = torch.autograd.grad([op(x, dtype=torch.double)], [x])
3495
self.assertEqual(grad1, grad2)
3496
self.assertEqual(grad2.dtype, torch.float)
3498
gi = torch.randn(op(x, dim=0).shape, dtype=torch.float, device=device)
3499
grad1, = torch.autograd.grad([op(x, dim=0)], [x], gi)
3501
grad2, = torch.autograd.grad([op(x, dim=0, dtype=torch.double)], [x], gi.double())
3503
grad2, = torch.autograd.grad([op(x.double(), dim=0)], [x], gi.double())
3504
self.assertEqual(grad1, grad2)
3505
self.assertEqual(grad2.dtype, torch.float)
3507
test_reduction(torch.sum, True)
3508
test_reduction(torch.prod, True)
3509
test_reduction(torch.cumsum, False)
3510
test_reduction(torch.cumprod, False)
3511
test_reduction(torch.logcumsumexp, False, takes_dtype=False)
3513
@ops(reference_masked_ops)
3514
def test_reference_masked(self, device, dtype, op):
3515
"""Test masked reduction operations on strided-only tensors using
3516
numpy reductions as reference.
3519
def to_numpy(input):
3520
if input.dtype is torch.bfloat16:
3521
return input.cpu().to(torch.float32).numpy()
3523
return input.cpu().numpy()
3525
samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
3526
for sample_input in samples:
3527
t = sample_input.input
3528
actual = op(t, *sample_input.args, **sample_input.kwargs)
3529
exact_dtype = not (t.dtype is torch.bfloat16
3530
or (op.promotes_int_to_float and not torch.is_floating_point(t)))
3531
expected = op.ref(to_numpy(t), *sample_input.args,
3534
identity=torch.masked._reduction_identity(op.name, t),
3535
**sample_input.kwargs))
3538
expected = np.asarray(expected)
3541
if expected.dtype in [np.uint64, np.uint32]:
3544
msg = ("Failed to produce expected results! Input tensor was"
3545
f" {t}, torch result is {actual}, and reference result is"
3546
f" {expected}.") if t.numel() < 10 else None
3548
self.assertEqual(actual, expected, msg, exact_dtype=exact_dtype)
3551
@largeTensorTest("8GB")
3552
@dtypes(torch.half, torch.chalf, torch.bfloat16)
3553
def test_reductions_large_half_tensors(self, device, dtype):
3554
t = torch.ones(2**31, device=device, dtype=dtype)
3556
expected = torch.tensor(0, device=device, dtype=dtype)
3557
self.assertEqual(torch.sum(t), expected)
3560
err_msg = "not implemented for 'ComplexHalf'"
3561
ctx = self.assertRaisesRegex(
3562
RuntimeError, err_msg) if dtype is torch.chalf else contextlib.nullcontext()
3564
self.assertEqual(torch.mean(t), expected)
3566
instantiate_device_type_tests(TestReductions, globals())
3568
if __name__ == '__main__':