3
""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc
6
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
8
from __future__ import annotations
11
from typing import Optional, TYPE_CHECKING
15
from . import _dtypes_impl, _util
19
from ._normalizations import (
29
def _deco_axis_expand(func):
31
Generically handle axis arguments in reductions.
32
axis is *always* the 2nd arg in the function so no need to have a look at its signature
35
@functools.wraps(func)
36
def wrapped(a, axis=None, *args, **kwds):
38
axis = _util.normalize_axis_tuple(axis, a.ndim)
43
newshape = _util.expand_shape(a.shape, axis=0)
44
a = a.reshape(newshape)
47
return func(a, axis, *args, **kwds)
52
def _atleast_float(dtype, other_dtype):
53
"""Return a dtype that is real or complex floating-point.
55
For inputs that are boolean or integer dtypes, this returns the default
56
float dtype; inputs that are complex get converted to the default complex
57
dtype; real floating-point dtypes (`float*`) get passed through unchanged
61
if not (dtype.is_floating_point or dtype.is_complex):
62
return _dtypes_impl.default_dtypes().float_dtype
67
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False):
68
return a.count_nonzero(axis)
74
axis: AxisLike = None,
75
out: Optional[OutArray] = None,
77
keepdims: KeepDims = False,
80
raise NotImplementedError(f"argmax with dtype={a.dtype}.")
82
axis = _util.allow_only_single_axis(axis)
84
if a.dtype == torch.bool:
88
return torch.argmax(a, axis)
94
axis: AxisLike = None,
95
out: Optional[OutArray] = None,
97
keepdims: KeepDims = False,
100
raise NotImplementedError(f"argmin with dtype={a.dtype}.")
102
axis = _util.allow_only_single_axis(axis)
104
if a.dtype == torch.bool:
106
a = a.to(torch.uint8)
108
return torch.argmin(a, axis)
114
axis: AxisLike = None,
115
out: Optional[OutArray] = None,
116
keepdims: KeepDims = False,
118
where: NotImplementedType = None,
120
axis = _util.allow_only_single_axis(axis)
121
axis_kw = {} if axis is None else {"dim": axis}
122
return torch.any(a, **axis_kw)
128
axis: AxisLike = None,
129
out: Optional[OutArray] = None,
130
keepdims: KeepDims = False,
132
where: NotImplementedType = None,
134
axis = _util.allow_only_single_axis(axis)
135
axis_kw = {} if axis is None else {"dim": axis}
136
return torch.all(a, **axis_kw)
142
axis: AxisLike = None,
143
out: Optional[OutArray] = None,
144
keepdims: KeepDims = False,
145
initial: NotImplementedType = None,
146
where: NotImplementedType = None,
149
raise NotImplementedError(f"amax with dtype={a.dtype}")
160
axis: AxisLike = None,
161
out: Optional[OutArray] = None,
162
keepdims: KeepDims = False,
163
initial: NotImplementedType = None,
164
where: NotImplementedType = None,
167
raise NotImplementedError(f"amin with dtype={a.dtype}")
178
axis: AxisLike = None,
179
out: Optional[OutArray] = None,
180
keepdims: KeepDims = False,
182
return a.amax(axis) - a.amin(axis)
188
axis: AxisLike = None,
189
dtype: Optional[DTypeLike] = None,
190
out: Optional[OutArray] = None,
191
keepdims: KeepDims = False,
192
initial: NotImplementedType = None,
193
where: NotImplementedType = None,
195
assert dtype is None or isinstance(dtype, torch.dtype)
197
if dtype == torch.bool:
198
dtype = _dtypes_impl.default_dtypes().int_dtype
200
axis_kw = {} if axis is None else {"dim": axis}
201
return a.sum(dtype=dtype, **axis_kw)
207
axis: AxisLike = None,
208
dtype: Optional[DTypeLike] = None,
209
out: Optional[OutArray] = None,
210
keepdims: KeepDims = False,
211
initial: NotImplementedType = None,
212
where: NotImplementedType = None,
214
axis = _util.allow_only_single_axis(axis)
216
if dtype == torch.bool:
217
dtype = _dtypes_impl.default_dtypes().int_dtype
219
axis_kw = {} if axis is None else {"dim": axis}
220
return a.prod(dtype=dtype, **axis_kw)
229
axis: AxisLike = None,
230
dtype: Optional[DTypeLike] = None,
231
out: Optional[OutArray] = None,
232
keepdims: KeepDims = False,
234
where: NotImplementedType = None,
236
dtype = _atleast_float(dtype, a.dtype)
238
axis_kw = {} if axis is None else {"dim": axis}
239
result = a.mean(dtype=dtype, **axis_kw)
247
axis: AxisLike = None,
248
dtype: Optional[DTypeLike] = None,
249
out: Optional[OutArray] = None,
251
keepdims: KeepDims = False,
253
where: NotImplementedType = None,
256
dtype = _atleast_float(dtype, a.dtype)
257
tensor = _util.cast_if_needed(a, dtype)
258
result = tensor.std(dim=axis, correction=ddof)
259
return _util.cast_if_needed(result, in_dtype)
265
axis: AxisLike = None,
266
dtype: Optional[DTypeLike] = None,
267
out: Optional[OutArray] = None,
269
keepdims: KeepDims = False,
271
where: NotImplementedType = None,
274
dtype = _atleast_float(dtype, a.dtype)
275
tensor = _util.cast_if_needed(a, dtype)
276
result = tensor.var(dim=axis, correction=ddof)
277
return _util.cast_if_needed(result, in_dtype)
287
axis: AxisLike = None,
288
dtype: Optional[DTypeLike] = None,
289
out: Optional[OutArray] = None,
291
if dtype == torch.bool:
292
dtype = _dtypes_impl.default_dtypes().int_dtype
296
(a,), axis = _util.axis_none_flatten(a, axis=axis)
297
axis = _util.normalize_axis_index(axis, a.ndim)
299
return a.cumsum(axis=axis, dtype=dtype)
304
axis: AxisLike = None,
305
dtype: Optional[DTypeLike] = None,
306
out: Optional[OutArray] = None,
308
if dtype == torch.bool:
309
dtype = _dtypes_impl.default_dtypes().int_dtype
313
(a,), axis = _util.axis_none_flatten(a, axis=axis)
314
axis = _util.normalize_axis_index(axis, a.ndim)
316
return a.cumprod(axis=axis, dtype=dtype)
325
weights: ArrayLike = None,
331
result = mean(a, axis=axis)
332
wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
334
if not a.dtype.is_floating_point:
338
if a.shape != weights.shape:
341
"Axis must be specified when shapes of a and weights differ."
343
if weights.ndim != 1:
345
"1D weights expected when shapes of a and weights differ."
347
if weights.shape[0] != a.shape[axis]:
349
"Length of weights not compatible with specified axis."
353
weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape)
354
weights = weights.swapaxes(-1, axis)
357
result_dtype = _dtypes_impl.result_type_impl(a, weights)
358
numerator = sum(a * weights, axis, dtype=result_dtype)
359
wsum = sum(weights, axis, dtype=result_dtype)
360
result = numerator / wsum
364
result = _util.apply_keepdims(result, axis, a.ndim)
367
if wsum.shape != result.shape:
368
wsum = torch.broadcast_to(wsum, result.shape).clone()
378
axis: AxisLike = None,
379
out: Optional[OutArray] = None,
380
overwrite_input=False,
382
keepdims: KeepDims = False,
384
interpolation: NotImplementedType = None,
393
if not a.dtype.is_floating_point:
394
dtype = _dtypes_impl.default_dtypes().float_dtype
398
if a.dtype == torch.float16:
399
a = a.to(torch.float32)
406
axis = _util.normalize_axis_tuple(axis, a.ndim)
411
axis = _util.allow_only_single_axis(axis)
413
q = _util.cast_if_needed(q, a.dtype)
415
return torch.quantile(a, q, axis=axis, interpolation=method)
421
axis: AxisLike = None,
422
out: Optional[OutArray] = None,
423
overwrite_input=False,
425
keepdims: KeepDims = False,
427
interpolation: NotImplementedType = None,
430
if _dtypes_impl.python_type_for_torch(q.dtype) == int:
431
q = q.to(_dtypes_impl.default_dtypes().float_dtype)
438
overwrite_input=overwrite_input,
441
interpolation=interpolation,
448
out: Optional[OutArray] = None,
449
overwrite_input=False,
450
keepdims: KeepDims = False,
454
torch.as_tensor(0.5),
456
overwrite_input=overwrite_input,