pytorch

Форк
0
/
_reductions_impl.py 
459 строк · 11.5 Кб
1
# mypy: ignore-errors
2

3
""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc
4
in the 'public' layer.
5

6
Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
7
"""
8
from __future__ import annotations
9

10
import functools
11
from typing import Optional, TYPE_CHECKING
12

13
import torch
14

15
from . import _dtypes_impl, _util
16

17

18
if TYPE_CHECKING:
19
    from ._normalizations import (
20
        ArrayLike,
21
        AxisLike,
22
        DTypeLike,
23
        KeepDims,
24
        NotImplementedType,
25
        OutArray,
26
    )
27

28

29
def _deco_axis_expand(func):
30
    """
31
    Generically handle axis arguments in reductions.
32
    axis is *always* the 2nd arg in the function so no need to have a look at its signature
33
    """
34

35
    @functools.wraps(func)
36
    def wrapped(a, axis=None, *args, **kwds):
37
        if axis is not None:
38
            axis = _util.normalize_axis_tuple(axis, a.ndim)
39

40
        if axis == ():
41
            # So we insert a length-one axis and run the reduction along it.
42
            # We cannot return a.clone() as this would sidestep the checks inside the function
43
            newshape = _util.expand_shape(a.shape, axis=0)
44
            a = a.reshape(newshape)
45
            axis = (0,)
46

47
        return func(a, axis, *args, **kwds)
48

49
    return wrapped
50

51

52
def _atleast_float(dtype, other_dtype):
53
    """Return a dtype that is real or complex floating-point.
54

55
    For inputs that are boolean or integer dtypes, this returns the default
56
    float dtype; inputs that are complex get converted to the default complex
57
    dtype; real floating-point dtypes (`float*`) get passed through unchanged
58
    """
59
    if dtype is None:
60
        dtype = other_dtype
61
    if not (dtype.is_floating_point or dtype.is_complex):
62
        return _dtypes_impl.default_dtypes().float_dtype
63
    return dtype
64

65

66
@_deco_axis_expand
67
def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False):
68
    return a.count_nonzero(axis)
69

70

71
@_deco_axis_expand
72
def argmax(
73
    a: ArrayLike,
74
    axis: AxisLike = None,
75
    out: Optional[OutArray] = None,
76
    *,
77
    keepdims: KeepDims = False,
78
):
79
    if a.is_complex():
80
        raise NotImplementedError(f"argmax with dtype={a.dtype}.")
81

82
    axis = _util.allow_only_single_axis(axis)
83

84
    if a.dtype == torch.bool:
85
        # RuntimeError: "argmax_cpu" not implemented for 'Bool'
86
        a = a.to(torch.uint8)
87

88
    return torch.argmax(a, axis)
89

90

91
@_deco_axis_expand
92
def argmin(
93
    a: ArrayLike,
94
    axis: AxisLike = None,
95
    out: Optional[OutArray] = None,
96
    *,
97
    keepdims: KeepDims = False,
98
):
99
    if a.is_complex():
100
        raise NotImplementedError(f"argmin with dtype={a.dtype}.")
101

102
    axis = _util.allow_only_single_axis(axis)
103

104
    if a.dtype == torch.bool:
105
        # RuntimeError: "argmin_cpu" not implemented for 'Bool'
106
        a = a.to(torch.uint8)
107

108
    return torch.argmin(a, axis)
109

110

111
@_deco_axis_expand
112
def any(
113
    a: ArrayLike,
114
    axis: AxisLike = None,
115
    out: Optional[OutArray] = None,
116
    keepdims: KeepDims = False,
117
    *,
118
    where: NotImplementedType = None,
119
):
120
    axis = _util.allow_only_single_axis(axis)
121
    axis_kw = {} if axis is None else {"dim": axis}
122
    return torch.any(a, **axis_kw)
123

124

125
@_deco_axis_expand
126
def all(
127
    a: ArrayLike,
128
    axis: AxisLike = None,
129
    out: Optional[OutArray] = None,
130
    keepdims: KeepDims = False,
131
    *,
132
    where: NotImplementedType = None,
133
):
134
    axis = _util.allow_only_single_axis(axis)
135
    axis_kw = {} if axis is None else {"dim": axis}
136
    return torch.all(a, **axis_kw)
137

138

139
@_deco_axis_expand
140
def amax(
141
    a: ArrayLike,
142
    axis: AxisLike = None,
143
    out: Optional[OutArray] = None,
144
    keepdims: KeepDims = False,
145
    initial: NotImplementedType = None,
146
    where: NotImplementedType = None,
147
):
148
    if a.is_complex():
149
        raise NotImplementedError(f"amax with dtype={a.dtype}")
150

151
    return a.amax(axis)
152

153

154
max = amax
155

156

157
@_deco_axis_expand
158
def amin(
159
    a: ArrayLike,
160
    axis: AxisLike = None,
161
    out: Optional[OutArray] = None,
162
    keepdims: KeepDims = False,
163
    initial: NotImplementedType = None,
164
    where: NotImplementedType = None,
165
):
166
    if a.is_complex():
167
        raise NotImplementedError(f"amin with dtype={a.dtype}")
168

169
    return a.amin(axis)
170

171

172
min = amin
173

174

175
@_deco_axis_expand
176
def ptp(
177
    a: ArrayLike,
178
    axis: AxisLike = None,
179
    out: Optional[OutArray] = None,
180
    keepdims: KeepDims = False,
181
):
182
    return a.amax(axis) - a.amin(axis)
183

184

185
@_deco_axis_expand
186
def sum(
187
    a: ArrayLike,
188
    axis: AxisLike = None,
189
    dtype: Optional[DTypeLike] = None,
190
    out: Optional[OutArray] = None,
191
    keepdims: KeepDims = False,
192
    initial: NotImplementedType = None,
193
    where: NotImplementedType = None,
194
):
195
    assert dtype is None or isinstance(dtype, torch.dtype)
196

197
    if dtype == torch.bool:
198
        dtype = _dtypes_impl.default_dtypes().int_dtype
199

200
    axis_kw = {} if axis is None else {"dim": axis}
201
    return a.sum(dtype=dtype, **axis_kw)
202

203

204
@_deco_axis_expand
205
def prod(
206
    a: ArrayLike,
207
    axis: AxisLike = None,
208
    dtype: Optional[DTypeLike] = None,
209
    out: Optional[OutArray] = None,
210
    keepdims: KeepDims = False,
211
    initial: NotImplementedType = None,
212
    where: NotImplementedType = None,
213
):
214
    axis = _util.allow_only_single_axis(axis)
215

216
    if dtype == torch.bool:
217
        dtype = _dtypes_impl.default_dtypes().int_dtype
218

219
    axis_kw = {} if axis is None else {"dim": axis}
220
    return a.prod(dtype=dtype, **axis_kw)
221

222

223
product = prod
224

225

226
@_deco_axis_expand
227
def mean(
228
    a: ArrayLike,
229
    axis: AxisLike = None,
230
    dtype: Optional[DTypeLike] = None,
231
    out: Optional[OutArray] = None,
232
    keepdims: KeepDims = False,
233
    *,
234
    where: NotImplementedType = None,
235
):
236
    dtype = _atleast_float(dtype, a.dtype)
237

238
    axis_kw = {} if axis is None else {"dim": axis}
239
    result = a.mean(dtype=dtype, **axis_kw)
240

241
    return result
242

243

244
@_deco_axis_expand
245
def std(
246
    a: ArrayLike,
247
    axis: AxisLike = None,
248
    dtype: Optional[DTypeLike] = None,
249
    out: Optional[OutArray] = None,
250
    ddof=0,
251
    keepdims: KeepDims = False,
252
    *,
253
    where: NotImplementedType = None,
254
):
255
    in_dtype = dtype
256
    dtype = _atleast_float(dtype, a.dtype)
257
    tensor = _util.cast_if_needed(a, dtype)
258
    result = tensor.std(dim=axis, correction=ddof)
259
    return _util.cast_if_needed(result, in_dtype)
260

261

262
@_deco_axis_expand
263
def var(
264
    a: ArrayLike,
265
    axis: AxisLike = None,
266
    dtype: Optional[DTypeLike] = None,
267
    out: Optional[OutArray] = None,
268
    ddof=0,
269
    keepdims: KeepDims = False,
270
    *,
271
    where: NotImplementedType = None,
272
):
273
    in_dtype = dtype
274
    dtype = _atleast_float(dtype, a.dtype)
275
    tensor = _util.cast_if_needed(a, dtype)
276
    result = tensor.var(dim=axis, correction=ddof)
277
    return _util.cast_if_needed(result, in_dtype)
278

279

280
# cumsum / cumprod are almost reductions:
281
#   1. no keepdims
282
#   2. axis=None flattens
283

284

285
def cumsum(
286
    a: ArrayLike,
287
    axis: AxisLike = None,
288
    dtype: Optional[DTypeLike] = None,
289
    out: Optional[OutArray] = None,
290
):
291
    if dtype == torch.bool:
292
        dtype = _dtypes_impl.default_dtypes().int_dtype
293
    if dtype is None:
294
        dtype = a.dtype
295

296
    (a,), axis = _util.axis_none_flatten(a, axis=axis)
297
    axis = _util.normalize_axis_index(axis, a.ndim)
298

299
    return a.cumsum(axis=axis, dtype=dtype)
300

301

302
def cumprod(
303
    a: ArrayLike,
304
    axis: AxisLike = None,
305
    dtype: Optional[DTypeLike] = None,
306
    out: Optional[OutArray] = None,
307
):
308
    if dtype == torch.bool:
309
        dtype = _dtypes_impl.default_dtypes().int_dtype
310
    if dtype is None:
311
        dtype = a.dtype
312

313
    (a,), axis = _util.axis_none_flatten(a, axis=axis)
314
    axis = _util.normalize_axis_index(axis, a.ndim)
315

316
    return a.cumprod(axis=axis, dtype=dtype)
317

318

319
cumproduct = cumprod
320

321

322
def average(
323
    a: ArrayLike,
324
    axis=None,
325
    weights: ArrayLike = None,
326
    returned=False,
327
    *,
328
    keepdims=False,
329
):
330
    if weights is None:
331
        result = mean(a, axis=axis)
332
        wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
333
    else:
334
        if not a.dtype.is_floating_point:
335
            a = a.double()
336

337
        # axis & weights
338
        if a.shape != weights.shape:
339
            if axis is None:
340
                raise TypeError(
341
                    "Axis must be specified when shapes of a and weights differ."
342
                )
343
            if weights.ndim != 1:
344
                raise TypeError(
345
                    "1D weights expected when shapes of a and weights differ."
346
                )
347
            if weights.shape[0] != a.shape[axis]:
348
                raise ValueError(
349
                    "Length of weights not compatible with specified axis."
350
                )
351

352
            # setup weight to broadcast along axis
353
            weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape)
354
            weights = weights.swapaxes(-1, axis)
355

356
        # do the work
357
        result_dtype = _dtypes_impl.result_type_impl(a, weights)
358
        numerator = sum(a * weights, axis, dtype=result_dtype)
359
        wsum = sum(weights, axis, dtype=result_dtype)
360
        result = numerator / wsum
361

362
    # We process keepdims manually because the decorator does not deal with variadic returns
363
    if keepdims:
364
        result = _util.apply_keepdims(result, axis, a.ndim)
365

366
    if returned:
367
        if wsum.shape != result.shape:
368
            wsum = torch.broadcast_to(wsum, result.shape).clone()
369
        return result, wsum
370
    else:
371
        return result
372

373

374
# Not using deco_axis_expand as it assumes that axis is the second arg
375
def quantile(
376
    a: ArrayLike,
377
    q: ArrayLike,
378
    axis: AxisLike = None,
379
    out: Optional[OutArray] = None,
380
    overwrite_input=False,
381
    method="linear",
382
    keepdims: KeepDims = False,
383
    *,
384
    interpolation: NotImplementedType = None,
385
):
386
    if overwrite_input:
387
        # raise NotImplementedError("overwrite_input in quantile not implemented.")
388
        # NumPy documents that `overwrite_input` MAY modify inputs:
389
        # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
390
        # Here we choose to work out-of-place because why not.
391
        pass
392

393
    if not a.dtype.is_floating_point:
394
        dtype = _dtypes_impl.default_dtypes().float_dtype
395
        a = a.to(dtype)
396

397
    # edge case: torch.quantile only supports float32 and float64
398
    if a.dtype == torch.float16:
399
        a = a.to(torch.float32)
400

401
    if axis is None:
402
        a = a.flatten()
403
        q = q.flatten()
404
        axis = (0,)
405
    else:
406
        axis = _util.normalize_axis_tuple(axis, a.ndim)
407

408
    # FIXME(Mario) Doesn't np.quantile accept a tuple?
409
    # torch.quantile does accept a number. If we don't want to implement the tuple behaviour
410
    # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above.
411
    axis = _util.allow_only_single_axis(axis)
412

413
    q = _util.cast_if_needed(q, a.dtype)
414

415
    return torch.quantile(a, q, axis=axis, interpolation=method)
416

417

418
def percentile(
419
    a: ArrayLike,
420
    q: ArrayLike,
421
    axis: AxisLike = None,
422
    out: Optional[OutArray] = None,
423
    overwrite_input=False,
424
    method="linear",
425
    keepdims: KeepDims = False,
426
    *,
427
    interpolation: NotImplementedType = None,
428
):
429
    # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32
430
    if _dtypes_impl.python_type_for_torch(q.dtype) == int:
431
        q = q.to(_dtypes_impl.default_dtypes().float_dtype)
432
    qq = q / 100.0
433

434
    return quantile(
435
        a,
436
        qq,
437
        axis=axis,
438
        overwrite_input=overwrite_input,
439
        method=method,
440
        keepdims=keepdims,
441
        interpolation=interpolation,
442
    )
443

444

445
def median(
446
    a: ArrayLike,
447
    axis=None,
448
    out: Optional[OutArray] = None,
449
    overwrite_input=False,
450
    keepdims: KeepDims = False,
451
):
452
    return quantile(
453
        a,
454
        torch.as_tensor(0.5),
455
        axis=axis,
456
        overwrite_input=overwrite_input,
457
        out=out,
458
        keepdims=keepdims,
459
    )
460

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.