pytorch

Форк
0
/
test_reductions.py 
3569 строк · 166.6 Кб
1
# Owner(s): ["module: tests"]
2

3
import contextlib
4
import torch
5
import numpy as np
6

7
import math
8
from typing import Dict, List, Sequence
9
import random
10
from functools import partial
11
from itertools import product, combinations, permutations
12
import warnings
13

14
from torch import inf, nan
15
from torch.testing import make_tensor
16
from torch.testing._internal.common_dtype import (
17
    all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and,
18
    integral_types_and, floating_and_complex_types_and, all_types_and, all_types,
19
)
20
from torch.testing._internal.common_utils import (
21
    TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict,
22
    IS_WINDOWS)
23
from torch.testing._internal.common_device_type import (
24
    OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU,
25
    onlyNativeDeviceTypes, onlyCUDA, largeTensorTest, ops, precisionOverride)
26
from torch.testing._internal.common_methods_invocations import (
27
    ReductionOpInfo, ReductionPythonRefInfo, reduction_ops, reference_masked_ops)
28

29
# TODO: replace with make_tensor
30
def _generate_input(shape, dtype, device, with_extremal):
31
    if shape == ():
32
        x = torch.tensor((), dtype=dtype, device=device)
33
    else:
34
        if dtype.is_floating_point or dtype.is_complex:
35
            # work around torch.randn not being implemented for bfloat16
36
            if dtype == torch.bfloat16:
37
                x = torch.randn(*shape, device=device) * random.randint(30, 100)
38
                x = x.to(torch.bfloat16)
39
            else:
40
                x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
41
            x[torch.randn(*shape) > 0.5] = 0
42
            if with_extremal and dtype.is_floating_point:
43
                # Use extremal values
44
                x[torch.randn(*shape) > 0.5] = float('nan')
45
                x[torch.randn(*shape) > 0.5] = float('inf')
46
                x[torch.randn(*shape) > 0.5] = float('-inf')
47
            elif with_extremal and dtype.is_complex:
48
                x[torch.randn(*shape) > 0.5] = complex('nan')
49
                x[torch.randn(*shape) > 0.5] = complex('inf')
50
                x[torch.randn(*shape) > 0.5] = complex('-inf')
51
        elif dtype == torch.bool:
52
            x = torch.zeros(shape, dtype=dtype, device=device)
53
            x[torch.randn(*shape) > 0.5] = True
54
        else:
55
            x = torch.randint(15, 100, shape, dtype=dtype, device=device)
56

57
    return x
58

59
# TODO: replace with make_tensor
60
def _rand_shape(dim, min_size, max_size):
61
    shape = []
62
    for i in range(dim):
63
        shape.append(random.randint(min_size, max_size))
64
    return tuple(shape)
65

66
def _reduced_shape(shape, dim=None, keepdim=False):
67
    """Computes the expected reduced shape given dim and keepdim
68

69
    Args:
70
        shape: The shape to reduce
71
        dim : The dimensions to reduce
72
        keepdim: If true, reduced dimensions have size 1 in the reduced shape,
73
            otherwise they are removed from the reduced shape.
74

75
    Returns:
76
        The reduced shape
77
    """
78
    if dim is None:
79
        return [1] * len(shape) if keepdim else []
80

81
    # Wrap negative dims
82
    dim = dim if isinstance(dim, Sequence) else [dim]
83
    dim = {i if i >= 0 else len(shape) + i for i in dim}
84

85
    result = []
86
    for i, size in enumerate(shape):
87
        if i not in dim:
88
            result.append(size)
89
        elif keepdim:
90
            result.append(1)
91

92
    return result
93

94
class TestReductions(TestCase):
95

96
    ###########################################################################
97
    # ReductionOpInfo unit tests
98
    ###########################################################################
99

100
    def _test_dim_keepdim(self, op: ReductionOpInfo, device, *, ndim, **dim_keepdim):
101
        """Tests output shape for input with ndim and dim and keepdim kwargs"""
102
        shape = torch.randint(2, 5, (ndim,)).tolist()
103
        t = make_tensor(shape, dtype=torch.float, device=device)
104
        args, kwargs = next(op.generate_args_kwargs(t, **dim_keepdim))
105
        result = op(t, *args, **dim_keepdim, **kwargs)
106
        expected_shape = _reduced_shape(shape, **dim_keepdim)
107
        self.assertEqual(result.shape, expected_shape, f"""
108
        expected output shape to be {expected_shape} but got {list(result.shape)}
109
        for input shape {shape} and {dim_keepdim}
110
        """)
111

112
    # TODO(@heitorschueroff) combine cases with and without keepdim once
113
    # there's support for a @parametrize decorator.
114

115
    @ops(reduction_ops, dtypes=OpDTypes.none)
116
    def test_dim_default(self, device, op: ReductionOpInfo):
117
        """Tests that the default dim reduces all dimensions."""
118
        for ndim in range(3):
119
            self._test_dim_keepdim(op, device, ndim=ndim)
120

121
    @ops(reduction_ops, dtypes=OpDTypes.none)
122
    def test_dim_default_keepdim(self, device, op: ReductionOpInfo):
123
        """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1."""
124
        for ndim in range(3):
125
            self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True)
126

127
    @ops(reduction_ops, dtypes=OpDTypes.none)
128
    def test_dim_none(self, device, op: ReductionOpInfo):
129
        """Tests that dim=None reduces all dimensions."""
130
        for ndim in range(3):
131
            self._test_dim_keepdim(op, device, ndim=ndim, dim=None)
132

133
    @ops(reduction_ops, dtypes=OpDTypes.none)
134
    def test_dim_none_keepdim(self, device, op: ReductionOpInfo):
135
        """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1."""
136
        for ndim in range(3):
137
            self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True)
138

139
    @ops(reduction_ops, dtypes=OpDTypes.none)
140
    def test_dim_single(self, device, op: ReductionOpInfo):
141
        """Tests that dim=i reduces dimension i."""
142
        self._test_dim_keepdim(op, device, ndim=0, dim=0)
143
        self._test_dim_keepdim(op, device, ndim=1, dim=0)
144
        self._test_dim_keepdim(op, device, ndim=2, dim=-1)
145
        self._test_dim_keepdim(op, device, ndim=3, dim=1)
146

147
    @ops(reduction_ops, dtypes=OpDTypes.none)
148
    def test_dim_single_keepdim(self, device, op: ReductionOpInfo):
149
        """Tests that dim=i, when keepdim=True, reduces dimension i to size 1."""
150
        self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True)
151
        self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True)
152
        self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True)
153
        self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True)
154

155
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
156
    def test_dim_empty(self, device, op: ReductionOpInfo):
157
        """Tests that dim=[] is a no-op"""
158
        self._test_dim_keepdim(op, device, ndim=0, dim=[])
159
        self._test_dim_keepdim(op, device, ndim=2, dim=[])
160

161
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
162
    def test_dim_empty_keepdim(self, device, op: ReductionOpInfo):
163
        """Tests that dim=[], when keepdim=True, is a no-op"""
164
        self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True)
165
        self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True)
166

167
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
168
    def test_dim_multi(self, device, op: ReductionOpInfo):
169
        """Tests that dim=[i, j, ...] reduces dimensions i, j, ...."""
170
        self._test_dim_keepdim(op, device, ndim=1, dim=[0])
171
        self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
172

173
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
174
    def test_dim_multi_keepdim(self, device, op: ReductionOpInfo):
175
        """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1."""
176
        self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True)
177
        self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True)
178

179
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
180
    def test_dim_multi_unsorted(self, device, op: ReductionOpInfo):
181
        """Tests that operator correctly handles unsorted dim list."""
182
        self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2])
183

184
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
185
    def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo):
186
        """Tests that operator correctly handles unsorted dim list when keepdim=True."""
187
        self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True)
188

189
    @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
190
    def test_dim_multi_duplicate(self, device, op: ReductionOpInfo):
191
        """Tests that an error is raised if dim has duplicate entries."""
192
        with self.assertRaises(RuntimeError):
193
            self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2])
194

195
    @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none)
196
    def test_dim_multi_unsupported(self, device, op: ReductionOpInfo):
197
        """Tests that ops claiming to not support multi dim actually don't."""
198
        with self.assertRaises(TypeError):
199
            self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
200

201
    @ops(reduction_ops, dtypes=OpDTypes.none)
202
    def test_dim_offbounds(self, device, op: ReductionOpInfo):
203
        """Tests that passing an off-bounds dim throws"""
204
        with self.assertRaises(IndexError):
205
            self._test_dim_keepdim(op, device, ndim=2, dim=2)
206

207
    @ops(reduction_ops, dtypes=OpDTypes.none)
208
    def test_dim_ndim_limit(self, device, op: ReductionOpInfo):
209
        """Tests that an exception is raised when reducing a tensor with more
210
        than 64 dims along some specific dimensions. dim=None is ok"""
211
        t = make_tensor([1] * 65, dtype=torch.float, device=device)
212
        with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
213
            op(t, dim=0)
214

215
    @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported)
216
    def test_identity(self, device, dtype, op: ReductionOpInfo):
217
        """Tests that the identity value is an identity for the operator"""
218
        t = make_tensor((10,), dtype=dtype, device=device)
219
        t[1::2] = op.identity
220
        args, kwargs = next(op.generate_args_kwargs(t))
221
        result = op(t[::2], *args, **kwargs)
222
        result_with_identity = op(t, *args, **kwargs)
223
        self.assertEqual(result, result_with_identity, """
224
        Adding identity value to the input tensor should not change the result.
225
        """)
226

227
    # TODO(@heitorschueroff) Update these to use the nan_policy kwarg once
228
    # it is added to reduction operators.
229

230
    @ops(filter(lambda op: op.nan_policy == 'propagate', reduction_ops), dtypes=OpDTypes.supported,
231
         allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
232
    def test_nan_policy_propagate(self, device, dtype, op: ReductionOpInfo):
233
        """Tests that nan is propagated to the output by default"""
234
        t = make_tensor((5,), dtype=dtype, device=device)
235
        t[2] = torch.nan
236
        args, kwargs = next(op.generate_args_kwargs(t))
237
        result = op(t, *args, **kwargs)
238
        self.assertTrue(result.isnan())
239

240
    @ops(filter(lambda op: op.nan_policy == 'omit', reduction_ops), dtypes=OpDTypes.supported,
241
         allowed_dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16))
242
    def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo):
243
        """Tests that NaN values do not affect the result."""
244
        t = make_tensor((10,), dtype=dtype, device=device)
245
        t[1::2] = torch.nan
246
        args, kwargs = next(op.generate_args_kwargs(t))
247
        result = op(t[::2], *args, **kwargs)
248
        result_with_nan = op(t, *args, **kwargs)
249
        self.assertEqual(result, result_with_nan)
250

251
    @ops(reduction_ops, dtypes=OpDTypes.supported)
252
    def test_result_dtype(self, device, dtype, op: ReductionOpInfo):
253
        """Tests that the result has the correct dtype"""
254
        t = make_tensor((5,), dtype=dtype, device=device)
255
        args, kwargs = next(op.generate_args_kwargs(t))
256
        result: torch.Tensor = op(t, *args, **kwargs)
257
        is_integral = dtype in integral_types_and(torch.bool)
258
        if op.promotes_int_to_float and is_integral:
259
            self.assertTrue(torch.is_floating_point(result))
260
        elif op.promotes_int_to_int64 and is_integral:
261
            self.assertEqual(result.dtype, torch.int64)
262
        elif op.result_dtype is not None:
263
            self.assertEqual(result.dtype, op.result_dtype)
264
        elif op.complex_to_real:
265
            _complex_to_real_dtype_map = {
266
                torch.complex128: torch.float64,
267
                torch.complex64: torch.float32,
268
                torch.complex32: torch.float16,
269
            }
270
            self.assertEqual(result.dtype, _complex_to_real_dtype_map.get(dtype, dtype))
271
        else:
272
            self.assertEqual(result.dtype, dtype)
273

274
    @ops(reduction_ops, dtypes=OpDTypes.none)
275
    def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo):
276
        """Tests for consistent behavior when reducing over an empty slice.
277

278
        The rules for reducing over an empty slice are as follows:
279
            - Return the identity value if the operator has one
280
            - Otherwise, return NaN if the operator promotes integral dtype to
281
              floating point dtypes.
282
            - Otherwise, raise an error
283

284
        See discussion here https://github.com/pytorch/pytorch/issues/61901
285
        """
286
        t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
287
        for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []:
288
            args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
289
            if op.identity is not None:
290
                # Reducing along empty slice should return identity
291
                result = op(t, *args, dim=dim, **kwargs)
292
                self.assertEqual(result, torch.full_like(result, op.identity))
293
            elif op.promotes_int_to_float:
294
                # Reducing along empty slice should return NaN
295
                result = op(t, *args, dim=dim, **kwargs)
296
                self.assertEqual(result, torch.full_like(result, torch.nan))
297
            else:
298
                # Reducing along empty slice should raise an error
299
                if isinstance(op, ReductionPythonRefInfo):
300
                    # ref reductions throw RuntimeError for this
301
                    with self.assertRaises(RuntimeError):
302
                        op(t, *args, dim=dim, **kwargs)
303
                else:
304
                    with self.assertRaises(IndexError):
305
                        op(t, *args, dim=dim, **kwargs)
306

307
    @ops(reduction_ops, dtypes=OpDTypes.none)
308
    def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo):
309
        """Tests that reducing a nonempty slice of an empty tensor returns an
310
        empty tensor with the dimensions reduced."""
311
        t = make_tensor((0, 2, 3), dtype=torch.float, device=device)
312
        for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []:
313
            args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
314
            result = op(t, *args, dim=dim, **kwargs)
315
            self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
316

317
    def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
318
        """Helper method to test noncontiguous input tensors."""
319
        assert not t.is_contiguous()
320

321
        t_contig = t.contiguous()
322
        for args, kwargs in op.generate_args_kwargs(t_contig, **reduction_kwargs):
323
            kwargs.update(reduction_kwargs)
324
            result = op(t, *args, **kwargs)
325
            expected = op(t_contig, *args, **kwargs)
326
            self.assertEqual(result, expected)
327

328
    @ops(reduction_ops)
329
    def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo):
330
        """Tests reducing along noncontiguous innermost dimension."""
331
        t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
332
        self._test_noncontiguous(op, t[:, ::2], dim=1)
333

334
    @ops(reduction_ops)
335
    def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo):
336
        """Tests reducing along noncontiguous outermost dimension."""
337
        t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1)
338
        self._test_noncontiguous(op, t[::2, :], dim=0)
339

340
    @ops(reduction_ops)
341
    def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo):
342
        """Tests reducing all dimensions of a noncontiguous tensor."""
343
        t = make_tensor((5, 5, 5), dtype=dtype, device=device, low=-1, high=1)
344
        self._test_noncontiguous(op, t[::2, ::3, 1:-1:2])
345

346
    @ops(reduction_ops)
347
    def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo):
348
        """Tests reducing a transposed tensor."""
349
        t = make_tensor((5, 5), dtype=dtype, device=device, low=-1, high=1)
350
        self._test_noncontiguous(op, t.T)
351

352
    @ops(reduction_ops)
353
    def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo):
354
        """Tests reducing a tensor with expanded singleton dimensions."""
355
        t = make_tensor((2, 3), dtype=dtype, device=device, low=-1, high=1)
356
        self._test_noncontiguous(op, t.unsqueeze(1).expand(-1, 5, -1))
357

358
    # NumPy does not support BFloat16 so we don't test that against reference
359
    # implementations. We also don't compare dtypes or test for different
360
    # keepdim because we already have other tests covering those.
361
    # The test_reference_testing in test_ops.py only uses the samples from
362
    # sample_inputs_func which do not test as exhaustively as these tests.
363

364
    def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs):
365
        """Compares op against op.ref for the given input and reduction kwargs"""
366
        for args, kwargs in op.generate_args_kwargs(t, **reduction_kwargs):
367
            kwargs.update(reduction_kwargs)
368
            result = op(t, *args, **kwargs)
369
            expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs)
370
            self.assertEqual(result, expected, exact_dtype=False)
371

372
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
373
         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
374
    def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo):
375
        """Compares op against reference for scalar input tensors"""
376
        self._test_ref(op, make_tensor([], dtype=dtype, device=device))
377

378
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
379
         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
380
    def test_ref_small_input(self, device, dtype, op: ReductionOpInfo):
381
        """Compares op against reference for small input tensors"""
382
        t = make_tensor((5, 3, 4, 2), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
383
        self._test_ref(op, t)
384
        for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
385
            self._test_ref(op, t, dim=dim)
386

387
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
388
         allowed_dtypes=[torch.float64])
389
    def test_ref_large_input_1D(self, device, dtype, op: ReductionOpInfo):
390
        """Compares op against reference for a large 1D input tensor to check stability"""
391
        self._test_ref(op, make_tensor((2 ** 20,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
392

393
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
394
         allowed_dtypes=[torch.float64])
395
    def test_ref_large_input_2D(self, device, dtype, op: ReductionOpInfo):
396
        """Compares op against reference for a large 2D input tensor to test parallelism"""
397
        t = make_tensor((32, 2 ** 16), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)
398
        self._test_ref(op, t, dim=1)
399

400
    @largeTensorTest("8gb")
401
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
402
         allowed_dtypes=[torch.float64])
403
    def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo):
404
        """Compares op against reference for a very large input tensor that requires 64 bit indexing"""
405
        self._test_ref(op, make_tensor((275000000,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True))
406

407
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
408
         allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool))
409
    def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo):
410
        """Compares op against reference for input tensors with duplicate values"""
411
        t = make_tensor((4, 4), dtype=dtype, device=device, low=-2, high=2, exclude_zero=True)
412
        t[::2, ::2] = t[1::2, 1::2]
413
        self._test_ref(op, t)
414
        self._test_ref(op, t, dim=0)
415
        self._test_ref(op, t, dim=1)
416

417
    @ops(filter(lambda op: op.ref is not None, reduction_ops),
418
         allowed_dtypes=[torch.float32, torch.complex64])
419
    def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo):
420
        """Compares op against reference for input tensors with extremal values"""
421
        t = make_tensor((5,), dtype=dtype, device=device, exclude_zero=True)
422
        extremals = [0, 1, nan, inf, -inf]
423
        for extremal in extremals:
424
            t[2] = extremal
425
            self._test_ref(op, t)
426

427
    ###########################################################################
428
    # TODO: Legacy tests - port to ReductionOpInfo
429
    ###########################################################################
430

431
    def test_var_unbiased(self, device):
432
        tensor = torch.randn(100, device=device)
433
        self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))
434
        self.assertEqual(tensor.var(), tensor.var(unbiased=True))
435
        self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False))
436

437
        tensor = torch.tensor([1.0, 2.0], device=device)
438
        self.assertEqual(tensor.var(unbiased=True), 0.5)
439
        self.assertEqual(tensor.var(unbiased=False), 0.25)
440

441
        tensor = torch.tensor([1.0, 2.0, 3.0], device=device)
442
        self.assertEqual(tensor.var(unbiased=True), 1.0)
443
        self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0)
444

445
        tensor = torch.randn(100, device=device)
446
        self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True))
447
        self.assertEqual(tensor.std(), tensor.std(unbiased=True))
448
        self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False))
449

450
    def test_var_stability(self, device):
451
        tensor = torch.tensor([2281.5, 2281.25], device=device)
452
        self.assertEqual(tensor.var(dim=0), 0.03125)
453
        self.assertEqual(tensor.var(), 0.03125)
454

455
    def test_sum_dim_reduction_uint8_overflow(self, device):
456
        example = [[-1, 2, 1], [5, 3, 6]]
457
        x = torch.tensor(example, dtype=torch.uint8, device=device)
458
        self.assertEqual(x.sum(dtype=torch.uint8).item(), 16)
459
        self.assertEqual(x.sum(0, dtype=torch.uint8), torch.tensor([4, 5, 7], dtype=torch.uint8, device=device))
460
        self.assertEqual(x.sum(1, dtype=torch.uint8), torch.tensor([2, 14], dtype=torch.uint8, device=device))
461
        y = torch.tensor(example, dtype=torch.uint8, device=device)
462
        torch.sum(x, 0, out=y)
463
        self.assertEqual(x.sum(0, dtype=torch.uint8), y)
464

465
    def test_dim_reduction_less_than_64(self, device):
466
        sizes = [1] * 65
467
        x = torch.randn(sizes, device=device)
468
        ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
469
               torch.norm]
470
        for op in ops:
471
            with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
472
                op(x, dim=64)
473
            with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"):
474
                op(x, dim=-1)
475

476
    @onlyCPU
477
    @dtypes(torch.float, torch.bfloat16)
478
    def test_dim_reduction_lastdim(self, device, dtype):
479
        x = torch.randn(3, 5, 40, device=device, dtype=dtype)
480
        x = x[:, :, 0:40:2]
481
        x2 = x.contiguous()
482
        ops = [torch.norm, torch.argmax, torch.argmin]
483
        for op in ops:
484
            y = op(x, dim=-1)
485
            y2 = op(x2, dim=-1)
486
            self.assertEqual(y, y2)
487

488
    @skipIfNoSciPy
489
    def test_logsumexp(self, device):
490
        from scipy.special import logsumexp
491
        a = torch.randn(5, 4, device=device)
492
        a[0, 0] = inf
493
        a[1, :] = -inf
494
        actual = a.logsumexp(1)
495
        expected = logsumexp(a.cpu().numpy(), 1)
496
        self.assertEqual(expected.shape, actual.shape)
497
        self.assertEqual(expected, actual)
498

499
        # check that out is actually inplace
500
        b = torch.zeros(5, 2, device=device)
501
        c = b[:, 0]
502
        torch.logsumexp(a, 1, out=c)
503
        self.assertEqual(expected, b[:, 0])
504

505
        # check integral inputs is promoted to floating point
506
        e = torch.randint(-100, 100, [5, 4], device=device)
507
        actual = e.logsumexp(1).to(torch.float64)
508
        expected = logsumexp(e.cpu().numpy(), 1)
509
        self.assertEqual(expected.shape, actual.shape)
510
        self.assertEqual(expected, actual)
511

512
    @skipIfNoSciPy
513
    @dtypes(torch.complex64, torch.complex128)
514
    def test_logcumsumexp_complex(self, device, dtype):
515
        # logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))``
516
        # and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]``
517
        # the for-loop above should produce similar precision as logcumsumexp (it's just slower),
518
        # so it can be used as the expected values to check our computation
519

520
        # using logsumexp from scipy because by the time of writing this test code,
521
        # torch.logsumexp has not been implemented for complex numbers
522
        from scipy.special import logsumexp
523

524
        def zero_out_neg_inf(t):
525
            t = t.clone()
526
            idx = torch.logical_and(~(torch.isfinite(t)), torch.real(t) < 0)
527
            t[idx] = torch.real(t[idx]).to(t.dtype)
528
            return t
529

530
        def standardize_phase(t):
531
            t = torch.real(t) + 1j * (torch.imag(t) % (2 * np.pi))
532
            return t
533

534
        def logcumsumexp_slow(a, dim):
535
            res_lst = []
536
            for i in range(a.size(dim)):
537
                index = [slice(None, None, None) for _ in range(a.ndim)]
538
                index[dim] = slice(None, i + 1, None)
539
                a_inp = a[tuple(index)]
540
                res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
541
            res = np.concatenate(res_lst, axis=dim)
542
            return torch.as_tensor(res)
543

544
        def compare_logcumsumexp(a, expected=None):
545
            for i in range(a.ndim):
546
                actual = torch.logcumsumexp(a, dim=i)
547
                # if the expected is not given, then revert to scipy's logsumexp
548
                if expected is None:
549
                    expected2 = logcumsumexp_slow(a, dim=i)
550
                else:
551
                    expected2 = expected
552

553
                # move the imaginary values to (0, 2 * pi)
554
                actual = standardize_phase(actual)
555
                expected2 = standardize_phase(expected2)
556

557
                # zeroing the imaginary part of the element if the real part is -inf
558
                # as the imaginary part cannot be determined exactly and it does not
559
                # really matter if we take the exp of the output
560
                actual = zero_out_neg_inf(actual)
561
                expected2 = zero_out_neg_inf(expected2)
562
                self.assertEqual(expected2.shape, actual.shape)
563
                self.assertEqual(expected2, actual)
564

565
        # randomly specified values
566
        # in this case, scipy.logsumexp should be enough
567
        a1 = torch.randn((5, 10), dtype=dtype, device=device)
568
        compare_logcumsumexp(a1)
569

570
        # test with some non-normal values
571
        a2 = torch.tensor([1e3 + 0j, 1e-18 + 1e4j, 1e2 + 1e-8j], dtype=dtype, device=device)
572
        compare_logcumsumexp(a2)
573

574
        # handle special case involving infinites and nans
575
        # here we don't use scipy.logsumexp as it gives confusing answer on
576
        # some inf cases
577
        # see here:
578
        inf = float('inf')
579
        nan = float('nan')
580
        a3_input = torch.tensor([
581
            -inf + 4j,
582
            -inf + 1j,
583
            1.2 + 2.1j,
584
            1e10 + 1e20j,
585
            inf + 0j,
586
            inf + 1j,
587
            inf + 3j,
588
            nan + 2j,
589
        ])
590
        a3_expected = torch.tensor([
591
            -inf + 0j,
592
            -inf + 0j,
593
            1.2 + 2.1j,
594
            1e10 + 1e20j,
595
            inf + 0j,  # scipy's logsumexp gives (inf + 0.7853982j) here, unclear why
596
            inf + (np.pi / 4) * 1j,  # the imaginary part thanks to some weird behaviour of log(inf + infj)
597
            complex(inf, nan),
598
            complex(nan, nan),
599
        ])
600
        # windows give strange results on the second-to-last results where it gives inf + pi/4 j
601
        # instead of inf + nan j
602
        if not IS_WINDOWS:
603
            compare_logcumsumexp(a3_input, a3_expected)
604

605
        a4_input = torch.tensor([
606
            complex(-inf, inf),
607
            complex(-inf, inf),
608
            -inf + 1j,
609
            1.2 + 2.1j,
610
            complex(2.4, inf),
611
        ])
612
        a4_expected = torch.tensor([
613
            -inf + 0j,
614
            -inf + 0j,
615
            -inf + 0j,
616
            1.2 + 2.1j,
617
            complex(nan, nan),
618
        ])
619
        if not IS_WINDOWS:
620
            compare_logcumsumexp(a4_input, a4_expected)
621

622
    @onlyCPU
623
    def test_sum_parallel(self, device):
624
        # To use parallel branches we'll need to compare on tensors
625
        # that are relatively large. Even if this is run on a single
626
        # core machine these tests will still give you signal on
627
        # the correctness
628

629
        def _run_test(size):
630
            for dim in range(len(size) + 1):
631
                nv = np.round(np.random.rand(*size))  # 0s and 1s
632
                tv = torch.from_numpy(nv)
633
                # Parallelisim is only used if numel is
634
                # larger than grainsize defined in Parallel.h
635
                self.assertTrue(tv.numel() > 32768)
636
                if dim == len(size):
637
                    nvs = nv.sum()
638
                    tvs = tv.sum()
639
                else:
640
                    nvs = nv.sum(dim)
641
                    tvs = tv.sum(dim)
642
                diff = np.abs(nvs - tvs.numpy()).sum()
643
                self.assertEqual(diff, 0)
644

645
        _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3])
646
        _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
647
        _run_test([1, 32 * 8 * 32 * 8])
648
        _run_test([1, 32770])
649

650
    # TODO: kill map2_ (and similar) uses and update to compare with NumPy
651
    # only works on CPU since this uses map2_, which is only supported on CPU
652
    def _testCSelection(self, torchfn, mathfn):
653
        # Two tensors
654
        size = (100, 100)
655
        a = torch.rand(*size)
656
        b = torch.rand(*size)
657
        c = torchfn(a, b)
658
        expected_c = torch.zeros(*size)
659
        expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b))
660
        self.assertEqual(expected_c, c, atol=0, rtol=0)
661

662
    @onlyCPU
663
    def test_max_elementwise(self, device):
664
        self._testCSelection(torch.max, max)
665

666
    @onlyCPU
667
    def test_min_elementwise(self, device):
668
        self._testCSelection(torch.min, min)
669

670
    def test_all_any(self, device):
671
        def test(size):
672
            x = torch.ones(*size, device=device).byte()
673
            self.assertTrue(x.all())
674
            self.assertTrue(x.any())
675

676
            x[3] = 0
677
            self.assertFalse(x.all())
678
            self.assertTrue(x.any())
679

680
            x.zero_()
681
            self.assertFalse(x.all())
682
            self.assertFalse(x.any())
683

684
            x.fill_(2)
685
            self.assertTrue(x.all())
686
            self.assertTrue(x.any())
687

688
            x = torch.ones(*size, device=device).bool()
689
            self.assertTrue(x.all())
690
            self.assertTrue(x.any())
691

692
            x[3] = False
693
            self.assertFalse(x.all())
694
            self.assertTrue(x.any())
695

696
        test((10,))
697
        test((5, 5))
698

699
    def test_all_any_with_dim(self, device):
700
        def test(x):
701
            r1 = x.prod(dim=0, keepdim=False).byte()
702
            r2 = x.all(dim=0, keepdim=False)
703
            self.assertEqual(r1.shape, r2.shape)
704
            self.assertTrue((r1 == r2).all())
705

706
            r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte()
707
            r4 = x.any(dim=1, keepdim=True)
708
            self.assertEqual(r3.shape, r4.shape)
709
            self.assertTrue((r3 == r4).all())
710

711
        test(torch.tensor([[0, 0, 0],
712
                           [0, 0, 1],
713
                           [0, 1, 1],
714
                           [1, 1, 1]], device=device, dtype=torch.uint8))
715

716
    def test_numpy_named_args(self, device):
717
        x1 = torch.randn(10, device=device)
718
        x2 = torch.randn(10, device=device)
719
        res1 = torch.add(input=x1, other=x2)
720
        res2 = torch.add(x1=x1, x2=x2)
721
        self.assertEqual(res1, res2)
722

723
        x1 = torch.randn(10, 10, 10, device=device)
724
        res1 = x1.sum(dim=(0, 2), keepdim=True)
725
        res2 = x1.sum(axis=(0, 2), keepdims=True)
726
        self.assertEqual(res1, res2)
727

728
    # TODO: kill this ane replace with common creation ops
729
    def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True,
730
                      use_complex=False) -> Dict[str, List[torch.Tensor]]:
731
        float_types = [torch.double,
732
                       torch.float]
733
        int_types = [torch.int64,
734
                     torch.int32,
735
                     torch.int16]
736

737
        complex_types = [torch.complex64,
738
                         torch.complex128]
739

740
        def make_contiguous(shape, dtype) -> torch.Tensor:
741
            if dtype in float_types:
742
                val = torch.randn(shape, dtype=dtype)
743
                val = val * ((val_range[1] - val_range[0]) / (math.pi * 2.0))
744
                val = val + ((val_range[1] - val_range[0]) / 2.0)
745
                val = torch.clamp(val, min=val_range[0], max=val_range[1])
746
                return val
747
            result = torch.zeros(shape, dtype=dtype)
748
            result.apply_(lambda x: random.randint(val_range[0], val_range[1]))
749
            return result
750

751
        def make_non_contiguous(shape, dtype) -> torch.Tensor:
752
            contig = make_contiguous(shape, dtype)
753
            non_contig = torch.empty(shape + (2, 2), dtype=dtype)[..., 0]
754
            non_contig = non_contig.select(-1, -1)
755
            non_contig.copy_(contig)
756
            self.assertFalse(non_contig.is_contiguous())
757
            return non_contig
758

759
        def make_contiguous_slice(size, dtype) -> torch.Tensor:
760
            contig = make_contiguous((1, size), dtype)
761
            non_contig = contig[:1, 1:size - 1]
762
            self.assertTrue(non_contig.is_contiguous())
763
            return contig
764

765
        types = []
766
        if use_floating:
767
            types += float_types
768
        if use_integral:
769
            types += int_types
770
        if use_complex:
771
            types += complex_types
772
        tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []}
773
        for dtype in types:
774
            tensors["cont"].append(make_contiguous(shape, dtype))
775
            tensors["noncont"].append(make_non_contiguous(shape, dtype))
776
            tensors["slice"].append(make_contiguous_slice(sum(list(shape)), dtype))
777

778
        return tensors
779

780
    # TODO: refactor this to use comparators from common_utils
781
    def _assert_matches_numpy(self, t, n):
782
        self.assertEqual(n.shape, t.shape)
783
        if t.dtype == torch.float:
784
            self.assertEqual(n, t, rtol=1e-03, atol=1e-05, equal_nan=True)
785
        else:
786
            self.assertEqual(n, t, equal_nan=True)
787

788
    # TODO: update this and tests that use it to use the device argument properly
789
    def _test_dim_ops(self, pytorch_op, numpy_op,
790
                      use_floating=True, use_integral=True, use_complex=False):
791
        def do_one(tensors_dict, dim):
792
            for category, tensors in tensors_dict.items():
793
                if category == "slice":
794
                    dim = 0
795
                for tensor in tensors:
796
                    # we have no control over NumPy warnings...
797
                    with warnings.catch_warnings():
798
                        warnings.simplefilter("ignore")
799
                        expected = numpy_op(tensor.cpu().numpy(), dim)
800
                    actual = pytorch_op(tensor, dim)
801
                    self._assert_matches_numpy(actual, expected)
802
                    if torch.cuda.is_available():
803
                        self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected)
804
        do_one(self._make_tensors((5, 400000), use_floating=use_floating,
805
                                  use_integral=use_integral, use_complex=use_complex), 1)
806
        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
807
                                  use_integral=use_integral, use_complex=use_complex), 0)
808
        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
809
                                  use_integral=use_integral, use_complex=use_complex), 1)
810
        do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
811
                                  use_integral=use_integral, use_complex=use_complex), 2)
812
        do_one(self._make_tensors((100000, ), use_floating=use_floating,
813
                                  use_integral=use_integral, use_complex=use_complex), -1)
814
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
815
                                  use_integral=use_integral, use_complex=use_complex), 0)
816
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
817
                                  use_integral=use_integral, use_complex=use_complex), 1)
818
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
819
                                  use_integral=use_integral, use_complex=use_complex), 2)
820
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
821
                                  use_integral=use_integral, use_complex=use_complex), (1, 2))
822
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
823
                                  use_integral=use_integral, use_complex=use_complex), (1, -1))
824
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
825
                                  use_integral=use_integral, use_complex=use_complex), (0, 2))
826
        do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
827
                                  use_integral=use_integral, use_complex=use_complex), (0, 2, 1))
828

829
    @slowTest
830
    @onlyCPU
831
    def test_sum_dim(self, device):
832
        self._test_dim_ops(
833
            lambda t, d: t.sum(d),
834
            lambda n, d: n.sum(d),
835
            use_floating=True, use_integral=True, use_complex=True)
836

837
    @onlyCPU
838
    def test_mean_dim(self, device):
839
        self._test_dim_ops(
840
            lambda t, d: t.mean(d),
841
            lambda n, d: n.mean(d),
842
            use_integral=False,
843
            use_complex=True)
844

845
    @onlyCPU
846
    def test_std_dim(self, device):
847
        for unbiased in [False, True]:
848
            self._test_dim_ops(
849
                lambda t, d: t.std(d, unbiased=unbiased),
850
                lambda n, d: n.std(d, ddof=1 if unbiased else 0),
851
                use_integral=False)
852

853
    @onlyCPU
854
    def test_var_dim(self, device):
855
        for unbiased in [False, True]:
856
            self._test_dim_ops(
857
                lambda t, d: t.var(d, unbiased=unbiased),
858
                lambda n, d: n.var(d, ddof=1 if unbiased else 0),
859
                use_integral=False)
860

861
    @onlyCPU
862
    @skipIfNoSciPy
863
    def test_logsumexp_dim(self, device):
864
        from scipy.special import logsumexp
865
        self._test_dim_ops(
866
            lambda t, d: t.logsumexp(d),
867
            lambda n, d: logsumexp(n, d),
868
            use_integral=False)
869

870
    @onlyCPU
871
    def test_mean_int_with_optdtype(self, device):
872
        a = make_tensor((3, 4, 5), dtype=torch.int64, device=device)
873

874
        # If the optional desired output type is given, the input
875
        # is internally cast.
876
        a_float = a.to(torch.float32)
877
        self.assertEqual(a_float.mean(), a.mean(dtype=torch.float32))
878

879
    # TODO: update this and tests that use it to handle device properly
880
    def _test_reduce_integer_upcast(self, fn, has_out=True, test_complex=True):
881
        shape = (3, 4, 5)
882
        reduced_shape = fn(torch.ones(shape)).shape
883

884
        def _test_out(dtype, other_dtype):
885
            out = torch.ones(reduced_shape, dtype=dtype)
886
            result = fn(x, out=out)
887
            self.assertIs(out.dtype, result.dtype)
888
            self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
889
            result = fn(x, out=out, dtype=dtype)
890
            self.assertIs(out.dtype, result.dtype)
891
            self.assertEqual(fn(x.to(dtype)), result, exact_dtype=False)
892
            # 'out' is favored over dtype, check error
893
            self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype))
894

895
        for dtype in [dtype for dtype in get_all_math_dtypes('cpu') if dtype != torch.float16]:
896
            x = torch.ones(shape, dtype=dtype)
897
            expected_dtype = dtype if dtype.is_floating_point or dtype.is_complex else torch.int64
898
            self.assertIs(expected_dtype, fn(x).dtype)
899
            self.assertEqual(fn(x.to(expected_dtype)), fn(x))
900

901
            if dtype.is_floating_point:
902
                other_dtype = torch.float32 if dtype == torch.float64 else torch.float64
903
            elif dtype.is_complex:
904
                other_dtype = torch.complex64 if dtype == torch.complex128 else torch.complex128
905
            else:
906
                other_dtype = torch.int32 if dtype != torch.int32 else torch.int16
907
            self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype)
908
            self.assertEqual(fn(x.to(other_dtype)), fn(x, dtype=other_dtype), exact_dtype=False)
909

910
            # test mixed int/float/complex
911
            if dtype.is_floating_point:
912
                mixed_dtypes = [torch.int32, torch.complex64]
913
            elif dtype.is_complex:
914
                mixed_dtypes = [torch.int32, torch.float32]
915
            else:
916
                mixed_dtypes = [torch.float32, torch.complex64]
917

918
            for mixed_dtype in mixed_dtypes:
919
                self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype)
920
                self.assertEqual(fn(x.to(mixed_dtype)), fn(x, dtype=mixed_dtype), exact_dtype=False)
921

922
                if has_out:
923
                    _test_out(dtype, other_dtype)
924
                    _test_out(dtype, mixed_dtype)
925

926
    @onlyCPU
927
    def test_sum_integer_upcast(self, device):
928
        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False)
929
        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs))
930

931
    @onlyCPU
932
    def test_prod_integer_upcast(self, device):
933
        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False)
934
        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs))
935

936
    @onlyCPU
937
    def test_cumsum_integer_upcast(self, device):
938
        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs))
939

940
    @onlyCPU
941
    def test_cumprod_integer_upcast(self, device):
942
        self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs))
943

944
    @dtypes(*all_types())
945
    def test_mode(self, device, dtype):
946
        SIZE = 10
947
        x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE)
948
        x[:2] = 1
949
        x[:, :2] = 1
950
        x0 = x.clone()
951

952
        # Pre-calculated results.
953
        res1val = torch.ones(SIZE, device=device, dtype=dtype)
954
        # The indices are the position of the last appearance of the mode element.
955
        res1ind = torch.ones(SIZE, device=device, dtype=torch.long)
956
        res1ind[0] = SIZE - 1
957
        res1ind[1] = SIZE - 1
958

959
        res2val, res2ind = torch.mode(x, keepdim=False)
960
        self.assertEqual(res1val, res2val, atol=0, rtol=0)
961
        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
962

963
        # Test use of result tensor
964
        res2val = torch.tensor((), device=device, dtype=dtype)
965
        res2ind = torch.tensor((), device=device, dtype=torch.long)
966
        torch.mode(x, keepdim=False, out=(res2val, res2ind))
967
        self.assertEqual(res1val, res2val, atol=0, rtol=0)
968
        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
969

970
        # Test non-default dim
971
        res2val, res2ind = torch.mode(x, 0, False)
972
        self.assertEqual(res1val, res2val, atol=0, rtol=0)
973
        self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
974

975
        # input unchanged
976
        self.assertEqual(x, x0, atol=0, rtol=0)
977

978
    def _test_mode_intervals(self, shape, intervals, device, dtype, v=1):
979
        x = torch.arange(0, shape[1], device=device, dtype=dtype).expand(shape)
980
        x = x.contiguous()
981
        x[:, v] = intervals[0][0]
982

983
        # Set the value of each interval to the mode "v"
984
        for (beg, end) in intervals:
985
            x[:, beg:end] = v
986

987
        values, indices = torch.mode(x, -1, False)
988

989
        # Check whether the returned indices correspond to the returned values
990
        self.assertTrue((x.gather(1, indices.unsqueeze(1)).t() == values).all())
991
        # Check whether the returned values are the mode
992
        self.assertTrue((values == v).all().item())
993

994
    @onlyCUDA
995
    @dtypes(*all_types_and(torch.half, torch.bfloat16))
996
    def test_mode_large(self, device, dtype):
997
        # i should be less than (d - 2) / 2
998
        def testset_for_shape(shape, i):
999
            d = shape[-1]
1000
            # Mode only in the middle.
1001
            self._test_mode_intervals(shape, [(i, d - i)], device, dtype)
1002
            # Mode in discontiguous parts of the input.
1003
            self._test_mode_intervals(shape, [(0, i), (i + 1, d - i - 1), (d - i, d)], device, dtype)
1004

1005
        # More than one line of (65535) thread blocks
1006
        testset_for_shape((65536, 10), 3)
1007

1008
        # Max slice size (2048)
1009
        testset_for_shape((10, 2048), 10)
1010

1011
        # Naive kernel for big slice sizes (> 2048)
1012
        testset_for_shape((10, 4096), 10)
1013

1014
    def test_mode_boolean(self, device):
1015
        shapes = [
1016
            (10, 10),
1017
            (4, 2048),
1018
            (1, 4096),
1019
        ]
1020

1021
        for shape in shapes:
1022
            a = torch.zeros(shape, device=device, dtype=torch.bool)
1023

1024
            a[:, (shape[1] - 1) // 2:] = True
1025
            values, indices = a.mode(-1)
1026
            self.assertEqual(values, torch.ones(shape[0], dtype=torch.bool))
1027
            print(indices)
1028
            indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1029
            self.assertEqual(values, indexed)
1030

1031
            a.fill_(False)
1032
            a[:, shape[1] // 2 + 1:] = True
1033
            values, indices = a.mode(-1)
1034
            print(indices)
1035
            self.assertEqual(values, torch.zeros(shape[0], dtype=torch.bool))
1036
            indexed = a.gather(1, indices.unsqueeze(1)).squeeze(1)
1037
            self.assertEqual(values, indexed)
1038

1039

1040
    @expectedFailureMeta  # mode only supports CPU and CUDA device type
1041
    @onlyNativeDeviceTypes
1042
    def test_mode_wrong_dtype(self, device):
1043
        def test_for_dtypes(x_ty, v_ty, i_ty, message):
1044
            x = torch.ones(10, device=device, dtype=x_ty)
1045
            v = torch.ones(10, device=device, dtype=v_ty)
1046
            i = torch.ones(10, device=device, dtype=i_ty)
1047

1048
            with self.assertRaisesRegex(RuntimeError, message):
1049
                torch.mode(x, -1, True, out=(v, i))
1050

1051
        err_msg = "expected scalar type .* but got .* for "
1052
        values_err = err_msg + "values"
1053
        indices_err = err_msg + "indices"
1054

1055
        test_for_dtypes(torch.uint8, torch.int8, torch.long, values_err)
1056
        test_for_dtypes(torch.int8, torch.int16, torch.long, values_err)
1057
        test_for_dtypes(torch.int32, torch.float32, torch.long, values_err)
1058
        test_for_dtypes(torch.float32, torch.float64, torch.long, values_err)
1059

1060
        test_for_dtypes(torch.uint8, torch.uint8, torch.int8, indices_err)
1061
        test_for_dtypes(torch.int8, torch.int8, torch.int16, indices_err)
1062
        test_for_dtypes(torch.int32, torch.int32, torch.float32, indices_err)
1063
        test_for_dtypes(torch.float32, torch.float32, torch.float64, indices_err)
1064

1065
    @onlyCUDA
1066
    def test_mode_wrong_device(self, device):
1067
        # CPU Input Tensor
1068
        x = torch.ones(2)
1069

1070
        with self.assertRaisesRegex(RuntimeError,
1071
                                    "expected device .* but got .* for values"):
1072
            values = torch.tensor([], device=device)
1073
            torch.mode(x, -1, True, out=(values, torch.tensor([], dtype=torch.long)))
1074

1075
        with self.assertRaisesRegex(RuntimeError,
1076
                                    "expected device .* but got .* for indices"):
1077
            indices = torch.tensor([], device=device)
1078
            torch.mode(x, -1, True, out=(torch.tensor([]), indices))
1079

1080
    # TODO: make work on CUDA, too
1081
    @onlyCPU
1082
    def test_accreal_type(self, device) -> None:
1083
        x = torch.ones(2, 3, 4)
1084
        self.assertIsInstance(x.double().sum().item(), float)
1085
        self.assertIsInstance(x.float().sum().item(), float)
1086
        self.assertIsInstance(x.long().sum().item(), int)
1087
        self.assertIsInstance(x.int().sum().item(), int)
1088
        self.assertIsInstance(x.short().sum().item(), int)
1089
        self.assertIsInstance(x.char().sum().item(), int)
1090
        self.assertIsInstance(x.byte().sum().item(), int)
1091

1092
    def test_var_mean_some_dims(self, device):
1093
        sizes = (4, 6, 7, 5, 3)
1094
        dims = len(sizes)
1095

1096
        x = torch.rand(sizes, device=device)
1097
        for num_of_dims in range(2, dims):
1098
            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
1099
            for dim in dim_list:
1100
                for unbiased in [False, True]:
1101
                    for keepdim in [False, True]:
1102
                        var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
1103
                        var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
1104
                        mean2 = x.mean(dim=dim, keepdim=keepdim)
1105
                        self.assertEqual(var1, var2)
1106
                        self.assertEqual(mean1, mean2)
1107

1108
    # TODO: this should be a generic opinfo test
1109
    def test_all_any_empty(self, device):
1110
        x = torch.ByteTensor().to(device)
1111
        self.assertTrue(x.all())
1112
        self.assertFalse(x.any())
1113

1114
        x = torch.BoolTensor().to(device)
1115
        self.assertTrue(x.all())
1116
        self.assertFalse(x.any())
1117

1118
    def test_all_issue117215(self, device):
1119
        info = torch.iinfo(torch.uint8)
1120
        a = torch.randint(info.min, info.max, (73, 11, 3, 17), dtype=torch.uint8)
1121
        b = torch.all(a, dim=0)
1122
        c = a.to(torch.bool).all(dim=0)
1123
        self.assertEqual(torch.ne(b, c).sum(), 0)
1124

1125
    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1126
    @dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
1127
    def test_max_with_inf(self, device, dtype):
1128
        a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1129
        self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item())
1130
        self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item())
1131
        self.assertTrue(torch.max(a).item() == inf)
1132
        self.assertTrue(torch.amax(a).item() == inf)
1133

1134
    @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
1135
    @dtypes(torch.half, torch.float, torch.bfloat16, torch.double)
1136
    def test_min_with_inf(self, device, dtype):
1137
        a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device)
1138
        self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item())
1139
        self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item())
1140
        self.assertTrue(torch.min(a).item() == -inf)
1141
        self.assertTrue(torch.amin(a).item() == -inf)
1142

1143
    def _test_minmax_helper(self, torchfn, reffn, device, dtype, skip_indices=False):
1144
        def create_input(shape, device, dtype):
1145
            if dtype.is_floating_point:
1146
                return torch.randn(*shape, device=device, dtype=dtype)
1147
            else:
1148
                low = 0 if dtype == torch.bool else -1000
1149
                high = 2 if dtype == torch.bool else 1000
1150
                return torch.randint(low, high, shape, device=device, dtype=dtype)
1151
        x = create_input((100, 100), device, dtype)
1152
        self.compare_with_numpy(torchfn, reffn, x)
1153
        # non contiguous
1154
        x = create_input((10, 10, 10), device, dtype)
1155
        x = x[:, 4]
1156
        self.compare_with_numpy(torchfn, reffn, x)
1157

1158
        def get_values(x):
1159
            if isinstance(x, tuple):
1160
                return x[0]
1161
            return x
1162

1163
        # indices
1164
        if not skip_indices:
1165
            size = 5
1166
            x = create_input((size, size), device, dtype)
1167
            inputs = (x, x.t())
1168
            dims = (0, 1)
1169
            for xinp, d in product(inputs, dims):
1170
                self.compare_with_numpy(lambda x: get_values(torchfn(x, d, False)), lambda x: reffn(x, d, keepdims=False), xinp)
1171
                result = torchfn(xinp, d, False)
1172
                if isinstance(result, tuple):
1173
                    v, i = result
1174
                    if d == 1:
1175
                        self.assertEqual(xinp[torch.arange(size), i], v, atol=0, rtol=0)
1176
                    else:
1177
                        self.assertEqual(xinp[i, torch.arange(size)], v, atol=0, rtol=0)
1178
        # nan
1179
        if dtype.is_floating_point:
1180
            for index in (0, 4, 99):
1181
                x = create_input((100,), device, dtype)
1182
                x[index] = nan
1183
                if not skip_indices:
1184
                    result = torchfn(x, 0)
1185
                    v = get_values(result)
1186
                    self.assertEqual(v, nan)
1187
                    if isinstance(result, tuple):
1188
                        i = result[1]
1189
                        self.assertEqual(i, index)
1190
                self.assertEqual(torchfn(x), nan)
1191

1192
    @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1193
    @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1194
    @dtypes(torch.half, torch.float, torch.double)
1195
    def test_max(self, device, dtype):
1196
        self._test_minmax_helper(torch.max, np.amax, device, dtype)
1197

1198
    @dtypesIfCPU(torch.float, torch.double, torch.long, torch.bool, torch.half)
1199
    @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool)
1200
    @dtypes(torch.half, torch.float, torch.double)
1201
    def test_min(self, device, dtype):
1202
        self._test_minmax_helper(torch.min, np.amin, device, dtype)
1203

1204
    @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1205
    @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1206
    @dtypes(torch.half, torch.float, torch.double)
1207
    def test_amin(self, device, dtype):
1208
        self._test_minmax_helper(torch.amin, np.amin, device, dtype)
1209

1210
    @dtypesIfCPU(torch.half, torch.float, torch.double, torch.int, torch.long, torch.bool)
1211
    @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool)
1212
    @dtypes(torch.float, torch.double)
1213
    def test_amax(self, device, dtype):
1214
        self._test_minmax_helper(torch.amax, np.amax, device, dtype)
1215

1216
    @onlyNativeDeviceTypes
1217
    @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
1218
    @dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
1219
    def test_aminmax(self, device, dtype):
1220

1221
        def _amin_wrapper(x, dim=None, keepdims=False):
1222
            with self.assertWarnsOnceRegex(UserWarning, "_aminmax is deprecated"):
1223
                if dim is None:
1224
                    return torch._aminmax(x)[0]
1225
                else:
1226
                    return torch._aminmax(x, dim, keepdims)[0]
1227

1228
        def _amax_wrapper(x, dim=None, keepdims=False):
1229
            with self.assertWarnsOnceRegex(UserWarning, "_aminmax is deprecated"):
1230
                if dim is None:
1231
                    return torch._aminmax(x)[1]
1232
                else:
1233
                    return torch._aminmax(x, dim, keepdims)[1]
1234

1235
        self._test_minmax_helper(_amin_wrapper, np.amin, device, dtype)
1236
        self._test_minmax_helper(_amax_wrapper, np.amax, device, dtype)
1237

1238
    # TODO: bincount isn't a classic reduction -- maybe this test suite is
1239
    #   reductions and summary ops?
1240
    def test_bincount(self, device):
1241
        # negative input throws
1242
        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1243
            torch.bincount(torch.tensor([1, -1], device=device))
1244
        # n-d input, with n > 1 throws
1245
        with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'):
1246
            torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device))
1247
        # floating input type throws
1248
        with self.assertRaisesRegex(RuntimeError, 'not implemented'):
1249
            torch.bincount(torch.tensor([1., 0.3], device=device))
1250
        # minlength < 0 throws
1251
        with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'):
1252
            torch.bincount(torch.tensor([1, 3], device=device),
1253
                           torch.tensor([.2, .2], device=device),
1254
                           minlength=-1)
1255
        # n-d weights, with n > 1 throws
1256
        with self.assertRaisesRegex(RuntimeError, '1-d'):
1257
            torch.bincount(torch.tensor([1, 0], device=device),
1258
                           torch.tensor([[1., 0.3], [1., 0.3]], device=device))
1259
        # input and weights dim mismatch
1260
        with self.assertRaisesRegex(RuntimeError, 'same length'):
1261
            torch.bincount(torch.tensor([1, 0], device=device),
1262
                           torch.tensor([1., 0.3, 0.5], device=device))
1263
        # 1-d input with no elements and default minlength
1264
        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)),
1265
                         torch.zeros(0, dtype=torch.long, device=device))
1266
        # 1-d input with no elements and specified minlength
1267
        self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10),
1268
                         torch.zeros(10, dtype=torch.long, device=device))
1269

1270
        # test tensor method without weights
1271
        long_counts = torch.tensor(
1272
            [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount()
1273
        self.assertEqual(
1274
            torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device),
1275
            long_counts)
1276
        # test avoiding overflow for uint8 (#76979)
1277
        count_uint8 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.uint8, device=device).bincount()
1278
        count_int16 = torch.tensor([0, 1, 2, 3, 255], dtype=torch.int16, device=device).bincount()
1279
        self.assertEqual(count_uint8, count_int16)
1280
        # test minlength functionality
1281
        int_counts = torch.bincount(
1282
            torch.tensor([1, 1, 1, 1], device=device), minlength=5)
1283
        self.assertEqual(
1284
            torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device),
1285
            int_counts)
1286
        # test weights
1287
        byte_counts = torch.bincount(
1288
            torch.tensor([0, 1, 1, 1, 4], device=device),
1289
            torch.tensor([.1, .2, .3, .4, .5], device=device))
1290
        self.assertEqual(
1291
            torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts)
1292
        byte_counts = torch.bincount(
1293
            torch.tensor([0, 1, 1, 1, 4], device=device),
1294
            torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device))
1295
        self.assertEqual(
1296
            torch.tensor([1, 9, 0, 0, 5], device=device, dtype=torch.float64), byte_counts)
1297
        # test non-contiguous inputs and weights
1298
        inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device)
1299
        weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device)
1300
        for i in [0, 1]:
1301
            assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous"
1302
            assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous"
1303
        # inputs are non-contiguous but weights are contiguous
1304
        self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2]))
1305
        # inputs and weights are non-contiguous
1306
        self.assertEqual(
1307
            inputs[:, 1].bincount(weights[:, 1]),
1308
            torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1309
        # weights are non-contiguous but inputs are contiguous
1310
        self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]),
1311
                         torch.tensor([1, 9, 0, 0, 5], dtype=torch.float32))
1312

1313
        # test bincount on non-contiguous slices
1314
        all0s = torch.zeros((32, 2), dtype=torch.int64, device=device)
1315
        self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32]))
1316

1317
        all1s = torch.ones((32, 2), dtype=torch.int64, device=device)
1318
        self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32]))
1319

1320
        # test large number of bins - global memory use
1321
        big_exp = torch.zeros(10000000, device=device)
1322
        big_exp[-1] = 50.0
1323
        big_w = torch.tensor([.5] * 100, device=device)
1324
        big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w)
1325
        self.assertEqual(big_exp, big_out)
1326
        # test large input size
1327
        big_exp = torch.zeros(2, device=device, dtype=torch.int64)
1328
        big_exp[1] = 1000000
1329
        big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount()
1330
        self.assertEqual(big_exp, big_out)
1331

1332
    # TODO: how many var stability tests are there?
1333
    def test_var_stability2(self, device):
1334
        tensor = torch.FloatTensor([2281.5, 2281.25]).to(device)
1335

1336
        # Stability for inner dim
1337
        self.assertEqual(tensor.var(0), 0.03125)
1338

1339
        # General stability
1340
        self.assertEqual(tensor.var(), 0.03125)
1341

1342
        # Stability for outer dimensions
1343
        tensor = tensor.unsqueeze(1)
1344
        self.assertEqual(tensor.var(0), 0.03125)
1345

1346
    @onlyCPU
1347
    @dtypes(torch.bool, torch.double)
1348
    def test_sum_all(self, device, dtype) -> None:
1349
        def check_sum_all(tensor: torch.Tensor) -> None:
1350
            pylist = tensor.reshape(-1).tolist()
1351
            self.assertEqual(tensor.sum(), sum(pylist))
1352

1353
        if dtype != torch.bool:
1354
            check_sum_all(torch.tensor([1, 2, 3, 4, 5], dtype=dtype, device=device))
1355
            check_sum_all(torch.randn(200000, dtype=dtype, device=device))
1356
            check_sum_all(torch.randn(2000, 2, dtype=dtype, device=device)[:, 0])
1357
        else:
1358
            check_sum_all(torch.tensor([True, False, True], dtype=torch.bool, device=device))
1359

1360
    def _test_memory_format_transformations(self, device, input_generator_fn, transformation_fn,
1361
                                            memory_format, compare_data=True, default_is_preserve=False):
1362

1363
        assert memory_format == torch.channels_last or memory_format == torch.channels_last_3d
1364

1365
        # xc is a channels last tensor
1366
        xc = input_generator_fn(device)
1367
        # xc is not memory dense, but looks like channels last
1368
        if memory_format == torch.channels_last:
1369
            xc = xc[..., ::2, ::2]
1370
        else:
1371
            xc = xc[..., ::2, ::2, ::2]
1372

1373
        clone = transformation_fn(xc, memory_format=torch.preserve_format)
1374
        self.assertFalse(clone.is_contiguous())
1375
        self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1376
        self.assertFalse(xc.is_contiguous())
1377
        self.assertFalse(xc.is_contiguous(memory_format=memory_format))
1378
        if compare_data:
1379
            self.assertEqual(xc, clone.to(xc))
1380

1381
        xc = input_generator_fn(device)
1382
        clone = transformation_fn(xc, memory_format=torch.contiguous_format)
1383
        self.assertTrue(clone.is_contiguous())
1384
        self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1385
        if compare_data:
1386
            self.assertEqual(xc, clone.to(xc))
1387

1388
        xc = input_generator_fn(device)
1389
        clone = transformation_fn(xc)
1390

1391
        if default_is_preserve:
1392
            self.assertFalse(clone.is_contiguous())
1393
            self.assertTrue(clone.is_contiguous(memory_format=memory_format))
1394
        else:
1395
            self.assertTrue(clone.is_contiguous())
1396
            self.assertFalse(clone.is_contiguous(memory_format=memory_format))
1397
        if compare_data:
1398
            self.assertEqual(xc, clone.to(xc))
1399

1400
        x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device)
1401
        for _ in range(10):
1402
            permutation = list(range(len(x.shape)))
1403
            random.shuffle(permutation)
1404
            x = x.permute(permutation)
1405
            self.assertEqual(x.stride(), transformation_fn(x, memory_format=torch.preserve_format).stride())
1406

1407
    @onlyCPU
1408
    @dtypes(torch.double)
1409
    def test_sum_out(self, device, dtype: torch.dtype) -> None:
1410
        x = torch.rand(100, 100, dtype=dtype, device=device)
1411
        res1 = torch.sum(x, 1)
1412
        res2 = torch.tensor((), dtype=dtype, device=device)
1413
        torch.sum(x, 1, out=res2)
1414
        self.assertEqual(res1, res2)
1415
        x = torch.rand(100, 100, 100, dtype=dtype, device=device)
1416
        res1 = x.sum(2).sum(1)
1417
        res2 = torch.tensor((), dtype=dtype, device=device)
1418
        torch.sum(x, (2, 1), out=res2)
1419
        self.assertEqual(res1, res2)
1420

1421
    @onlyCUDA
1422
    @dtypes(torch.float16, torch.float32)
1423
    def test_prod_gpu(self, device, dtype):
1424
        x = torch.tensor([2, 3, 6, 9, 8], dtype=dtype, device=device)
1425

1426
        # Check all combinations: fp16 input - fp16 output, fp16 input - fp32
1427
        # output, fp32 input - fp16 output, fp32 input - fp32 output
1428
        for dtype_output in [torch.float16, torch.float32]:
1429
            result_expected = torch.tensor(2592, dtype=dtype_output, device=device)
1430
            output = torch.prod(x, dtype=dtype_output)
1431
            self.assertEqual(output, result_expected)
1432

1433
            output = x.prod(dtype=dtype_output)
1434
            self.assertEqual(output, result_expected)
1435

1436
    @onlyCPU
1437
    @dtypes(torch.float)
1438
    def test_prod(self, device, dtype):
1439
        x = torch.rand(100, 100, dtype=dtype, device=device)
1440
        res1 = torch.prod(x, 1)
1441
        res2 = torch.tensor((), dtype=dtype, device=device)
1442
        torch.prod(x, 1, out=res2)
1443
        self.assertEqual(res1, res2)
1444

1445
    @onlyCPU
1446
    @dtypes(torch.float16, torch.bfloat16)
1447
    def test_prod_lowp(self, device, dtype):
1448
        x = torch.rand(100, 100, dtype=dtype, device=device)
1449
        x_ref = x.float()
1450
        res1 = torch.prod(x, 1)
1451
        res2 = torch.prod(x_ref, 1)
1452
        self.assertEqual(res1, res2.to(dtype=dtype))
1453
        res1 = torch.prod(x, 0)
1454
        res2 = torch.prod(x_ref, 0)
1455
        self.assertEqual(res1, res2.to(dtype=dtype))
1456

1457
    def test_prod_bool(self, device):
1458
        vals = [[True, True], [True, False], [False, False], []]
1459
        for val in vals:
1460
            result = torch.prod(torch.tensor(val, device=device), dtype=torch.bool).item()
1461
            expect = np.prod(np.array(val), dtype=bool)
1462
            self.assertEqual(result, expect)
1463

1464
            result = torch.prod(torch.tensor(val, device=device)).item()
1465
            expect = np.prod(np.array(val))
1466
            self.assertEqual(result, expect)
1467

1468
    @onlyCPU
1469
    def test_max_mixed_devices(self, device):
1470
        a = torch.randn(10, device=device)
1471
        if torch.cuda.is_available():
1472
            values = torch.randn(10).cuda()
1473
            indices = torch.cuda.LongTensor()
1474
            self.assertRaises(RuntimeError,
1475
                              lambda: torch.max(a, 0, out=(values, indices)))
1476
            self.assertRaises(RuntimeError,
1477
                              lambda: torch.amax(a, 0, out=values))
1478

1479
    @onlyCPU
1480
    def test_min_mixed_devices(self, device):
1481
        a = torch.randn(10, device=device)
1482
        if torch.cuda.is_available():
1483
            values = torch.randn(10).cuda()
1484
            indices = torch.cuda.LongTensor()
1485
            self.assertRaises(RuntimeError,
1486
                              lambda: torch.min(a, 0, out=(values, indices)))
1487
            self.assertRaises(RuntimeError,
1488
                              lambda: torch.amin(a, 0, out=values))
1489

1490
    # TODO: consider refactoring with bincount test
1491
    def test_bucketization(self, device):
1492
        values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device)
1493
        values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1494

1495
        # simple 1d boundary and 3d input value
1496
        boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device)
1497
        expected_result = torch.tensor([[[0, 2, 4], [1, 3, 5]], [[0, 1, 2], [3, 4, 5]]], device=device)
1498
        output = torch.empty(2, 2, 3, device=device, dtype=torch.int64)
1499
        self.assertEqual(torch.bucketize(values_3d, boundaries), expected_result)
1500
        self.assertEqual(torch.bucketize(values_3d, boundaries, out=output), expected_result)
1501
        expected_result = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device)
1502
        self.assertEqual(torch.bucketize(values_3d, boundaries, right=True), expected_result)
1503
        self.assertEqual(torch.bucketize(values_3d, boundaries, out=output, right=True), expected_result)
1504

1505
        # simple float 1d boundary and 1d input with output int32 type
1506
        for dtype in [torch.float32, torch.float16]:
1507
            values_1d_float = values_1d.to(dtype)
1508
            boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=dtype)
1509
            expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1510
            self.assertEqual(torch.searchsorted(boundaries, values_1d_float, out_int32=True), expected_result)
1511
            self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1512

1513
        # multiple dimension input with 0 elements
1514
        boundaries = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=torch.int64)
1515
        values_0_el = torch.tensor([[[]]], device=device, dtype=torch.int64)
1516
        expected_result = values_0_el.to(torch.int64)
1517
        self.assertEqual(torch.searchsorted(boundaries, values_0_el), expected_result)
1518
        self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result)
1519

1520
        # nan input
1521
        values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64)
1522
        boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64)
1523
        expected_result = torch.tensor([1, 4, 2, 4], device=device)
1524
        self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result)
1525
        expected_result = torch.tensor([2, 4, 3, 4], device=device)
1526
        self.assertEqual(torch.searchsorted(boundaries, values_nan, right=True), expected_result)
1527
        self.assertEqual(torch.searchsorted(boundaries, values_nan, side='right'), expected_result)
1528

1529
        # type promotion and non contiguous tensors
1530
        values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32)
1531
        boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64)
1532
        expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device)
1533
        if self.device_type != 'xla':
1534
            self.assertWarnsRegex(
1535
                UserWarning, "tensor is non-contiguous",
1536
                lambda: self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result))
1537
        else:
1538
            # All tensors in XLA is contiguous even doing permute, no warning msg will be generate in XLA
1539
            self.assertEqual(torch.searchsorted(boundaries_permute, values_3d_permute), expected_result)
1540

1541
        # scalar type
1542
        boundaries = torch.tensor([1.5, 2.5, 3.5], device=device)
1543
        expected_result = torch.tensor(1, device=device)
1544
        self.assertEqual(torch.searchsorted(boundaries, 2), expected_result)
1545
        self.assertEqual(torch.bucketize(torch.tensor(2, device=device), boundaries), expected_result)
1546
        expected_result = torch.tensor(3, device=device)
1547
        scalar_tensor_nan = torch.tensor(float('nan'), device=device)
1548
        self.assertEqual(torch.searchsorted(boundaries, scalar_tensor_nan), expected_result)
1549
        self.assertEqual(torch.bucketize(float('nan'), boundaries, right=True), expected_result)
1550

1551
        # invalid input dimensions
1552
        boundaries = torch.tensor([[1, 2, 3], [4, 5, 6]], device=device)
1553
        with self.assertRaisesRegex(
1554
                RuntimeError, "first N-1 dimensions of boundaries tensor and input value tensor must match"):
1555
            torch.searchsorted(boundaries, values_3d)
1556
        with self.assertRaisesRegex(
1557
                RuntimeError, "boundaries tensor must be 1 dimension"):
1558
            torch.bucketize(values_3d, boundaries)
1559
        with self.assertRaisesRegex(
1560
                RuntimeError, "only when boundaries tensor dimension is 1"):
1561
            torch.searchsorted(boundaries, 1)
1562

1563
        # incompatiable output tensor's dtype
1564
        def test_output_dtype(dtype, is_int32):
1565
            output = values_1d.to(dtype)
1566
            with self.assertRaisesRegex(
1567
                    RuntimeError, "output tensor's dtype is wrong"):
1568
                torch.searchsorted(values_1d, values_1d, out=output, out_int32=is_int32)
1569

1570
        test_output_dtype(torch.float32, False)
1571
        test_output_dtype(torch.int32, False)
1572
        test_output_dtype(torch.int64, True)
1573

1574
        # invalid side argument
1575
        with self.assertRaisesRegex(RuntimeError, "side can only be 'left' or 'right'"):
1576
            torch.searchsorted(values_1d, values_1d, side='bad')
1577

1578
        # invalid sorter argument, wrong size
1579
        with self.assertRaisesRegex(RuntimeError, "boundary and sorter must have the same size"):
1580
            sequence = torch.rand_like(values_1d, dtype=torch.float)
1581
            _, sorted_idx = torch.sort(sequence)
1582
            torch.searchsorted(sequence, values_1d, sorter=sorted_idx[:-1])
1583

1584
        # invalid sorter argument, is not dtype long
1585
        with self.assertRaisesRegex(RuntimeError, "sorter must be a tensor of long dtype"):
1586
            sequence = torch.rand_like(values_1d, dtype=torch.float)
1587
            _, sorted_idx = torch.sort(sequence)
1588
            torch.searchsorted(sequence, values_1d, sorter=sorted_idx.to(torch.float32))
1589

1590
        # invalid sorter value, out of bound (>= innermost size)
1591
        with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1592
            torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([0, 1, 3]))
1593

1594
        # invalid sorter value, out of bound (< 0)
1595
        with self.assertRaisesRegex(RuntimeError, "sorter index out of range"):
1596
            torch.searchsorted(torch.tensor([1, 2, 3]), 2.5, sorter=torch.tensor([-1, 1, 2]))
1597

1598
        # scalar type bfloat16
1599
        if self.device_type == 'cpu':
1600
            def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False):
1601
                values_1d_float = values_1d.to(torch.float32)
1602
                boundaries = torch.tensor([0.9, 1, 2, 2, 3, 3, 4, 4.1, 9, 9], device=device, dtype=torch.float32)
1603
                if values_bf16:
1604
                    values_1d_float = values_1d_float.to(torch.bfloat16)
1605
                if boundaries_bf16:
1606
                    boundaries = boundaries.to(torch.bfloat16)
1607
                expected_result = torch.tensor([1, 2, 4, 6, 8, 8, 8, 8, 8], device=device, dtype=torch.int32)
1608
                self.assertEqual(torch.bucketize(values_1d_float, boundaries, out_int32=True), expected_result)
1609

1610
            test_dtype_bfloat16(True, False)
1611
            test_dtype_bfloat16(False, True)
1612
            test_dtype_bfloat16(True, True)
1613

1614
    @dtypes(*all_types_and(torch.half, torch.bfloat16))
1615
    def test_nansum(self, device, dtype):
1616
        args = product(
1617
            (True, False),  # noncontiguous
1618
            (0, 1, None),   # dim
1619
        )
1620
        zero = torch.zeros((), device=device, dtype=dtype)
1621

1622
        for noncontiguous, dim in args:
1623
            # Randomly scale the values
1624
            scale = random.randint(10, 100)
1625
            x = make_tensor((17, 17), device=device, dtype=dtype,
1626
                            low=-scale, high=scale, noncontiguous=noncontiguous)
1627

1628
            if dtype.is_floating_point:
1629
                nan_mask = x < 0.2 * scale
1630
                x_nonan = torch.where(nan_mask, zero, x)
1631
                x[nan_mask] = np.nan
1632
            else:
1633
                x_nonan = x
1634

1635
            dim_kwargs = {} if dim is None else {"dim": dim}
1636
            expect = torch.sum(x_nonan, **dim_kwargs)
1637
            actual = torch.nansum(x, **dim_kwargs)
1638
            self.assertEqual(expect, actual)
1639

1640
    def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype,
1641
                                            with_extremal=False, atol=None, rtol=None,
1642
                                            exact_dtype=True, with_keepdim=False):
1643
        # Test 0-d to 3-d tensors.
1644
        for ndims in range(0, 4):
1645
            shape = _rand_shape(ndims, min_size=5, max_size=10)
1646
            for n in range(ndims + 1):
1647
                for c in combinations(list(range(ndims)), n):
1648
                    for count_dim in permutations(c):
1649
                        # Generate Input.
1650
                        x = _generate_input(shape, dtype, device, with_extremal)
1651

1652
                        if count_dim == ():
1653
                            # Default `dims=None` case
1654
                            self.compare_with_numpy(torch_func, np_func, x, device=None, dtype=None,
1655
                                                    atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1656
                        else:
1657
                            # With `dims: tuple of ints` case
1658
                            if with_keepdim:
1659
                                torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim)
1660
                                np_func_partial = partial(np_func, keepdims=True, axis=count_dim)
1661
                            else:
1662
                                torch_func_partial = partial(torch_func, dim=count_dim)
1663
                                np_func_partial = partial(np_func, axis=count_dim)
1664
                            self.compare_with_numpy(torch_func_partial, np_func_partial, x, device=None, dtype=None,
1665
                                                    atol=atol, rtol=rtol, exact_dtype=exact_dtype)
1666

1667
    @dtypes(*all_types_and_complex_and(torch.half))
1668
    def test_count_nonzero(self, device, dtype):
1669
        self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype)
1670
        self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True)
1671

1672
    def _test_sum_reduction_vs_numpy(self, torch_fn, np_fn, device, dtype, with_keepdim=False, with_extremal=False):
1673
        def is_integral(dtype):
1674
            return dtype in integral_types()
1675

1676
        exact_dtype = True
1677
        # On Windows CI, the current version of `numpy` promotes all lower integers
1678
        # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking
1679
        # the exact dtype.
1680
        # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580
1681
        # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370
1682
        if IS_WINDOWS and is_integral(dtype):
1683
            exact_dtype = False
1684
        # For uint8, numpy promotes to uint64 while torch promotes to int64.
1685
        # So we must skip this as well.
1686
        if dtype == torch.uint8:
1687
            exact_dtype = False
1688

1689
        # TODO: Investigate why the output is not close to numpy.
1690
        if dtype == torch.float16:
1691
            atol = 0.4
1692
            rtol = 1e-2
1693
        elif dtype == torch.float32:
1694
            atol = 7e-05
1695
            rtol = 3e-06
1696
        else:
1697
            # Default values
1698
            atol = None
1699
            rtol = None
1700
        self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
1701
                                                 atol=atol, rtol=rtol, exact_dtype=exact_dtype,
1702
                                                 with_keepdim=with_keepdim, with_extremal=with_extremal)
1703

1704
    @onlyNativeDeviceTypes
1705
    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1706
    def test_sum_vs_numpy(self, device, dtype):
1707
        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype)
1708
        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True)
1709
        self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True)
1710

1711
    @onlyNativeDeviceTypes
1712
    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
1713
    def test_nansum_vs_numpy(self, device, dtype):
1714
        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype)
1715
        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True)
1716
        self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_keepdim=True)
1717

1718
    @onlyCPU
1719
    @dtypes(*complex_types())
1720
    def test_nansum_complex(self, device, dtype):
1721
        x = torch.randn((3, 3, 3), device=device, dtype=dtype)
1722
        with self.assertRaisesRegex(RuntimeError, "nansum does not support complex inputs"):
1723
            torch.nansum(x)
1724

1725
    @dtypes(*all_types_and(torch.half))
1726
    def test_nansum_out_dtype(self, device, dtype):
1727
        out_dtype = dtype
1728
        inp_dtypes = all_types_and(torch.half) if out_dtype.is_floating_point else integral_types()
1729
        for inp_dtype in inp_dtypes:
1730
            shape = _rand_shape(random.randint(2, 5), min_size=5, max_size=10)
1731
            x = _generate_input(shape, inp_dtype, device, with_extremal=False)
1732
            torch_fn = partial(torch.nansum, dtype=out_dtype)
1733
            np_out_dtype = torch_to_numpy_dtype_dict[out_dtype]
1734
            np_fn = partial(np.nansum, dtype=np_out_dtype)
1735
            self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
1736

1737
    @dtypes(*all_types_and(torch.half))
1738
    def test_argminmax_multiple(self, device, dtype):
1739
        # Case: All Ones
1740
        t = torch.ones(3, 3, device=device, dtype=dtype)
1741
        self.compare_with_numpy(torch.argmax, np.argmax, t)
1742
        self.compare_with_numpy(torch.argmin, np.argmin, t)
1743

1744
        # Case: With single `nan` present.
1745
        if dtype in floating_types_and(torch.half, torch.bfloat16):
1746
            t[2, 2] = float('nan')
1747
            self.compare_with_numpy(torch.argmax, np.argmax, t)
1748
            self.compare_with_numpy(torch.argmin, np.argmin, t)
1749

1750
        # Case: Randomly Generated Tensors
1751
        for ndims in range(1, 5):
1752
            shape = _rand_shape(ndims, min_size=5, max_size=10)
1753
            for with_extremal in [False, True]:
1754
                for contiguous in [False, True]:
1755
                    # Generate Input.
1756
                    x = _generate_input(shape, dtype, device, with_extremal)
1757

1758
                    if dtype == torch.half:
1759
                        max_val = torch.max(x.to(torch.float))
1760
                        min_val = torch.min(x.to(torch.float))
1761
                    else:
1762
                        max_val = torch.max(x)
1763
                        min_val = torch.min(x)
1764

1765
                    mask = torch.randn(x.shape) > 0.5
1766
                    x[mask] = torch.tensor(max_val + 1, dtype=dtype)
1767

1768
                    mask = torch.randn(x.shape) > 0.5
1769
                    x[mask] = torch.tensor(min_val - 1, dtype=dtype)
1770

1771
                    if not contiguous:
1772
                        x = x.T
1773

1774
                    self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None)
1775
                    self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None)
1776

1777
                    # Verify indices returned by max and min.
1778
                    if dtype != torch.half:
1779
                        rand_dim = random.randint(0, ndims - 1)
1780
                        self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1],
1781
                                                lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None)
1782
                        self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1],
1783
                                                lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None)
1784

1785
        def verify_against_numpy(t):
1786
            # Argmax
1787
            torch_fn = partial(torch.argmax, dim=1)
1788
            np_fn = partial(np.argmax, axis=1)
1789
            self.compare_with_numpy(torch_fn, np_fn, t)
1790
            # Non-contiguous input
1791
            self.compare_with_numpy(torch_fn, np_fn, t.T)
1792

1793
            # Verify indices returned by max.
1794
            if dtype != torch.half:
1795
                self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1796
                self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1797

1798
            # Argmin
1799
            torch_fn = partial(torch.argmin, dim=1)
1800
            np_fn = partial(np.argmin, axis=1)
1801
            self.compare_with_numpy(torch_fn, np_fn, t)
1802
            # Non-contiguous input
1803
            self.compare_with_numpy(torch_fn, np_fn, t.T)
1804

1805
            # Verify indices returned by min.
1806
            if dtype != torch.half:
1807
                self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1808
                self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1809

1810
        # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998
1811
        t = torch.tensor([[1, 5],
1812
                          [2, 10],
1813
                          [3, 3]], device=device, dtype=dtype)
1814
        verify_against_numpy(t)
1815

1816
        # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998
1817
        t = torch.tensor([[1, 5],
1818
                          [2, 10],
1819
                          [0, 0]], device=device, dtype=dtype)
1820
        verify_against_numpy(t)
1821

1822
    @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
1823
    def test_all_any_vs_numpy(self, device, dtype):
1824
        # Note [all, any uint8 compatibility]: However for compatibility reason,
1825
        # for `uint8`, they return Tensor of same dtype `uint8`.
1826
        # Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
1827
        exact_dtype = True if dtype != torch.uint8 else False
1828

1829
        def _test_all_any(x):
1830
            self.compare_with_numpy(torch.all, np.all, x)
1831
            self.compare_with_numpy(torch.any, np.any, x)
1832

1833
        def _test_all_any_with_dim(x, dim):
1834
            torch_fn = partial(torch.all, dim=dim)
1835
            np_fn = partial(np.all, axis=dim)
1836
            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1837

1838
            torch_fn = partial(torch.any, dim=dim)
1839
            np_fn = partial(np.any, axis=dim)
1840
            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1841

1842
        def _test_out_variant(x, dim):
1843
            out = torch.empty_like(x)
1844
            if dtype == torch.bool or dtype == torch.uint8:
1845
                expected = torch.all(x, dim)
1846
                torch.all(x, dim, out=out)
1847
                self.assertEqual(expected, out)
1848

1849
                expected = torch.any(x, dim)
1850
                torch.any(x, dim, out=out)
1851
                self.assertEqual(expected, out)
1852
            else:
1853
                with self.assertRaisesRegex(RuntimeError, "all only supports bool tensor for result, got"):
1854
                    torch.all(x, dim, out=out)
1855

1856
                with self.assertRaisesRegex(RuntimeError, "any only supports bool tensor for result, got"):
1857
                    torch.any(x, dim, out=out)
1858

1859
        def _test_all_any_with_dim_keepdim(x, dim, keepdim):
1860
            torch_fn = partial(torch.all, dim=dim, keepdim=keepdim)
1861
            np_fn = partial(np.all, axis=dim, keepdims=keepdim)
1862
            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1863

1864
            torch_fn = partial(torch.any, dim=dim, keepdim=keepdim)
1865
            np_fn = partial(np.any, axis=dim, keepdims=keepdim)
1866
            self.compare_with_numpy(torch_fn, np_fn, x, exact_dtype=exact_dtype)
1867

1868
        def _test_output_dtype(x):
1869
            # This test will fail once the functions return bool output
1870
            # for uint8 input.
1871
            expected_dtype = torch.uint8 if dtype == torch.uint8 else torch.bool
1872
            self.assertEqual(torch.all(x).dtype, expected_dtype)
1873
            self.assertEqual(torch.any(x).dtype, expected_dtype)
1874

1875
            self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype)
1876
            self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype)
1877

1878
        for ndim in range(5):
1879
            shape = _rand_shape(ndim, 1, 5)
1880
            x = _generate_input(shape, dtype, device, with_extremal=False)
1881
            _test_all_any(x)
1882
            _test_all_any(x.T)
1883
            _test_all_any(x[..., ::2])
1884

1885
            x = _generate_input(shape, dtype, device, with_extremal=True)
1886
            _test_all_any(x)
1887
            _test_all_any(x.T)
1888
            _test_all_any(x[..., ::2])
1889

1890
            x = torch.zeros_like(x)
1891
            _test_all_any(x)
1892
            _test_all_any(x.T)
1893
            _test_all_any(x[..., ::2])
1894

1895
            x = torch.ones_like(x)
1896
            _test_all_any(x)
1897
            _test_all_any(x.T)
1898
            _test_all_any(x[..., ::2])
1899
            _test_output_dtype(x)
1900
            for dim in range(ndim):
1901
                x = _generate_input(shape, dtype, device, with_extremal=False)
1902
                _test_all_any_with_dim(x, dim)
1903
                _test_all_any_with_dim(x.T, dim)
1904
                _test_all_any_with_dim(x[..., ::2], dim)
1905
                _test_out_variant(x, dim)
1906
                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1907
                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1908

1909
                x = _generate_input(shape, dtype, device, with_extremal=True)
1910
                _test_all_any_with_dim(x, dim)
1911
                _test_all_any_with_dim(x.T, dim)
1912
                _test_all_any_with_dim(x[..., ::2], dim)
1913
                _test_out_variant(x, dim)
1914
                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1915
                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1916

1917
                x = torch.zeros_like(x)
1918
                _test_all_any_with_dim(x, dim)
1919
                _test_all_any_with_dim(x.T, dim)
1920
                _test_all_any_with_dim(x[..., ::2], dim)
1921
                _test_out_variant(x, dim)
1922
                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1923
                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1924

1925
                x = torch.ones_like(x)
1926
                _test_all_any_with_dim(x, dim)
1927
                _test_all_any_with_dim(x.T, dim)
1928
                _test_all_any_with_dim(x[..., ::2], dim)
1929
                _test_out_variant(x, dim)
1930
                _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1931
                _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1932

1933
    # TODO: part of this test covers torch.norm, with should be covered by test_linalg
1934
    @onlyNativeDeviceTypes
1935
    def test_repeated_dim(self, device):
1936
        ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var,
1937
               torch.norm]
1938
        x = torch.randn(3, 3, 3, 3, device=device)
1939

1940
        error_msg = r'appears multiple times in the list of dims'
1941
        for op in ops:
1942
            for dim in [(0, 0), (0, -4)]:
1943
                with self.assertRaisesRegex(RuntimeError, error_msg):
1944
                    op(x, dim=dim)
1945

1946
    # TODO: update this test to comapre against NumPy
1947
    @onlyCUDA
1948
    def test_var(self, device):
1949
        cpu_tensor = torch.randn(2, 3, 3)
1950
        device_tensor = cpu_tensor.to(device)
1951
        self.assertEqual(device_tensor.var(), cpu_tensor.var())
1952
        self.assertEqual(device_tensor.var(1), cpu_tensor.var(1))
1953
        self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
1954
        self.assertEqual(device_tensor.std(), cpu_tensor.std())
1955
        self.assertEqual(device_tensor.std(1), cpu_tensor.std(1))
1956
        self.assertEqual(device_tensor.var(2), cpu_tensor.var(2))
1957

1958
        cpu_tensor = torch.randn(100)
1959
        device_tensor = cpu_tensor.to(device)
1960
        self.assertEqual(device_tensor.var(), cpu_tensor.var())
1961

1962
    # TODO: update this test to compare against NumPy
1963
    @onlyCUDA
1964
    def test_var_large_input(self, device):
1965
        # Large, not-nice input
1966
        cpu_tensor = torch.randn(2 * 32 * 1024 + 1, 2, 67)
1967
        device_tensor = cpu_tensor.to(device)
1968

1969
        self.assertEqual(cpu_tensor.var(2), device_tensor.var(2))
1970

1971
    # TODO: update this to compare against NumPy instead of CPU
1972
    @onlyCUDA
1973
    @dtypes(torch.double)
1974
    def test_sum_noncontig(self, device, dtype):
1975
        x = torch.randn(1, 75, 57, 20, dtype=dtype, device=device).permute(0, 3, 1, 2)
1976
        y = x.cpu()
1977
        self.assertEqual(x.sum().cpu(), y.sum())
1978
        self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2)))
1979
        self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3)))
1980

1981
    # TODO: update this to compare against NumPy instead of CPU
1982
    @onlyCUDA
1983
    def test_min_max_nan(self, device):
1984
        tests = [(lambda x: x.min(), 'min'),
1985
                 (lambda x: x.max(), 'max'),
1986
                 (lambda x: x.amin(), 'amin'),
1987
                 (lambda x: x.amax(), 'amax'),
1988
                 (lambda x: x.min(0).values, 'min_dim'),
1989
                 (lambda x: x.max(0).values, 'max_dim'),
1990
                 (lambda x: x.amin(0), 'amin_dim'),
1991
                 (lambda x: x.amax(0), 'amax_dim')]
1992
        for f, name in tests:
1993
            a = torch.arange(25.0).view(5, 5)
1994
            a[2, 2] = nan
1995
            actual = f(a.to(device)).cpu()
1996
            expected = f(a).cpu()
1997
            self.assertEqual(torch.isnan(actual), torch.isnan(expected), msg=f'nans for {name}')
1998
            self.assertEqual(actual[~torch.isnan(actual)],
1999
                             expected[~torch.isnan(expected)], msg=f'nans for {name}')
2000

2001
    # TODO: make this test generic using OpInfos
2002
    @onlyCUDA
2003
    def test_sum_cpu_device_mismatch(self, device):
2004
        x = torch.randn(20, dtype=torch.float32, device=device)
2005
        y = torch.randn(1, dtype=torch.float32)
2006

2007
        err_string = f"Expected out tensor to have device {device}, but got cpu instead"
2008

2009
        with self.assertRaisesRegex(RuntimeError, err_string):
2010
            torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2011

2012
        # tests half to float promotion
2013
        if self.device_type == 'cuda':
2014
            x = x.half()
2015
            with self.assertRaisesRegex(RuntimeError, err_string):
2016
                torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2017

2018
    # Assert for illegal dtype would not be raised on XLA
2019
    @onlyNativeDeviceTypes
2020
    def test_minmax_illegal_dtype(self, device):
2021
        x = torch.randn(5, 5, dtype=torch.float32, device=device)
2022
        valid_values = torch.empty(5, dtype=torch.float32, device=device)
2023
        valid_indices = torch.empty(5, dtype=torch.long, device=device)
2024
        illegal_values = torch.empty(5, dtype=torch.int, device=device)
2025
        illegal_indices = torch.empty(5, dtype=torch.double, device=device)
2026
        torch.max(x, dim=0, out=(valid_values, valid_indices))
2027
        torch.min(x, dim=0, out=(valid_values, valid_indices))
2028
        torch.amax(x, dim=0, out=valid_values)
2029
        torch.amin(x, dim=0, out=valid_values)
2030
        rmsg = r'scalar type|dtype'
2031
        with self.assertRaisesRegex(RuntimeError, rmsg):
2032
            torch.max(x, dim=0, out=(illegal_values, valid_indices))
2033
        with self.assertRaisesRegex(RuntimeError, rmsg):
2034
            torch.min(x, dim=0, out=(illegal_values, valid_indices))
2035
        with self.assertRaisesRegex(RuntimeError, rmsg):
2036
            torch.max(x, dim=0, out=(valid_values, illegal_indices))
2037
        with self.assertRaisesRegex(RuntimeError, rmsg):
2038
            torch.min(x, dim=0, out=(valid_values, illegal_indices))
2039
        with self.assertRaisesRegex(RuntimeError, rmsg):
2040
            torch.max(x, dim=0, out=(illegal_values, illegal_indices))
2041
        with self.assertRaisesRegex(RuntimeError, rmsg):
2042
            torch.min(x, dim=0, out=(illegal_values, illegal_indices))
2043

2044
    @dtypes(*all_types_and(torch.half, torch.bfloat16))
2045
    def test_dim_arg_reduction_scalar(self, device, dtype):
2046
        example = 4.0
2047

2048
        x = torch.tensor(example, device=device, dtype=dtype)
2049
        self.assertEqual(x.argmax().item(), 0)
2050
        self.assertEqual(x.argmax(dim=None).item(), 0)
2051
        self.assertEqual(x.argmax(dim=0).item(), 0)
2052
        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2053

2054
        x = torch.tensor(example, device=device, dtype=dtype)
2055
        self.assertEqual(x.argmin().item(), 0)
2056
        self.assertEqual(x.argmin(dim=None).item(), 0)
2057
        self.assertEqual(x.argmin(dim=0).item(), 0)
2058
        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2059

2060

2061
    @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
2062
    @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
2063
    def test_dim_reduction(self, device, dtype):
2064
        example = [[-1, 2, 1], [5, 3, 6]]
2065

2066
        sum_dtype = {
2067
            torch.bfloat16: torch.bfloat16,
2068
            torch.double: torch.double,
2069
            torch.float: torch.float,
2070
            torch.half: torch.half,
2071
            torch.int64: torch.int64,
2072
            torch.int32: torch.int64,
2073
            torch.int16: torch.int64,
2074
            torch.int8: torch.int64
2075
        }
2076

2077
        # This won't test for 256bit instructions, since we usually
2078
        # only work on 1 cacheline (512bit) at a time and these
2079
        # examples aren't big enough to trigger that.
2080
        x = torch.tensor(example, device=device, dtype=dtype)
2081
        self.assertEqual(x.sum().item(), 16)
2082
        self.assertEqual(x.sum(0), torch.tensor([4, 5, 7], dtype=sum_dtype[dtype]))
2083
        self.assertEqual(x.sum(1), torch.tensor([2, 14], dtype=sum_dtype[dtype]))
2084
        y = torch.tensor(example, device=device, dtype=sum_dtype[dtype])
2085
        torch.sum(x, 0, out=y)
2086
        self.assertEqual(x.sum(0), y)
2087

2088
        # Mean not supported for Int types
2089
        if dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
2090
            x = torch.tensor(example, device=device, dtype=dtype)
2091
            self.assertEqual(x.mean().item(), 16.0 / 6)
2092
            self.assertEqual(x.mean(0), torch.tensor([2.0, 2.5, 7.0 / 2], dtype=dtype))
2093
            self.assertEqual(x.mean(1), torch.tensor([2.0 / 3, 14.0 / 3], dtype=dtype))
2094
            self.assertEqual(x.mean(), x.mean((0, 1)))
2095

2096
        prod_dtype = {
2097
            torch.bfloat16: torch.bfloat16,
2098
            torch.double: torch.double,
2099
            torch.float: torch.float,
2100
            torch.float16: torch.float16,
2101
            torch.int64: torch.int64,
2102
            torch.int32: torch.int64,
2103
            torch.int16: torch.int64,
2104
            torch.int8: torch.int64,
2105
        }
2106

2107
        # prod is not supported for float16 & bfloat16 on CPU
2108
        if not (self.device_type == 'cpu' and dtype in [torch.float16, torch.bfloat16]):
2109
            x = torch.tensor(example, device=device, dtype=dtype)
2110
            self.assertEqual(x.prod().item(), -180)
2111
            self.assertEqual(x.prod(0), torch.tensor([-5, 6, 6], dtype=prod_dtype[dtype]))
2112
            self.assertEqual(x.prod(1), torch.tensor([-2, 90], dtype=prod_dtype[dtype]))
2113

2114
        x = torch.tensor(example, device=device, dtype=dtype)
2115

2116
        self.assertEqual(x.min().item(), -1)
2117
        self.assertEqual(x.argmin().item(), 0)
2118

2119
        # TODO: torch.min does not support the same operation as argmin
2120
        # for the same case, should we enable it?
2121
        self.assertEqual(x.argmin(dim=None).item(), 0)
2122

2123
        self.assertEqual(x.min(0), (torch.tensor([-1, 2, 1], dtype=dtype),
2124
                                    torch.tensor([0, 0, 0], dtype=torch.int64)))
2125
        self.assertEqual(x.amin(0), torch.tensor([-1, 2, 1], dtype=dtype))
2126
        self.assertEqual(x.argmin(0), torch.tensor([0, 0, 0], dtype=torch.int64))
2127

2128
        self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype),
2129
                         torch.tensor([[0, 0, 0]], dtype=torch.int64)))
2130
        self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype))
2131
        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64))
2132

2133
        self.assertEqual(x.min(1), (torch.tensor([-1, 3], dtype=dtype),
2134
                         torch.tensor([0, 1], dtype=torch.int64)))
2135
        self.assertEqual(x.amin(1), torch.tensor([-1, 3], dtype=dtype))
2136
        self.assertEqual(x.argmin(1), torch.tensor([0, 1], dtype=torch.int64))
2137

2138
        self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype),
2139
                         torch.tensor([[0], [1]], dtype=torch.int64)))
2140
        self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype))
2141
        self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64))
2142

2143
        # test that non-contiguous tensors work
2144
        self.assertEqual(x[:, :2].min().item(), -1)
2145
        self.assertEqual(x[:, :2].amin().item(), -1)
2146
        self.assertEqual(x[:, :2].argmin().item(), 0)
2147

2148
        x = torch.tensor(example, device=device, dtype=dtype)
2149

2150
        self.assertEqual(x.max().item(), 6)
2151
        self.assertEqual(x.amax().item(), 6)
2152
        self.assertEqual(x.argmax().item(), 5)
2153

2154
        self.assertEqual(x.max(0), (torch.tensor([5, 3, 6], dtype=dtype),
2155
                                    torch.tensor([1, 1, 1], dtype=torch.int64)))
2156
        self.assertEqual(x.amax(0), torch.tensor([5, 3, 6], dtype=dtype))
2157
        self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64))
2158

2159
        self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype),
2160
                                                      torch.tensor([[1, 1, 1]], dtype=torch.int64)))
2161
        self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype))
2162
        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64))
2163

2164
        self.assertEqual(x.max(1), (torch.tensor([2, 6], dtype=dtype),
2165
                                    torch.tensor([1, 2], dtype=torch.int64)))
2166
        self.assertEqual(x.amax(1), torch.tensor([2, 6], dtype=dtype))
2167
        self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64))
2168

2169
        self.assertEqual(x.max(1, keepdim=True), (torch.tensor([[2], [6]], dtype=dtype),
2170
                                                  torch.tensor([[1], [2]], dtype=torch.int64)))
2171
        self.assertEqual(x.amax(1, keepdim=True), torch.tensor([[2], [6]], dtype=dtype))
2172
        self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64))
2173

2174
        # test that non-contiguous tensors work
2175
        self.assertEqual(x[:, :2].max().item(), 5)
2176
        self.assertEqual(x[:, :2].amax().item(), 5)
2177
        self.assertEqual(x[:, :2].argmax().item(), 2)
2178

2179
        dim_red_fns = [
2180
            "mean", "median", "nanmedian", "mode", "norm", "prod",
2181
            "std", "sum", "var", "max", "min", "amax", "amin"]
2182

2183
        def normfn_attr(t, dim, keepdim=False, out=None):
2184
            attr = torch.norm
2185
            return attr(t, 2, dim, keepdim, out=out)
2186

2187
        for fn_name in dim_red_fns:
2188
            fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
2189

2190
            def fn(x, dim, keepdim=False, out=None):
2191
                ans = fn_attr(x, dim, keepdim=keepdim, out=out)
2192
                return ans if not isinstance(ans, tuple) else ans[0]
2193

2194
            def fn_tuple(x, dim, keepdim=False, out=None):
2195
                return fn_attr(x, dim, keepdim=keepdim, out=out)
2196

2197
            def test_multidim(x, dim):
2198
                self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
2199
                self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
2200
                self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())
2201

2202
            # general case
2203
            x = torch.randn(3, 4, 5, device=device)
2204
            dim = random.randint(0, 2)
2205
            test_multidim(x, dim)
2206

2207
            # check 1-d behavior
2208
            x = torch.randn(1, device=device)
2209
            dim = 0
2210
            self.assertEqual(fn(x, dim).shape, ())
2211
            self.assertEqual(fn(x, dim, keepdim=True).shape, (1,))
2212

2213
            # check reducing of a singleton dimension
2214
            dims = [3, 4, 5]
2215
            singleton_dim = random.randint(0, 2)
2216
            dims[singleton_dim] = 1
2217
            x = torch.randn(dims, device=device)
2218
            test_multidim(x, singleton_dim)
2219

2220
            # check reducing with output kwargs
2221
            if fn_name in ['median', 'nanmedian', 'mode', 'max', 'min']:
2222
                y = torch.randn(5, 3, device=device)
2223
                values = torch.randn(5, 3, device=device)
2224
                indices = torch.zeros(5, 3, device=device).long() - 1
2225
                fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
2226
                values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
2227
                self.assertEqual(values[:, 1], values_expected,
2228
                                 msg=f'{fn_name} values with out= kwarg')
2229
                self.assertEqual(indices[:, 1], indices_expected,
2230
                                 msg=f'{fn_name} indices with out= kwarg')
2231
                continue
2232

2233
            x = torch.randn(5, 3, device=device)
2234
            y = torch.randn(5, 3, device=device)
2235
            fn(y, 1, keepdim=False, out=x[:, 1])
2236
            expected = fn(y, 1, keepdim=False)
2237
            self.assertEqual(x[:, 1], expected, msg=f'{fn_name} with out= kwarg')
2238

2239
    @onlyCUDA
2240
    @largeTensorTest('10GB')
2241
    def test_reduction_split(self, device):
2242
        # Test reduction when there is a 32bit-indexing split
2243
        # https://github.com/pytorch/pytorch/issues/37583
2244
        input_ = torch.randn(5, 14400, 14400, device=device)
2245
        result = input_.sum(dim=0)
2246
        expect = input_[0] + input_[1] + input_[2] + input_[3] + input_[4]
2247
        self.assertEqual(result, expect)
2248

2249
    @onlyCUDA
2250
    @dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2251
    def test_reduction_vectorize_along_input_corner(self, device, dtype):
2252
        # 1D case: sum
2253
        size = 1024 * 1024 * 64 + 3
2254
        shift = 1
2255
        x = torch.zeros(size, dtype=dtype, device=device)
2256
        y = x[shift:]
2257
        for i in range(100):
2258
            x.zero_()
2259
            x[i] = 1
2260
            self.assertEqual(x.sum(), 1.0)
2261
            if i < shift:
2262
                self.assertEqual(y.sum(), 0.0)
2263
            else:
2264
                self.assertEqual(y.sum(), 1.0)
2265
        for i in range(1, 100):
2266
            x.zero_()
2267
            x[-i] = 1
2268
            self.assertEqual(x.sum(), 1.0)
2269
            self.assertEqual(y.sum(), 1.0)
2270
        # 1D case: argmax
2271
        size = 1024 * 1024 * 64 + 3
2272
        shift = 1
2273
        ysize = size - shift
2274
        x = torch.zeros(size, dtype=dtype, device=device)
2275
        y = x[shift:]
2276
        for i in range(100):
2277
            x.zero_()
2278
            x[i] = 1
2279
            self.assertEqual(x.argmax().item(), i)
2280
            if i >= shift:
2281
                self.assertEqual(y.argmax().item(), i - shift)
2282
        for i in range(1, 100):
2283
            x.zero_()
2284
            x[-i] = 1
2285
            self.assertEqual(x.argmax().item(), size - i)
2286
            self.assertEqual(y.argmax().item(), ysize - i)
2287
        # 2D case: sum
2288
        size = (7, 1024 * 1024 + 3)
2289
        x = torch.zeros(size, dtype=dtype, device=device)
2290
        for i in range(100):
2291
            x.zero_()
2292
            for j in range(7):
2293
                x[j][i] = j
2294
            xs = x.sum(dim=-1)
2295
            for j in range(7):
2296
                self.assertEqual(xs[j].item(), float(j))
2297
        for i in range(100):
2298
            x.zero_()
2299
            for j in range(7):
2300
                x[j][-i] = j
2301
            xs = x.sum(dim=-1)
2302
            for j in range(7):
2303
                self.assertEqual(xs[j].item(), float(j))
2304
        # 2D case: max/argmax
2305
        size = (7, 1024 * 1024 + 3)
2306
        x = torch.zeros(size, dtype=dtype, device=device)
2307
        for i in range(100):
2308
            x.zero_()
2309
            for j in range(7):
2310
                x[j][i] = j + 1
2311
            xs1 = x.argmax(dim=-1)
2312
            xs2 = x.max(dim=-1).indices
2313
            for j in range(7):
2314
                self.assertEqual(xs1[j].item(), i)
2315
                self.assertEqual(xs2[j].item(), i)
2316
        for i in range(1, 100):
2317
            x.zero_()
2318
            for j in range(7):
2319
                x[j][-i] = j + 1
2320
            xs1 = x.argmax(dim=-1)
2321
            xs2 = x.max(dim=-1).indices
2322
            for j in range(7):
2323
                self.assertEqual(xs1[j].item(), size[1] - i)
2324
                self.assertEqual(xs2[j].item(), size[1] - i)
2325
        # 2D case: min/argmin
2326
        size = (7, 1024 * 1024 + 3)
2327
        x = torch.zeros(size, dtype=dtype, device=device)
2328
        for i in range(100):
2329
            x.zero_()
2330
            for j in range(7):
2331
                x[j][i] = -(j + 1)
2332
            xs1 = x.argmin(dim=-1)
2333
            xs2 = x.min(dim=-1).indices
2334
            for j in range(7):
2335
                self.assertEqual(xs1[j].item(), i)
2336
                self.assertEqual(xs2[j].item(), i)
2337
        for i in range(1, 100):
2338
            x.zero_()
2339
            for j in range(7):
2340
                x[j][-i] = -(j + 1)
2341
            xs1 = x.argmin(dim=-1)
2342
            xs2 = x.min(dim=-1).indices
2343
            for j in range(7):
2344
                self.assertEqual(xs1[j].item(), size[1] - i)
2345
                self.assertEqual(xs2[j].item(), size[1] - i)
2346

2347
    @onlyCUDA
2348
    @dtypes(torch.half, torch.float, torch.double, torch.bfloat16)
2349
    def test_reduction_vectorize_along_output(self, device, dtype):
2350
        def run_test(input_):
2351
            M, N = input_.shape
2352
            input_.zero_()
2353
            for i in range(min(M, N)):
2354
                input_[i][i] = 1
2355
            output1 = input_.argmax(dim=0)
2356
            output2 = input_.sum(dim=0)
2357
            for i in range(min(M, N)):
2358
                self.assertEqual(output1[i], i)
2359
                self.assertEqual(output2[i], 1)
2360
        # vec 4
2361
        run_test(torch.zeros(64, 64, dtype=dtype, device=device))
2362
        # vec 2
2363
        run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64))
2364
        run_test(torch.zeros(64, 62, dtype=dtype, device=device))
2365
        run_test(torch.zeros(64, 2, dtype=dtype, device=device))
2366
        # vec 1
2367
        run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64))
2368
        run_test(torch.zeros(64, 61, dtype=dtype, device=device))
2369
        run_test(torch.zeros(64, 1, dtype=dtype, device=device))
2370

2371
    @onlyCUDA
2372
    def test_argminmax_large_axis(self, device):
2373
        # Regression test for gh-32863
2374
        x = torch.zeros(2**31, device=device, dtype=torch.int8)
2375
        x[-1] = 1
2376
        self.assertEqual(x.argmax(0), x.shape[0] - 1)
2377
        self.assertEqual(x.max(0).indices, x.shape[0] - 1)
2378
        x[-1] = -1
2379
        self.assertEqual(x.argmin(0), x.shape[0] - 1)
2380
        self.assertEqual(x.min(0).indices, x.shape[0] - 1)
2381

2382
    def test_argminmax_axis_with_dim_one(self, device):
2383
        # See: https://github.com/pytorch/pytorch/issues/38922
2384
        n = 32768
2385
        x = torch.zeros(1, n)
2386
        self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64))
2387
        self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64))
2388

2389
        self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64))
2390
        self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64))
2391

2392
        self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2393
        self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2394

2395
        self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2396
        self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2397

2398
    @dtypes(torch.int, torch.long, torch.float, torch.double)
2399
    @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double)
2400
    def test_median_real_values(self, device, dtype):
2401
        # Generate random 0-3D sizes
2402
        sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2403
        for size in sizes:
2404
            # Create random input tensor
2405
            t = torch.randn(size, device=device).type(dtype)
2406
            t_numpy = t.cpu().numpy()
2407
            res = t.median()
2408
            self.assertEqual(res, t.nanmedian())
2409
            k = int((t.numel() - 1) / 2)
2410
            self.assertEqual(res, t.view(-1).sort()[0][k])
2411
            if t.numel() % 2 == 1:
2412
                # We can only test agains numpy for odd reductions because numpy
2413
                # returns the mean of the two medians and torch returns the lower
2414
                self.assertEqual(res.cpu().numpy(), np.median(t_numpy))
2415
            for dim in range(t.ndim):
2416
                res = t.median(dim, True)
2417
                self.assertEqual(res, t.nanmedian(dim, True))
2418
                size = t.size(dim) if t.ndim > 0 else 1
2419
                k = int((size - 1) / 2)
2420
                self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim))
2421
                self.assertEqual(res[0], t.gather(dim, res[1]))
2422
                if size % 2 == 1:
2423
                    # We can only test agains numpy for odd reductions because numpy
2424
                    # returns the mean of the two medians and torch returns the lower
2425
                    self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False)
2426

2427
    @dtypes(torch.float, torch.double)
2428
    @dtypesIfCUDA(torch.half, torch.float, torch.double)
2429
    def test_median_nan_values(self, device, dtype):
2430
        # Generate random 0-3D sizes
2431
        sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)]
2432
        for size in sizes:
2433
            # Create random input tensor with nan values
2434
            t = torch.rand(size, device=device, dtype=dtype)
2435
            t.masked_fill_(t < 0.1, float('nan'))
2436
            t_numpy = t.cpu().numpy()
2437
            for op in [torch.median, torch.nanmedian]:
2438
                numpy_op = np.median if op == torch.median else np.nanmedian
2439
                res = op(t)
2440
                num_nan = t.isnan().sum()
2441
                if op == torch.median and num_nan > 0:
2442
                    k = t.numel() - 1
2443
                else:
2444
                    k = int((t.numel() - num_nan - 1) / 2)
2445
                self.assertEqual(res, t.view(-1).sort()[0][k])
2446
                if (t.numel() - num_nan) % 2 == 1:
2447
                    # We can only test agains numpy for odd reductions because numpy
2448
                    # returns the mean of the two medians and torch returns the lower
2449
                    self.assertEqual(res.item(), numpy_op(t.cpu().numpy()))
2450
                for dim in range(t.ndim):
2451
                    res = op(t, dim, True)
2452
                    size = t.size(dim) if t.ndim > 0 else 1
2453
                    num_nan = t.isnan().sum(dim, True)
2454
                    if op == torch.median:
2455
                        k = torch.where(num_nan > 0, size - 1, int((size - 1) / 2))
2456
                    else:
2457
                        k = ((size - num_nan - 1) / 2).type(torch.long)
2458
                    self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k))
2459
                    self.assertEqual(res[0], t.gather(dim, res[1]))
2460
                    # We can only test agains numpy for odd reductions because numpy
2461
                    # returns the mean of the two medians and torch returns the lower
2462
                    mask = (size - num_nan) % 2 == 1
2463
                    res = res[0].masked_select(mask).cpu()
2464
                    ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()]
2465
                    self.assertEqual(res, torch.from_numpy(ref))
2466

2467
    def test_median_corner_cases(self, device):
2468
        def check(op, a, args, key):
2469
            t = torch.tensor(a, device=device)
2470
            res = op(t, *args)
2471
            if not args:
2472
                key = torch.tensor(key, device=device)
2473
            else:
2474
                if len(key) == 1:
2475
                    key = torch.tensor(key[0], device=device)
2476
                    res = res[0]
2477
                else:
2478
                    key = (torch.tensor(key[0], device=device), torch.tensor(key[1], device=device))
2479
            self.assertEqual(res, key)
2480

2481
        nan = float('nan')
2482
        check(torch.median, nan, [], nan)
2483
        check(torch.median, [], [], nan)
2484
        check(torch.nanmedian, nan, [], nan)
2485
        check(torch.median, nan, [0], [nan, 0])
2486
        check(torch.nanmedian, nan, [0], [nan, 0])
2487
        check(torch.median, [nan], [0, True], [[nan], [0]])
2488
        check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2489
        check(torch.median, [nan], [0, True], [[nan], [0]])
2490
        check(torch.nanmedian, [nan], [0, True], [[nan], [0]])
2491

2492
        # Indices are not deterministic here so can only check values
2493
        check(torch.median, [[nan, nan], [1, 2]], [0], [[nan, nan]])
2494
        check(torch.nanmedian, [[nan, nan], [1, 2]], [0], [[1, 2.]])
2495
        check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
2496
        check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
2497

2498
        # Discontiguous and strided tensors
2499
        a = torch.arange(12, device=device)
2500
        self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
2501
        self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
2502

2503
        a.resize_(3, 4)
2504
        self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2505
        self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2506
        self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
2507
        self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
2508

2509
        a.resize_(2, 3, 2)
2510
        self.assertEqual(a.T.median(), torch.tensor(5, device=device))
2511
        self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
2512
        self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2513
        self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
2514

2515

2516
    @onlyNativeDeviceTypes
2517
    @dtypes(torch.float, torch.double)
2518
    def test_quantile(self, device, dtype):
2519
        # Generate some random test cases
2520
        ops = ['quantile', 'nanquantile']
2521
        inputs = [tuple(np.random.randint(2, 10, size=i)) for i in range(1, 4)]
2522
        quantiles = [tuple(np.random.rand(i)) for i in range(0, 5)]
2523
        keepdims = [True, False]
2524

2525
        # Add corner cases
2526
        inputs.extend([0.75, (1,), (1, 1), (1, 2, 1)])
2527
        inputs.extend([[float('nan')], [[float('nan'), float('nan')], [1, 2]]])
2528
        inputs.extend([[[float('nan'), float('nan')], [float('nan'), 2]]])
2529
        quantiles.extend([0.5, [0., 1.], np.random.rand(10)])
2530

2531
        # Enumerate all input combinations
2532
        for op, x, q, keepdim in product(ops, inputs, quantiles, keepdims):
2533
            if type(x) is tuple:
2534
                a = torch.randn(x, dtype=dtype, device=device)
2535
                # Make some random elements NaN
2536
                a.masked_fill_(torch.randint_like(a, 20) == 0, float('nan'))
2537
            else:
2538
                a = torch.tensor(x, dtype=dtype, device=device)
2539

2540
            q = torch.tensor(q, dtype=dtype, device=device)
2541

2542
            torch_op = getattr(torch, op)
2543
            numpy_op = getattr(np, op)
2544

2545
            # Compute quantile along every dimension and flattened tensor
2546
            interpolations = ('linear', 'lower', 'higher', 'midpoint', 'nearest')
2547
            for interpolation, dim in product(interpolations,
2548
                                              [None] + list(range(a.ndim))):
2549
                result = torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation)
2550
                expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim,
2551
                                    interpolation=interpolation, keepdims=keepdim)
2552
                self.assertEqual(result.cpu(), torch.from_numpy(np.array(expected)).type(result.type()))
2553

2554
                # Test out variation
2555
                out = torch.empty_like(result)
2556
                torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out)
2557
                self.assertEqual(out.cpu(), result.cpu())
2558

2559
    def test_quantile_backward(self, device):
2560
        def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)):
2561
            for op in ops:
2562
                t = torch.tensor(a, device=device, requires_grad=True)
2563
                op(t, torch.tensor(q, device=device), dim).sum().backward()
2564
                self.assertEqual(t.grad, expected_grad)
2565

2566
        check([1., 2, 3], 0.5, 0, [0, 1, 0])
2567
        check([1., 2, 3, 4], 0.5, 0, [0, 0.5, 0.5, 0])
2568
        check([3., 1, 4, 2], 0.5, 0, [0.5, 0, 0, 0.5])
2569
        check([1., 2, 3, 4], [0.25, 0.5, 0.75], 0, [0.25, 1.25, 1.25, 0.25])
2570
        check([[1., 2], [2, 1]], 0., 0, [[1, 0], [0, 1]])
2571
        check([[1., 2], [4, 3]], 1., 1, [[0, 1], [1, 0]])
2572
        check([1, float('nan'), 2], 0.5, 0, [0, 1, 0], [torch.quantile])
2573
        check([1, float('nan'), 2], 0.5, 0, [0.5, 0, 0.5], [torch.nanquantile])
2574

2575
    def test_quantile_error(self, device):
2576
        def check(a, q, args, kwargs, message):
2577
            with self.assertRaisesRegex(RuntimeError, r'quantile\(\) ' + message):
2578
                at = torch.tensor(a, device=device)
2579
                qt = torch.tensor(q, device=device) if isinstance(q, list) else q
2580
                torch.quantile(at, qt, *args, **kwargs)
2581

2582
        check([], 0.5, [], {}, r'input tensor must be non-empty')
2583
        check([1.], [[1.]], [], {}, r'q must be a scalar or 1D tensor')
2584
        check([1], 0.5, [], {}, r'input tensor must be either float or double dtype')
2585
        check([1.], [1], [], {}, r'q tensor must be same dtype as the input tensor')
2586
        check([1.], -1., [], {}, r'q must be in the range \[0, 1\] but got -1')
2587
        check([1.], 1.1, [], {}, r'q must be in the range \[0, 1\] but got 1.1')
2588
        check([1.], 0.5, [], {'out': torch.empty([], dtype=torch.int32, device=device)},
2589
              r'out tensor must be same dtype as the input tensor')
2590
        check([1.], [1.], [None, False], {'interpolation': 'random_mode'},
2591
              r"interpolation must be one of linear, lower, higher, midpoint or nearest, but got random_mode")
2592

2593
        if self.device_type == "cpu":
2594
            check([1.], [0.5, 1.1, -1], [], {}, r'q values must be in the range \[0, 1\]')
2595

2596
        if self.device_type == "cuda":
2597
            with self.assertRaisesRegex(
2598
                    RuntimeError, r'quantile\(\) q tensor must be on the same device as the input tensor'):
2599
                torch.randn(1, device=device).quantile(torch.tensor(0.5))
2600
            with self.assertRaisesRegex(
2601
                    RuntimeError, r'quantile\(\) out tensor must be on the same device as the input tensor'):
2602
                torch.quantile(torch.randn(1, device=device), 0.5, out=torch.scalar_tensor(1))
2603

2604
    def test_std_mean(self, device):
2605
        x = torch.rand(100, 50, 20, device=device)
2606
        for dim in range(x.dim()):
2607
            for unbiased in [False, True]:
2608
                for keepdim in [False, True]:
2609
                    std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2610
                    std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2611
                    mean2 = x.mean(dim=dim, keepdim=keepdim)
2612
                    self.assertEqual(std1, std2)
2613
                    self.assertEqual(mean1, mean2)
2614

2615
    def test_std_mean_all_dims(self, device):
2616
        x = torch.rand(100, 50, 20, device=device)
2617
        for unbiased in [False, True]:
2618
            std1, mean1 = torch.std_mean(x, unbiased=unbiased)
2619
            std2 = x.std(unbiased=unbiased)
2620
            mean2 = x.mean()
2621
            self.assertEqual(std1, std2)
2622
            self.assertEqual(mean1, mean2)
2623

2624
    def test_var_mean(self, device):
2625
        x = torch.rand(100, 300, 50, device=device)
2626
        for dim in range(x.dim()):
2627
            for unbiased in [False, True]:
2628
                for keepdim in [False, True]:
2629
                    var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2630
                    var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
2631
                    mean2 = x.mean(dim=dim, keepdim=keepdim)
2632
                    self.assertEqual(var1, var2)
2633
                    self.assertEqual(mean1, mean2)
2634

2635
    def test_var_mean_all_dims(self, device):
2636
        x = torch.rand(100, 50, 20, device=device)
2637
        for unbiased in [False, True]:
2638
            var1, mean1 = torch.var_mean(x, unbiased=unbiased)
2639
            var2 = x.var(unbiased=unbiased)
2640
            mean2 = x.mean()
2641
            self.assertEqual(var1, var2)
2642
            self.assertEqual(mean1, mean2)
2643

2644
    def test_std_mean_some_dims(self, device):
2645
        sizes = (4, 6, 7, 5, 3)
2646
        dims = len(sizes)
2647
        x = torch.rand(sizes, device=device)
2648
        for num_of_dims in range(2, dims):
2649
            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2650
            for dim in dim_list:
2651
                for unbiased in [False, True]:
2652
                    for keepdim in [False, True]:
2653
                        std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2654
                        std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2655
                        mean2 = x.mean(dim=dim, keepdim=keepdim)
2656
                        self.assertEqual(std1, std2)
2657
                        self.assertEqual(mean1, mean2)
2658

2659
    def _compare_std_var_with_numpy(self, op, device, dtype, input, dim,
2660
                                    keepdim, unbiased, use_out):
2661
        a = input.cpu().numpy() if input.dtype is not torch.bfloat16 else input.float().cpu().numpy()
2662
        numpy_kwargs = {
2663
            'axis' : dim,
2664
            'keepdims' : keepdim,
2665
            'ddof' : 1 if unbiased else 0,
2666
        }
2667

2668
        if dim is None:
2669
            del numpy_kwargs['axis']
2670
            del numpy_kwargs['keepdims']
2671

2672
        if op == 'var':
2673
            torch_op = torch.var
2674
            numpy_op = np.var
2675
        elif op == 'std':
2676
            torch_op = torch.std
2677
            numpy_op = np.std
2678
        else:
2679
            self.fail("Unknown op!")
2680

2681
        numpy_result = numpy_op(a, **numpy_kwargs)
2682

2683
        if dim is None and use_out is False:
2684
            torch_result = torch_op(input, unbiased)
2685
        elif dim is not None and use_out is False:
2686
            torch_result = torch_op(input, dim, unbiased, keepdim)
2687
        elif dim is not None and use_out is True:
2688
            out = torch.empty(0, device=device, dtype=dtype)
2689
            torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2690
        else:
2691
            out = torch.empty(0, device=device, dtype=dtype)
2692
            torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2693

2694
        exact_dtype = input.dtype not in (torch.bfloat16, torch.complex32, torch.complex64, torch.complex128)
2695
        self.assertEqual(torch_result, numpy_result, exact_dtype=exact_dtype)
2696

2697
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2698
    def test_var_vs_numpy(self, device, dtype):
2699
        _size = (20, 20)
2700

2701
        for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2702
                                 (None, 0, 1),
2703
                                 (False, True),
2704
                                 (False, True),
2705
                                 (False, True),):
2706
            self._compare_std_var_with_numpy('var', device, dtype, *test_case)
2707

2708
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2709
    def test_std_vs_numpy(self, device, dtype):
2710
        _size = (20, 20)
2711

2712
        for test_case in product((torch.randn(_size, device=device, dtype=dtype),),
2713
                                 (None, 0, 1),
2714
                                 (False, True),
2715
                                 (False, True),
2716
                                 (False, True),):
2717
            self._compare_std_var_with_numpy('std', device, dtype, *test_case)
2718

2719
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2720
    def test_var_correction_vs_numpy(self, device, dtype):
2721
        _size = (20, 20)
2722
        test_args = [
2723
            *product(
2724
                # dim
2725
                (None, 0, 1),
2726
                # correction
2727
                (None, 0, 10, 30),
2728
                # keepdim
2729
                (False, True),
2730
            ),
2731
            [None, -100, True],  # Negative correction
2732
        ]
2733

2734
        tensor = make_tensor(_size, device=device, dtype=dtype)
2735
        array = tensor.cpu().numpy()
2736

2737
        for dim, correction, keepdim in test_args:
2738
            numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2739
            if correction is None:
2740
                # NumPy default is not compatible with torch.std (gh-50010)
2741
                numpy_kwargs['ddof'] = 1
2742

2743
            numpy_res = np.asarray(np.var(array, **numpy_kwargs))
2744
            torch_res = torch.var(tensor, dim=dim, correction=correction, keepdim=keepdim)
2745

2746
            # inf vs. nan results are sensitive to machine precision,
2747
            # just treat them as equivalent
2748
            numpy_res[np.isinf(numpy_res)] = np.nan
2749
            torch_res[torch_res.isinf()] = np.nan
2750

2751
            self.assertEqual(torch_res, numpy_res)
2752

2753
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2754
    def test_std_correction_vs_numpy(self, device, dtype):
2755
        _size = (20, 20)
2756
        test_args = [
2757
            *product(
2758
                # dim
2759
                (None, 0, 1),
2760
                # correction
2761
                (None, 0, 10, 30),
2762
                # keepdim
2763
                (False, True),
2764
            ),
2765
            [None, -100, True],  # Negative correction
2766
        ]
2767

2768
        tensor = make_tensor(_size, device=device, dtype=dtype)
2769
        array = tensor.cpu().numpy()
2770

2771
        for dim, correction, keepdim in test_args:
2772
            numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2773
            if correction is None:
2774
                # NumPy default is incompatible with torch.std (gh-50010)
2775
                numpy_kwargs['ddof'] = 1
2776

2777
            numpy_res = np.asarray(np.std(array, **numpy_kwargs))
2778
            torch_res = torch.std(tensor, dim=dim, correction=correction, keepdim=keepdim)
2779

2780
            # inf vs. nan results are sensitive to machine precision,
2781
            # just treat them as equivalent
2782
            numpy_res[np.isinf(numpy_res)] = np.nan
2783
            torch_res[torch_res.isinf()] = np.nan
2784

2785
            self.assertEqual(torch_res, numpy_res)
2786

2787
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2788
    def test_std_mean_correction(self, device, dtype):
2789
        _size = (20, 20)
2790
        test_args = [
2791
            *product(
2792
                # dim
2793
                (None, 0, 1),
2794
                # correction
2795
                (None, 0, 10, 30),
2796
                # keepdim
2797
                (False, True),
2798
            ),
2799
            [None, -100, True],  # Negative correction
2800
        ]
2801

2802
        tensor = make_tensor(_size, device=device, dtype=dtype)
2803

2804
        for dim, correction, keepdim in test_args:
2805
            kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2806
            std1 = torch.std(tensor, **kwargs)
2807
            if dim is not None:
2808
                mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2809
            else:
2810
                mean1 = torch.mean(tensor)
2811
                if keepdim:
2812
                    mean1 = mean1.reshape((1,) * tensor.ndim)
2813
            std2, mean2 = torch.std_mean(tensor, **kwargs)
2814

2815
            self.assertEqual(std1, std2)
2816
            self.assertEqual(mean1, mean2)
2817

2818
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2819
    def test_var_mean_correction(self, device, dtype):
2820
        _size = (20, 20)
2821
        test_args = [
2822
            *product(
2823
                # dim
2824
                (None, 0, 1),
2825
                # correction
2826
                (None, 0, 10, 30),
2827
                # keepdim
2828
                (False, True),
2829
            ),
2830
            [None, -100, True],  # Negative correction
2831
        ]
2832

2833
        tensor = make_tensor(_size, device=device, dtype=dtype)
2834

2835
        for dim, correction, keepdim in test_args:
2836
            kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2837
            var1 = torch.var(tensor, **kwargs)
2838
            if dim is not None:
2839
                mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2840
            else:
2841
                mean1 = torch.mean(tensor)
2842
                if keepdim:
2843
                    mean1 = mean1.reshape((1,) * tensor.ndim)
2844
            var2, mean2 = torch.var_mean(tensor, **kwargs)
2845

2846
            self.assertEqual(var1, var2)
2847
            self.assertEqual(mean1, mean2)
2848

2849
    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
2850
    def test_warn_invalid_degrees_of_freedom(self, device, dtype):
2851
        def _assert_warning(_func, _tensor, _correction):
2852
            with warnings.catch_warnings(record=True) as w:
2853
                _func(_tensor, dim=-1, correction=_correction)
2854
            self.assertIn('degrees of freedom is <= 0', str(w[0].message))
2855

2856
        correction = 20
2857
        size = (10, correction)
2858
        tensor = make_tensor(size, dtype=dtype, device=device)
2859
        for f in [torch.std, torch.var, torch.var_mean, torch.std_mean]:
2860
            _assert_warning(f, tensor, correction)
2861

2862
    def test_amin_amax_some_dims(self, device):
2863
        sizes = (4, 6, 7, 5, 3)
2864
        dims = len(sizes)
2865
        x = torch.rand(sizes, device=device)
2866
        for num_of_dims in range(2, dims):
2867
            dim_list = list(combinations(list(range(dims)), r=num_of_dims))
2868
            for dim in dim_list:
2869
                for keepdim in [False, True]:
2870
                    amin1 = torch.amin(x, dim=dim, keepdim=keepdim)
2871
                    amax1 = torch.amax(x, dim=dim, keepdim=keepdim)
2872
                    amin2 = x
2873
                    amax2 = x
2874
                    for i, d in enumerate(dim):
2875
                        if not keepdim:
2876
                            d -= i
2877
                        amin2 = torch.amin(amin2, dim=d, keepdim=keepdim)
2878
                        amax2 = torch.amax(amax2, dim=d, keepdim=keepdim)
2879
                    self.assertEqual(amin1, amin2)
2880
                    self.assertEqual(amax1, amax2)
2881

2882
    def test_histc(self, device):
2883
        # negative nbins throws
2884
        with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
2885
            torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
2886
        # empty tensor
2887
        actual = torch.histc(torch.tensor([], device=device), min=0, max=3)
2888
        expected = torch.zeros(100, dtype=torch.float, device=device)
2889
        self.assertEqual(expected, actual)
2890

2891
        # without nbins
2892
        actual = torch.histc(
2893
            torch.tensor([2, 5], dtype=torch.float, device=device))
2894
        expected = torch.zeros(100, dtype=torch.float, device=device)
2895
        expected[0] = 1
2896
        expected[99] = 1
2897
        self.assertEqual(expected, actual)
2898
        # tensor with the same element
2899
        actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
2900
        self.assertEqual(
2901
            torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
2902
            actual)
2903
        # no element falls between [min, max]
2904
        actual = torch.histc(
2905
            torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
2906
        self.assertEqual(
2907
            torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
2908
            actual)
2909
        # element falls below min + integral bin size and
2910
        actual = torch.histc(
2911
            torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
2912
            bins=5, min=1, max=5)
2913
        self.assertEqual(
2914
            torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
2915
            actual)
2916
        # non-integral bin size
2917
        actual = torch.histc(
2918
            torch.tensor([1, 2, 1], dtype=torch.float, device=device),
2919
            bins=4, min=0, max=3)
2920
        self.assertEqual(
2921
            torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2922
            actual)
2923
        # double input
2924
        actual = torch.histc(
2925
            torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3)
2926
        self.assertEqual(
2927
            torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
2928
            actual)
2929
        self.assertEqual(actual.dtype, torch.double)
2930
        # mixed input
2931
        actual = torch.histc(
2932
            torch.tensor([1., 2, 1], dtype=torch.float, device=device),
2933
            bins=4, min=0, max=3)
2934
        self.assertEqual(
2935
            torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
2936
            actual)
2937
        self.assertEqual(actual.dtype, torch.float)
2938
        # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar.
2939
        actual = torch.histc(
2940
            torch.tensor(0, dtype=torch.float, device=device),
2941
            bins=1, min=0, max=3)
2942
        self.assertEqual(
2943
            torch.tensor([1], dtype=torch.float, device=device),
2944
            actual)
2945
        # tensors with inf; min, max not provided -- should throw a RuntimeError
2946
        with self.assertRaisesRegex(RuntimeError, r'range of \[inf, inf\] is not finite'):
2947
            torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device))
2948
        with self.assertRaisesRegex(RuntimeError, r'range of \[1, inf\] is not finite'):
2949
            torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device))
2950
        # tensors with inf; min, max provided
2951
        self.assertEqual(
2952
            torch.histc(torch.tensor([float("inf")], dtype=torch.float, device=device),
2953
                        bins=1, min=0, max=3),
2954
            torch.tensor([0], dtype=torch.float, device=device))
2955
        self.assertEqual(
2956
            torch.histc(torch.tensor([1., 2., float("inf")], dtype=torch.float, device=device),
2957
                        bins=4, max=3),
2958
            torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
2959
        # tensor with nan; min, max not provided -- should throw a RuntimeError
2960
        with self.assertRaisesRegex(RuntimeError, r'range of \[nan, nan\] is not finite'):
2961
            torch.histc(torch.tensor([float("nan")], dtype=torch.float, device=device))
2962
        # tensor with nan; min, max provided -- nan is ignored
2963
        self.assertEqual(
2964
            torch.histc(torch.tensor([1., 2., float("nan")], dtype=torch.float, device=device),
2965
                        bins=4, max=3),
2966
            torch.tensor([0, 1, 1, 0], dtype=torch.float, device=device))
2967
        # tensors with min > max -- should throw a RuntimeError
2968
        with self.assertRaisesRegex(RuntimeError, "max must be larger than min"):
2969
            torch.histc(torch.tensor([1., 2., 3.], dtype=torch.float, device=device),
2970
                        bins=4, min=5, max=1)
2971

2972
        # test against numpy.histogram()
2973
        def test_against_np(tensor, bins=100, min=0, max=0):
2974
            if min == 0 and max == 0:
2975
                min = tensor.min().item()
2976
                max = tensor.max().item()
2977
            nparr = tensor.cpu().numpy()
2978
            actual = torch.histc(tensor, bins=bins, min=min, max=max)
2979
            expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0])
2980
            actual_cpu = actual.cpu()
2981
            # NB: Numpy returns a int64 tensor, like normal people...
2982
            self.assertEqual(actual, expected.to(actual_cpu))
2983

2984
        test_against_np(torch.tensor([1., 2, 1], device=device))
2985
        test_against_np(torch.randn(5000, device=device))
2986

2987
        # Test bins arg
2988
        test_against_np(torch.randn(301, device=device), bins=10)
2989

2990
        # Test truncated range
2991
        test_against_np(torch.randn(201, device=device), min=0.1, max=1)
2992

2993
        noncontig = torch.randn(100, 3, device=device)[:, 2]
2994
        test_against_np(noncontig)
2995

2996
        multidim = torch.randn(3, 5, 7, 2, device=device)
2997
        test_against_np(multidim)
2998

2999
        expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
3000
        test_against_np(expanded)
3001

3002
        linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device)
3003
        test_against_np(linear, bins=20, min=0, max=0.99)
3004

3005
    @onlyCPU
3006
    @dtypes(torch.bfloat16, torch.half)
3007
    def test_histc_lowp(self, device, dtype):
3008
        actual = torch.histc(
3009
            torch.tensor([1, 2, 1], dtype=dtype, device=device), bins=4, min=0, max=3)
3010
        self.assertEqual(
3011
            torch.tensor([0, 2, 1, 0], dtype=dtype, device=device),
3012
            actual)
3013
        self.assertEqual(actual.dtype, dtype)
3014

3015
    """
3016
    Runs torch.histogram and numpy.histogram on the specified input parameters
3017
    and asserts that their output is equal.
3018
    """
3019
    def _test_histogram_numpy(self, t, bins, bin_range, weights, density):
3020
        def to_np(t):
3021
            if not torch.is_tensor(t):
3022
                return t
3023
            else:
3024
                return t.cpu().numpy()
3025

3026
        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
3027
        def reference_histogram(self, t, bins, bin_range, weights, density, dtype):
3028
            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3029
            (np_hist, np_bin_edges) = np.histogram(np_t, np_bins, range=bin_range, weights=np_weights, density=density)
3030
            return (torch.from_numpy(np_hist).to(dtype), torch.from_numpy(np_bin_edges).to(dtype))
3031

3032
        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
3033
        if bin_range:
3034
            (actual_hist, actual_bin_edges) = torch.histogram(t, bins, range=bin_range, weight=weights, density=density)
3035
        else:
3036
            (actual_hist, actual_bin_edges) = torch.histogram(t, bins, weight=weights, density=density)
3037

3038
        (expected_hist, expected_bin_edges) = reference_histogram(self, t, bins, bin_range, weights, density, actual_hist.dtype)
3039

3040
        """
3041
        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3042
        When bin edges are not explicitly defined, histogram uses the linspace operator internally
3043
        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3044
        from numpy.linspace output.
3045
        Issue: https://github.com/pytorch/pytorch/issues/58758
3046
        """
3047
        if not torch.is_tensor(bins):
3048
            self.assertEqual(actual_bin_edges, expected_bin_edges, atol=1e-5, rtol=1e-5)
3049
            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
3050
            (expected_hist, expected_bin_edges) = reference_histogram(
3051
                self, t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3052

3053
        self.assertEqual(actual_hist, expected_hist)
3054
        self.assertEqual(actual_bin_edges, expected_bin_edges)
3055

3056
        # Test passing non-contiguous output tensors
3057
        hist_out = make_tensor(expected_hist.shape, device=expected_hist.device, dtype=expected_hist.dtype,
3058
                               noncontiguous=True)
3059
        bin_edges_out = make_tensor(expected_bin_edges.shape, device=expected_bin_edges.device, dtype=expected_bin_edges.dtype,
3060
                                    noncontiguous=True)
3061

3062
        # Doesn't pass a 'range' kwarg unless necessary because the override of histogram with Tensor bins doesn't accept one
3063
        if bin_range:
3064
            torch.histogram(t, bins, range=bin_range, weight=weights, density=density, out=(hist_out, bin_edges_out))
3065
        else:
3066
            torch.histogram(t, bins, weight=weights, density=density, out=(hist_out, bin_edges_out))
3067

3068
        self.assertEqual(hist_out, expected_hist)
3069
        self.assertEqual(bin_edges_out, expected_bin_edges)
3070

3071
    @onlyCPU
3072
    @dtypes(torch.float32)
3073
    def test_histogram(self, device, dtype):
3074
        shapes = (
3075
            (),
3076
            (0,),
3077
            (1,),
3078
            (1, 5),
3079
            (3, 5),
3080
            (1, 5, 1),
3081
            (2, 3, 5))
3082

3083
        for contig, bins_contig, bin_ct, weighted, density, shape in \
3084
                product([True, False], [True, False], range(1, 10), [True, False], [True, False], shapes):
3085
            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3086
            weights = make_tensor(shape, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig) if weighted else None
3087

3088
            # Tests passing just the bin_ct
3089
            self._test_histogram_numpy(values, bin_ct, None, weights, density)
3090

3091
            # Tests with caller-specified histogram range
3092
            bin_range = sorted((random.uniform(-9, 9), random.uniform(-9, 9)))
3093
            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3094

3095
            # Tests with range min=max
3096
            bin_range[1] = bin_range[0]
3097
            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3098

3099
            # Tests with caller-specified bin edges
3100
            bin_edges = make_tensor(bin_ct + 1, dtype=dtype, device=device, low=-9, high=9).msort()
3101
            if not bins_contig:
3102
                # Necessary because msort always produces contiguous output
3103
                bin_edges_noncontig = make_tensor(bin_ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3104
                bin_edges_noncontig.copy_(bin_edges)
3105
                bin_edges = bin_edges_noncontig
3106
            self.assertEqual(bin_edges.is_contiguous(), bins_contig)
3107
            self._test_histogram_numpy(values, bin_edges, None, weights, density)
3108

3109
            # Tests with input tensor in which all elements are equal
3110
            elt = random.uniform(-9, 9)
3111
            values = make_tensor(shape, dtype=dtype, device=device, low=elt, high=elt, noncontiguous=not contig)
3112
            self._test_histogram_numpy(values, bin_ct, bin_range, weights, density)
3113
            self._test_histogram_numpy(values, bin_edges, None, weights, density)
3114

3115
            # Tests with input equal to bin_edges
3116
            weights = (
3117
                make_tensor(bin_ct + 1, dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3118
                if weighted
3119
                else None
3120
            )
3121
            self._test_histogram_numpy(bin_edges, bin_edges, None, weights, density)
3122

3123
        # Tests values of default args
3124
        for bin_ct, shape in product(range(1, 10), shapes):
3125
            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9)
3126
            (actual_hist, actual_bin_edges) = torch.histogram(values, bin_ct)
3127
            (expected_hist, expected_bin_edges) = torch.histogram(
3128
                values, bin_ct, range=None, weight=None, density=False)
3129
            self.assertEqual(actual_hist, expected_hist)
3130
            self.assertEqual(actual_bin_edges, expected_bin_edges)
3131

3132
    """
3133
    Runs torch.histogramdd and numpy.histogramdd on the specified input parameters
3134
    and asserts that their output is equal.
3135
    """
3136
    def _test_histogramdd_numpy(self, t, bins, bin_range, weights, density):
3137
        def to_np(t):
3138
            if type(t) == list:
3139
                return list(map(to_np, t))
3140
            if not torch.is_tensor(t):
3141
                return t
3142
            return t.cpu().numpy()
3143

3144
        # Wrapper around numpy.histogram performing conversions between torch tensors and numpy arrays.
3145
        def reference_histogramdd(t, bins, bin_range, weights, density, dtype):
3146
            (np_t, np_bins, np_weights) = map(to_np, [t, bins, weights])
3147

3148
            # numpy.histogramdd accepts only (N, D) shapes
3149
            D = np_t.shape[-1]
3150
            N = np.prod(np_t.shape[:-1])
3151
            reshaped_t = np.reshape(np_t, (N, D))
3152
            reshaped_wt = np.reshape(np_weights, (N,)) if np_weights is not None else None
3153

3154
            # numpy.histogramdd throws an error for D=0
3155
            if D == 0:
3156
                return (torch.tensor(float('nan') if density else 0.), [])
3157

3158
            # numpy.histogramdd expects range to be specified as a sequence of D (lower, upper) tuples
3159
            reshaped_range = None if not bin_range else [(bin_range[2 * i], bin_range[2 * i + 1]) for i in range(D)]
3160

3161
            (np_hist, np_bin_edges) = np.histogramdd(reshaped_t, np_bins,
3162
                                                     range=reshaped_range, weights=reshaped_wt, density=density)
3163

3164
            return (torch.from_numpy(np_hist).to(dtype), [torch.from_numpy(t).to(dtype) for t in np_bin_edges])
3165

3166
        (actual_hist, actual_bin_edges) = torch.histogramdd(t, bins, range=bin_range, weight=weights, density=density)
3167
        (expected_hist, expected_bin_edges) = reference_histogramdd(t, bins, bin_range, weights, density, actual_hist.dtype)
3168

3169
        D = len(actual_bin_edges)
3170
        self.assertEqual(D, len(expected_bin_edges))
3171

3172
        """
3173
        Works around linspace discrepancies by passing torch's constructed bin_edges to numpy.
3174
        When bin edges are not explicitly defined, histogram uses the linspace operator internally
3175
        to construct the sequence of bin edges. In some cases, torch.linspace output differs slightly
3176
        from numpy.linspace output.
3177
        Issue: https://github.com/pytorch/pytorch/issues/58758
3178
        """
3179
        if not torch.is_tensor(bins):
3180
            for dim in range(D):
3181
                self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5)
3182
            # Calls numpy.histogram again, passing torch's actual_bin_edges as the bins argument
3183
            (expected_hist, expected_bin_edges) = reference_histogramdd(
3184
                t, actual_bin_edges, bin_range, weights, density, actual_hist.dtype)
3185
            self.assertEqual(D, len(expected_bin_edges))
3186

3187
        self.assertEqual(actual_hist, expected_hist)
3188
        for dim in range(D):
3189
            self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim])
3190

3191
    @onlyCPU
3192
    @dtypes(torch.float32)
3193
    def test_histogramdd(self, device, dtype):
3194
        shapes = (
3195
            (1, 5),
3196
            (3, 5),
3197
            (1, 5, 1),
3198
            (2, 3, 5),
3199
            (7, 7, 7, 7),
3200
            (16, 8, 4, 2),
3201
            (10, 10, 10),
3202
            (7, 0, 3),
3203
            (5, 0),)
3204

3205
        for contig, bins_contig, weighted, density, shape in \
3206
                product([True, False], [True, False], [True, False], [True, False], shapes):
3207
            D = shape[-1]
3208

3209
            values = make_tensor(shape, dtype=dtype, device=device, low=-9, high=9, noncontiguous=not contig)
3210
            weights = (
3211
                make_tensor(shape[:-1], dtype=dtype, device=device, low=0, high=9, noncontiguous=not contig)
3212
                if weighted
3213
                else None
3214
            )
3215

3216
            # Tests passing a single bin count
3217
            bin_ct = random.randint(1, 5)
3218
            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3219

3220
            # Tests passing a bin count for each dimension
3221
            bin_ct = [random.randint(1, 5) for dim in range(D)]
3222
            self._test_histogramdd_numpy(values, bin_ct, None, weights, density)
3223

3224
            # Tests with caller-specified histogram range
3225
            bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)]
3226
            bin_range = [elt for t in bin_range_tuples for elt in t]
3227
            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3228

3229
            # Tests with range min=max
3230
            for dim in range(D):
3231
                bin_range[2 * dim + 1] = bin_range[2 * dim]
3232
            self._test_histogramdd_numpy(values, bin_ct, bin_range, weights, density)
3233

3234
            # Tests with caller-specified bin edges
3235
            bin_edges = [make_tensor(ct + 1, dtype=dtype, device=device, low=-9, high=9).msort() for ct in bin_ct]
3236
            if not bins_contig:
3237
                # Necessary because msort always produces contiguous output
3238
                bin_edges_noncontig = [
3239
                    make_tensor(ct + 1, dtype=dtype, device=device, noncontiguous=not bins_contig)
3240
                    for ct in bin_ct
3241
                ]
3242
                for dim in range(D):
3243
                    bin_edges_noncontig[dim].copy_(bin_edges[dim])
3244
                bin_edges = bin_edges_noncontig
3245
            for dim in range(D):
3246
                self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig)
3247
            self._test_histogramdd_numpy(values, bin_edges, None, weights, density)
3248

3249
    @onlyCPU
3250
    @dtypes(torch.float32)
3251
    def test_histogram_error_handling(self, device, dtype):
3252
        with self.assertRaisesRegex(RuntimeError, 'not implemented for'):
3253
            values = make_tensor((), dtype=torch.int32, device=device)
3254
            torch.histogram(values, 1)
3255

3256
        inconsistent_dtype = torch.float32 if dtype != torch.float32 else torch.float64
3257

3258
        with self.assertRaisesRegex(RuntimeError, 'input tensor and bins tensors should have the same dtype'):
3259
            values = make_tensor((), dtype=dtype, device=device)
3260
            bins = make_tensor((), dtype=inconsistent_dtype, device=device)
3261
            torch.histogram(values, bins)
3262

3263
        with self.assertRaisesRegex(RuntimeError, 'input tensor and weight tensor should have the same dtype'):
3264
            values = make_tensor((), dtype=dtype, device=device)
3265
            weight = make_tensor((), dtype=inconsistent_dtype, device=device)
3266
            torch.histogram(values, 1, weight=weight)
3267

3268
        with self.assertRaisesRegex(RuntimeError, 'input tensor and hist tensor should have the same dtype'):
3269
            values = make_tensor((), dtype=dtype, device=device)
3270
            hist = make_tensor((), dtype=inconsistent_dtype, device=device)
3271
            bin_edges = make_tensor((), dtype=dtype, device=device)
3272
            torch.histogram(values, 1, out=(hist, bin_edges))
3273

3274
        with self.assertRaisesRegex(RuntimeError, 'input tensor and bin_edges tensor should have the same dtype'):
3275
            values = make_tensor((), dtype=dtype, device=device)
3276
            hist = make_tensor((), dtype=dtype, device=device)
3277
            bin_edges = make_tensor((), dtype=inconsistent_dtype, device=device)
3278
            torch.histogram(values, 1, out=(hist, bin_edges))
3279

3280
        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have one dimension'):
3281
            t = make_tensor((2, 2), dtype=dtype, device=device)
3282
            torch.histogram(t, t)
3283

3284
        with self.assertRaisesRegex(RuntimeError, 'bins tensor should have at least 1 element'):
3285
            t = make_tensor((0), dtype=dtype, device=device)
3286
            torch.histogram(t, t)
3287

3288
        with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
3289
            values = make_tensor((), dtype=dtype, device=device)
3290
            torch.histogram(values, -1)
3291

3292
        with self.assertRaisesRegex(RuntimeError, 'if weight tensor is provided it should have the same shape \
3293
as the input tensor excluding its innermost dimension'):
3294
            values = make_tensor((2, 2), dtype=dtype, device=device)
3295
            weight = make_tensor((1), dtype=dtype, device=device)
3296
            torch.histogram(values, 1, weight=weight)
3297

3298
        with self.assertRaisesRegex(TypeError, 'received an invalid combination of arguments'):
3299
            values = make_tensor((), dtype=dtype, device=device)
3300
            bin_edges = make_tensor((), dtype=dtype, device=device)
3301
            torch.histogram(values, bin_edges, range=(0, 1))
3302

3303
        with self.assertRaisesRegex(RuntimeError, 'min should not exceed max'):
3304
            values = make_tensor((), dtype=dtype, device=device)
3305
            torch.histogram(values, 2, range=(1, 0))
3306

3307
        with self.assertRaisesRegex(RuntimeError, r'range \[nan, nan\] is not finite'):
3308
            values = torch.tensor([float("nan")], device=device, dtype=dtype)
3309
            torch.histogram(values, 2)
3310

3311
    # Tests to ensure that reduction functions employing comparison operators are usable when there
3312
    # exists a zero dimension (i.e. when the tensors are empty) in the tensor. These tests specifically
3313
    # cater to functions where specifying the `dim` parameter is necessary.
3314
    def test_tensor_compare_ops_empty(self, device):
3315
        shape = (2, 0, 4)
3316
        master_input = torch.randn(shape, device=device)
3317
        np_input = np.empty(shape)
3318
        test_functions = [
3319
            ('amax', torch.amax, np.amax),
3320
            ('amin', torch.amin, np.amin),
3321
            ('max', lambda *args, **kwargs: torch.max(*args, **kwargs).values, np.max),
3322
            ('min', lambda *args, **kwargs: torch.min(*args, **kwargs).values, np.min),
3323
            ('median', lambda *args, **kwargs: torch.median(*args, **kwargs).values, np.median),
3324
        ]
3325

3326
        for name, fn, np_function in test_functions:
3327
            # Check if reduction happens along the specified dim with and without keepdim. Check with
3328
            # numpy to maintain compatibility with numpy functions.
3329
            error_msg = f"test function: {name}"
3330
            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3331
            self.assertEqual(np_function(np_input, axis=2),
3332
                             fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False)
3333

3334
            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3335
            self.assertEqual(np_function(np_input, axis=-1),
3336
                             fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False)
3337

3338
            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3339
                             msg=error_msg)
3340
            self.assertEqual(np_function(np_input, axis=2, keepdims=True),
3341
                             fn(master_input, dim=2, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3342

3343
            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3344
                             msg=error_msg)
3345
            self.assertEqual(np_function(np_input, axis=-1, keepdims=True),
3346
                             fn(master_input, dim=-1, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3347

3348
            # Check if function raises error on specified zero'd dimension as reduction dim.
3349
            self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3350

3351
    # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using comparison operators
3352
    # raises an error if no `dim` parameter is specified. This exists separately from tests in
3353
    # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does
3354
    # not throw errors. Also, checking the return type of argmax requires supplying a different dtype
3355
    # argument than that for the input tensor. There is also variantion in numpy testing.
3356
    def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device):
3357
        shape = (2, 0, 4)
3358
        master_input = torch.randn(shape, device=device)
3359
        np_input = np.empty(shape)
3360
        test_functions = [
3361
            ('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax),
3362
            ('argmin', torch.argmin, {'dtype': torch.int64}, np.argmin),
3363
            ('kthvalue', lambda *args, k=1, **kwargs: torch.kthvalue(*args, k=1, **kwargs).values,
3364
             {}, lambda *args, k=1, axis=None, **kwargs: np.partition(*args, k, **kwargs).take(k - 1, axis=axis))
3365
        ]
3366

3367
        for name, fn, dtype, np_function in test_functions:
3368
            error_msg = f"test function: {name}"
3369
            self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_msg)
3370
            self.assertEqual(
3371
                np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False
3372
            )
3373

3374
            self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_msg)
3375
            self.assertEqual(
3376
                np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False
3377
            )
3378

3379
            # keepdim variant does not exist for numpy
3380
            self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=True),
3381
                             msg=error_msg)
3382
            self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=-1, keepdim=True),
3383
                             msg=error_msg)
3384

3385
            # Check if function raises error on specified zero'd dimension as reduction dim.
3386
            self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3387
            if name != 'kthvalue':
3388
                self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input))
3389

3390
    # Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using math operators works when a
3391
    # non-zero dim is specified for the reduction and throws an error when the dim specified is 0. Although
3392
    # there is some repetition with test_tensor_compare_ops_optional_dim_empty and test_tensor_compare_ops_empty,
3393
    # these tests are kept separate since tests for math operators also require checking for correctness of the
3394
    # returned data using allclose() or isinf() which does not exists in the former tests.
3395
    @skipIfNoSciPy
3396
    def test_tensor_reduce_ops_empty(self, device):
3397
        from scipy.special import logsumexp
3398
        shape = (2, 0, 4)
3399
        master_input = torch.randn(shape, device=device)
3400
        np_input = np.empty(shape)
3401
        test_functions = [
3402
            ('prod', torch.prod, 1., np.prod),
3403
            ('sum', torch.sum, 0., np.sum),
3404
            ('norm', torch.norm, 0., np.linalg.norm),
3405
            ('mean', torch.mean, nan, np.mean),
3406
            ('var', torch.var, nan, np.var),
3407
            ('std', torch.std, nan, np.std),
3408
            ('logsumexp', torch.logsumexp, -inf, logsumexp),
3409
        ]
3410

3411
        for name, fn, return_value, np_function in test_functions:
3412
            # Check if reduction happens along the specified dimension.
3413
            error_msg = f"test function: {name}"
3414
            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3415
            self.assertEqual(np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg,
3416
                             exact_dtype=False)
3417

3418
            self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3419
            self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg,
3420
                             exact_dtype=False)
3421

3422
            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3423
                             msg=error_msg)
3424
            self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=True),
3425
                             msg=error_msg, exact_dtype=False)
3426

3427
            self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3428
                             msg=error_msg)
3429
            self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=True),
3430
                             msg=error_msg, exact_dtype=False)
3431

3432
            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=error_msg)
3433
            self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=error_msg)
3434
            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True),
3435
                             msg=error_msg)
3436
            self.assertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True),
3437
                             msg=error_msg)
3438

3439
            if name != 'logsumexp':
3440
                # The scipy function does not work for reduction the zero dimension
3441
                self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(),
3442
                                 msg=error_msg)
3443
                self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(),
3444
                                 msg=error_msg)
3445
                self.assertEqual(np.float32(np_function(np_input, axis=1, keepdims=True)),
3446
                                 fn(master_input, dim=1, keepdim=True).cpu().numpy(),
3447
                                 msg=error_msg)
3448
                self.assertEqual(np.float32(np_function(np_input, axis=-2, keepdims=True)),
3449
                                 fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
3450
                                 msg=error_msg)
3451

3452
                # logsumexp throws a type error when not specifying dim so test separately.
3453
                self.assertEqual(torch.full((), return_value, device=device), fn(master_input), msg=error_msg)
3454
            else:
3455
                self.assertRaises(TypeError, lambda: fn(master_input))
3456

3457
    # Tests to ensure that any() and all() functions work with zero-dim tensors. Kept separate from
3458
    # other tests for checking reduction with zero-dim tensors because these tests have significantly
3459
    # different testing behaviour than that used for the former tests.
3460
    def test_reduction_empty_any_all(self, device):
3461
        shape = (2, 0, 4)
3462
        x = torch.randn(shape, device=device)
3463

3464
        for dtype in all_types_and_complex_and(torch.half, torch.bool):
3465
            # Refer: [all, any uint8 compatibility]
3466
            if dtype == torch.uint8:
3467
                out_dtype = torch.uint8
3468
            else:
3469
                out_dtype = torch.bool  # output of all/any is bool irrespective of input dtype
3470

3471
            xb = x.to(dtype)
3472
            yb = x.to(dtype)
3473
            # any
3474
            self.assertEqual((2, 0), xb.any(2).shape)
3475
            self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape)
3476
            self.assertEqual(torch.zeros((2, 4), device=device, dtype=out_dtype), xb.any(1))
3477
            self.assertEqual(torch.zeros((2, 1, 4), device=device, dtype=out_dtype), xb.any(1, keepdim=True))
3478
            self.assertEqual(torch.zeros((), device=device, dtype=out_dtype), xb.any())
3479

3480
            # all
3481
            self.assertEqual((2, 0), xb.all(2).shape)
3482
            self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape)
3483
            self.assertEqual(torch.ones((2, 4), device=device, dtype=out_dtype), xb.all(1))
3484
            self.assertEqual(torch.ones((2, 1, 4), device=device, dtype=out_dtype), xb.all(1, keepdim=True))
3485
            self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all())
3486

3487
    # TODO: can these be merged with their respective OpInfos?
3488
    def test_reduce_dtype(self, device):
3489
        def test_reduction(op, has_no_dim, takes_dtype=True):
3490
            x = torch.randn(3, 3, dtype=torch.float, requires_grad=True, device=device)
3491

3492
            if has_no_dim:
3493
                grad1, = torch.autograd.grad([op(x)], [x])
3494
                grad2, = torch.autograd.grad([op(x, dtype=torch.double)], [x])
3495
                self.assertEqual(grad1, grad2)
3496
                self.assertEqual(grad2.dtype, torch.float)
3497

3498
            gi = torch.randn(op(x, dim=0).shape, dtype=torch.float, device=device)
3499
            grad1, = torch.autograd.grad([op(x, dim=0)], [x], gi)
3500
            if takes_dtype:
3501
                grad2, = torch.autograd.grad([op(x, dim=0, dtype=torch.double)], [x], gi.double())
3502
            else:
3503
                grad2, = torch.autograd.grad([op(x.double(), dim=0)], [x], gi.double())
3504
            self.assertEqual(grad1, grad2)
3505
            self.assertEqual(grad2.dtype, torch.float)
3506

3507
        test_reduction(torch.sum, True)
3508
        test_reduction(torch.prod, True)
3509
        test_reduction(torch.cumsum, False)
3510
        test_reduction(torch.cumprod, False)
3511
        test_reduction(torch.logcumsumexp, False, takes_dtype=False)
3512

3513
    @ops(reference_masked_ops)
3514
    def test_reference_masked(self, device, dtype, op):
3515
        """Test masked reduction operations on strided-only tensors using
3516
        numpy reductions as reference.
3517
        """
3518

3519
        def to_numpy(input):
3520
            if input.dtype is torch.bfloat16:
3521
                return input.cpu().to(torch.float32).numpy()
3522
            else:
3523
                return input.cpu().numpy()
3524

3525
        samples = op.sample_inputs_func(op, device, dtype, requires_grad=False)
3526
        for sample_input in samples:
3527
            t = sample_input.input
3528
            actual = op(t, *sample_input.args, **sample_input.kwargs)
3529
            exact_dtype = not (t.dtype is torch.bfloat16
3530
                               or (op.promotes_int_to_float and not torch.is_floating_point(t)))
3531
            expected = op.ref(to_numpy(t), *sample_input.args,
3532
                              **dict(
3533
                                  # `identity` is mapped to numpy reduction `initial` argument
3534
                                  identity=torch.masked._reduction_identity(op.name, t),
3535
                                  **sample_input.kwargs))
3536

3537
            # Workaround https://github.com/pytorch/pytorch/issues/66556
3538
            expected = np.asarray(expected)  # transform numpy scalars to numpy.ndarray instances
3539

3540
            # Numpy differs, producing uint32 on Windows
3541
            if expected.dtype in [np.uint64, np.uint32]:
3542
                exact_dtype = False
3543

3544
            msg = ("Failed to produce expected results! Input tensor was"
3545
                   f" {t}, torch result is {actual}, and reference result is"
3546
                   f" {expected}.") if t.numel() < 10 else None
3547

3548
            self.assertEqual(actual, expected, msg, exact_dtype=exact_dtype)
3549

3550
    @onlyCUDA
3551
    @largeTensorTest("8GB")
3552
    @dtypes(torch.half, torch.chalf, torch.bfloat16)
3553
    def test_reductions_large_half_tensors(self, device, dtype):
3554
        t = torch.ones(2**31, device=device, dtype=dtype)
3555
        t[2**30:] = -1
3556
        expected = torch.tensor(0, device=device, dtype=dtype)
3557
        self.assertEqual(torch.sum(t), expected)
3558

3559
        # mean_cuda is not implemented for ComplexHalf
3560
        err_msg = "not implemented for 'ComplexHalf'"
3561
        ctx = self.assertRaisesRegex(
3562
            RuntimeError, err_msg) if dtype is torch.chalf else contextlib.nullcontext()
3563
        with ctx:
3564
            self.assertEqual(torch.mean(t), expected)
3565

3566
instantiate_device_type_tests(TestReductions, globals())
3567

3568
if __name__ == '__main__':
3569
    run_tests()
3570

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.