pytorch

Форк
0
/
_ops.py 
1795 строк · 63.8 Кб
1

2
import warnings
3

4
# A workaround to support both TorchScript and MyPy:
5
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
6

7
import torch
8
from torch import Tensor
9
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor
10
from . import _docs
11
from torch._prims_common import corresponding_real_dtype
12
from torch import sym_float
13

14
if TYPE_CHECKING:
15
    from torch.types import _dtype as DType
16

17
    DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
18
else:
19
    # The JIT doesn't understand Union, nor torch.dtype here
20
    DType = int
21
    DimOrDims = Optional[Tuple[int]]
22

23

24
__all__: List[str] = []
25

26
# All masked reduction/normalization operations have the same
27
# signatures. Here we introduce docstring templates that are applied
28
# to docstrings of reduction/normalization functions via
29
# _apply_docstring_templates decorator.
30

31

32
def _apply_docstring_templates(func):
33
    """Decorator that applies docstring templates to function docstring
34
    and returns the function instance.
35
    """
36

37
    doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
38
    if doc_string is None:
39
        warnings.warn(
40
            f"No documentation string available for {func.__name__}."
41
            " PyTorch team should run `python tools/update_masked_docs.py`"
42
            " to generate the missing docstrings."
43
        )
44
    else:
45
        func.__doc__ = doc_string
46

47
    # Expose function as public symbol
48
    __all__.append(func.__name__)
49

50
    return func
51

52

53
def _generate_docstring(func):
54
    """A utility function called from tools/update_masked_docs.py
55
    script to update the module torch.masked._docs.py
56
    """
57
    docstring_templates = dict(
58
        reduction_signature="""\
59
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
60
        reduction_descr="""\
61
Returns {operation name} of all the elements in the :attr:`input`
62
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
63
elements are masked out according to the boolean tensor
64
:attr:`mask`.""",
65
        reduction_args="""\
66
If :attr:`keepdim` is ``True``, the output tensor is of the same size
67
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
68
size 1. Otherwise, :attr:`dim` is squeezed (see
69
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
70
``len(dim)``) fewer dimension(s).
71

72
The boolean tensor :attr:`mask` defines the "validity" of
73
:attr:`input` tensor elements: if :attr:`mask` element is True
74
then the corresponding element in :attr:`input` tensor will be
75
included in {operation name} computation, otherwise the element is
76
ignored.
77

78
When all elements of :attr:`input` along the given dimension
79
:attr:`dim` are ignored (fully masked-out), the corresponding element
80
of the output tensor will have undefined value: it may or may not
81
correspond to the identity value of {operation name} operation; the
82
choice may correspond to the value that leads to the most efficient
83
storage of :attr:`output` tensor.
84

85
The mask of the output tensor can be computed as
86
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
87
dtype=torch.bool)``.
88

89
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
90
don't need to match, but they must be :ref:`broadcastable
91
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
92
tensor must not be greater than of the :attr:`input` tensor.
93

94
Args:
95
    input (Tensor): the input tensor
96
    {args_declarations}
97

98
Keyword args:
99
    {kwargs_declarations}""",
100
        reduction_example="""\
101
Example::
102

103
    >>> input = {example_input}
104
    >>> input
105
    {indent_example_input}
106
    >>> mask = {example_mask}
107
    >>> mask
108
    {indent_example_mask}
109
    >>> {full_function_name}(input, {example_args}, mask=mask)
110
    {indent_example_output}
111
""",
112
        reduction_identity="""\
113
The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
114
        reduction_identity_dtype="""\
115
The identity value of {operation name} operation, which is used to start the
116
reduction, depends on input dtype. For instance, for float32, uint8,
117
and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
118
        normalization_signature="""\
119
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
120
        normalization_descr="""\
121
Returns {operation name} of all the slices in the :attr:`input` tensor
122
along :attr:`dim` while the :attr:`input` elements are masked out
123
according to the boolean tensor :attr:`mask`.
124

125
{definition}""",
126
        normalization_args="""\
127
The boolean tensor :attr:`mask` defines the "validity" of
128
:attr:`input` tensor elements: if :attr:`mask` element is True then
129
the corresponding element in :attr:`input` tensor will be included in
130
{operation name} computation, otherwise the element is ignored.
131

132
The values of masked-out elements of the output tensor have undefined
133
value: it may or may not be set to zero or nan; the choice may correspond to
134
the value that leads to the most efficient storage of :attr:`output`
135
tensor.
136

137
The mask of the {operation name} output tensor can be computed as
138
``torch.broadcast_to(mask, input.shape)``.
139

140
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
141
don't need to match, but they must be :ref:`broadcastable
142
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
143
tensor must not be greater than of the :attr:`input` tensor.
144

145
Args:
146
    input (Tensor): the input tensor
147
    {args_declarations}
148

149
Keyword args:
150
    {kwargs_declarations}""",
151
        normalization_example="""\
152
Example::
153

154
    >>> input = {example_input}
155
    >>> input
156
    {indent_example_input}
157
    >>> mask = {example_mask}
158
    >>> mask
159
    {indent_example_mask}
160
    >>> {full_function_name}(input, {example_args}, mask=mask)
161
    {indent_example_output}
162
""",
163
    )
164

165
    args_and_kwargs = dict(
166
        # argument name sufficies separated by double underscore will
167
        # be removed in the final documentation string.
168
        sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
169
        prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
170
        cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
171
        cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
172
        amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
173
        amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
174
        argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
175
        argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
176
        mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
177
        median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
178
        norm=(
179
            (
180
                "ord",
181
                "dim",
182
            ),
183
            ("keepdim=False", "dtype=None", "mask=None"),
184
        ),
185
        var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
186
        std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
187
        logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
188
        softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
189
        log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
190
        softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
191
        normalize=(
192
            (
193
                "ord__required",
194
                "dim__as_int",
195
            ),
196
            ("eps=1e-12", "dtype=None", "mask=None"),
197
        ),
198
    )
199

200
    argument_declarations = dict(
201
        dim="""\
202
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
203
  Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
204
        dim__as_int="""\
205
dim (int): the dimension along which {operation name} is computed.""",
206
        ord="""\
207
ord (int, float, optional): the order of vector norm. Default: 2.
208
  See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
209
        ord__required="""\
210
ord (int, float): the order of vector norm. Default: 2.
211
  See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
212
        unbiased="""\
213
unbiased (bool): when True, use Bessel’s correction, otherwise, compute
214
  the uncorrected sample variance.""",
215
        eps="""\
216
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
217
        keepdim="""\
218
keepdim (bool, optional): whether the output tensor has
219
  :attr:`dim` retained or not. Default: {default}.""",
220
        dtype="""\
221
dtype (:class:`torch.dtype`, optional): the desired data type
222
  of returned tensor.  If specified, the input tensor is
223
  casted to :attr:`dtype` before the operation is
224
  performed. Default: {default}.""",
225
        mask="""\
226
mask (:class:`torch.Tensor`, optional): the boolean tensor
227
  containing the binary mask of validity of input tensor
228
  elements.
229
  Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
230
    )
231

232
    definitions = dict(
233
        softmax="""\
234
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
235
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
236
defined as ``exp(x[i])/sum(exp(x))``.""",
237
        log_softmax="""\
238
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
239
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
240
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
241
        softmin="""\
242
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
243
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
244
defined as ``exp(-x[i])/sum(exp(-x))``.""",
245
        normalize="""\
246
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
247
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
248
defined as ``x[i]/max(norm(x, p), eps)``.""",
249
        cumsum="""\
250
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
251
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
252
defined as ``sum(x[:i])``.""",
253
        cumprod="""\
254
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
255
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
256
defined as ``prod(x[:i])``.""",
257
    )
258

259
    reduction_names = dict(
260
        sum="sum",
261
        prod="product",
262
        amax="maximum",
263
        amin="minimum",
264
        argmax="argmax",
265
        argmin="argmin",
266
        mean="mean",
267
        median="median",
268
        norm="norm",
269
        var="variance",
270
        std="standard_deviation",
271
        logsumexp="logsumexp",
272
    )
273

274
    normalization_names = dict(
275
        softmax="softmax",
276
        log_softmax="log_softmax",
277
        softmin="softmin",
278
        normalize="normalize",
279
        cumsum="cumulative_sum",
280
        cumprod="cumulative_prod",
281
    )
282

283
    operation_names = {}
284
    operation_names.update(reduction_names)
285
    operation_names.update(normalization_names)
286

287
    # Default example data:
288
    example_dim = 1
289
    example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
290
    example_mask = torch.tensor([[True, False, True], [False, False, False]])
291
    example_args: Tuple[Any, ...]
292
    if func.__name__ in {"norm", "normalize"}:
293
        example_args = (2.0, example_dim)
294
        example_input = example_input.to(dtype=torch.float32)
295
    elif func.__name__ in {"var", "std"}:
296
        example_args = (example_dim, False)
297
    elif func.__name__ == "median":
298
        example_args = (example_dim,)
299
        example_input = example_input.to(dtype=torch.float32)
300
    else:
301
        example_args = (example_dim,)
302

303
    operation_args: Tuple[str, ...]
304
    operation_kwargs: Tuple[str, ...]
305
    operation_args, operation_kwargs = args_and_kwargs[func.__name__]
306
    arg_declarations = [
307
        "\n    ".join(
308
            argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines()
309
        )
310
        for a in operation_args
311
    ]
312
    kwarg_declarations = [
313
        "\n    ".join(
314
            argument_declarations.get(
315
                a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.'
316
            )
317
            .format(default=a.split("=", 1)[1])
318
            .splitlines()
319
        )
320
        for a in operation_kwargs
321
    ]
322

323
    if func.__name__ in reduction_names:
324
        op_kind = "reduction"
325
        doc_sections = ["signature", "descr", "identity", "args", "example"]
326
    elif func.__name__ in normalization_names:
327
        op_kind = "normalization"
328
        doc_sections = ["signature", "descr", "args", "example"]
329
        example_input = example_input.to(dtype=torch.float32)
330
    else:
331
        assert 0  # add function name to operation names dictionaries
332
    example_output = func(example_input, *example_args, mask=example_mask)
333

334
    template_data = {
335
        "function_name": func.__name__,
336
        "full_function_name": func.__module__ + "." + func.__name__,
337
        "operation name": operation_names[func.__name__],
338
        "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
339
        "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
340
        # one-line representation of a tensor:
341
        "example_input": " ".join(str(example_input).split()),
342
        "example_args": ", ".join(map(str, example_args)),
343
        "example_mask": " ".join(str(example_mask).split()),
344
        # multi-line representation of a tensor with indent
345
        "indent_example_input": ("\n    ").join(str(example_input).splitlines()),
346
        "indent_example_mask": ("\n    ").join(str(example_mask).splitlines()),
347
        "indent_example_output": ("\n    ").join(str(example_output).splitlines()),
348
    }
349

350
    if func.__name__ in reduction_names:
351
        template_data.update(
352
            identity_uint8=_reduction_identity(
353
                func.__name__, torch.tensor(0, dtype=torch.uint8)
354
            ),
355
            identity_int32=_reduction_identity(
356
                func.__name__, torch.tensor(0, dtype=torch.int32)
357
            ),
358
            identity_float32=_reduction_identity(
359
                func.__name__, torch.tensor(0, dtype=torch.float32)
360
            ),
361
        )
362
        if func.__name__ == "norm":
363
            template_data.update(
364
                identity_ord_ninf=_reduction_identity(
365
                    func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
366
                )
367
            )
368
    elif func.__name__ in normalization_names:
369
        template_data.update(definition=definitions[func.__name__])
370
    else:
371
        assert 0  # add function name to operation names dictionaries
372
    template_data.update(
373
        args_declarations=("\n    ".join(arg_declarations)).format_map(template_data)
374
    )
375
    template_data.update(
376
        kwargs_declarations=("\n    ".join(kwarg_declarations)).format_map(
377
            template_data
378
        )
379
    )
380

381
    # Apply function name info to docstring templates:
382
    templates = {
383
        k: v.format_map(template_data)
384
        for k, v in docstring_templates.items()
385
        if k.startswith(op_kind)
386
    }
387
    templates.update(
388
        (k, v.format_map(template_data) if isinstance(v, str) else v)
389
        for k, v in template_data.items()
390
    )
391

392
    # Apply docstring templates to function doctring:
393
    if func.__doc__ is None:
394
        doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
395
    else:
396
        doc_template = func.__doc__
397
    return doc_template.format_map(templates)
398

399

400
def _reduction_identity(op_name: str, input: Tensor, *args):
401
    """Return identity value as scalar tensor of a reduction operation on
402
    given input, or None, if the identity value cannot be uniquely
403
    defined for the given input.
404

405
    The identity value of the operation is defined as the initial
406
    value to reduction operation that has a property ``op(op_identity,
407
    value) == value`` for any value in the domain of the operation.
408
    Or put it another way, including or excluding the identity value in
409
    a list of operands will not change the reduction result.
410

411
    See https://github.com/pytorch/rfcs/pull/27 for more information.
412

413
    """
414
    dtype: DType = input.dtype
415
    device = input.device
416
    op_name = op_name.rsplit(".", 1)[-1]  # lstrip module name when present
417
    if op_name in {"sum", "cumsum"}:
418
        return torch.tensor(0, dtype=dtype, device=device)
419
    elif op_name in {"prod", "cumprod"}:
420
        return torch.tensor(1, dtype=dtype, device=device)
421
    elif op_name in {"amax", "argmax", "logsumexp"}:
422
        if torch.is_floating_point(input):
423
            return torch.tensor(-torch.inf, dtype=dtype, device=device)
424
        elif torch.is_signed(input) or dtype == torch.uint8:
425
            return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
426
    elif op_name in {"amin", "argmin"}:
427
        if torch.is_floating_point(input):
428
            return torch.tensor(torch.inf, dtype=dtype, device=device)
429
        elif torch.is_signed(input) or dtype == torch.uint8:
430
            return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
431
    elif op_name == "mean":
432
        # Strictly speaking, the identity value of the mean operation
433
        # is the mean of the input. Since the mean value depends on
434
        # the dim argument and it may be a non-scalar tensor, we
435
        # consider the identity value of the mean operation ambiguous.
436
        # Moreover, the mean value of empty input is undefined.
437
        return None
438
    elif op_name == "norm":
439
        ord = args[0] if args else 2
440
        if ord == float("-inf"):
441
            assert torch.is_floating_point(input), input.dtype
442
            return torch.tensor(torch.inf, dtype=dtype, device=device)
443
        return torch.tensor(0, dtype=dtype, device=device)
444
    elif op_name == "median":
445
        # We use NaN for now because the implementation is currently using torch.nanmedian
446
        # and NaN is the identity for that function since it gets ignored
447
        dtype = input.dtype if torch.is_floating_point(input) else torch.float
448
        return torch.tensor(torch.nan, dtype=dtype, device=device)
449
    elif op_name in {"var", "std"}:
450
        return None
451
    raise NotImplementedError(f"identity of {op_name} on {dtype} input")
452

453

454
def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
455
    """Return dim argument as a tuple of sorted dim values."""
456
    dims: List[int] = []
457
    if dim == ():
458
        # Currently, `dim=()` in reductions operations means "reduce
459
        # over all dimensions" while in future, it will read "no
460
        # reduce". See https://github.com/pytorch/pytorch/issues/29137
461
        # When gh-29137 is resolved, this if-block must be deleted.
462
        dim = None
463
    if dim is None:
464
        return tuple(range(ndim))
465
    ndim = max(ndim, 1)
466
    dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
467
    for d in dim_:
468
        if d in dims:
469
            raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
470
        if d >= ndim or d < -ndim:
471
            raise IndexError(
472
                f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})"
473
            )
474
        dims.append(d % ndim)
475
    return tuple(sorted(dims))
476

477

478
def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
479
    # Flatted N-D indices to 1-D indices
480
    flat_indices = indices.new_zeros(indices.size(1))
481
    for d, sz in enumerate(shape):
482
        flat_indices.mul_(sz)
483
        flat_indices.add_(indices[d])
484
    return flat_indices
485

486

487
def _any(input: Tensor, dim: tuple, keepdim: bool):
488
    # Support torch.any with tuple dim argument.
489
    # Workaround of https://github.com/pytorch/pytorch/issues/56586
490
    r = input
491
    for d in reversed(dim):
492
        r = r.any(dim=d, keepdim=keepdim)
493
    return r
494

495

496
def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
497
    """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
498

499
    _sparse_coo_where implements the following invariant:
500

501
      _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
502
        torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
503

504
    where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
505
    tensor, and `to_dense(fill_value)` is like `to_dense()` except
506
    that the unspecified elements are mapped to `fill_value` rather
507
    than to `0`.
508

509
    Returns a sparse COO tensor with the following features:
510

511
    - all specified elements correspond to masked-in elements that
512
      have the values of the input tensor. If there exists a masked-in
513
      element (as specified by mask) that is not specified in the
514
      input, in the result tensor, the corresponding element has value
515
      0. In the dense part of the sparse tensor, the masked-out
516
      elements are replaced with fill_value.
517

518
    - all unspecified elements correspond to masked-out elements.
519
    """
520

521
    assert input.layout == torch.sparse_coo
522
    assert mask.layout == input.layout
523
    assert mask.shape == input.shape
524
    assert mask.dense_dim() == input.dense_dim()  # TODO: eliminate this restriction
525

526
    input = input.coalesce()
527

528
    # For set operations on sparse tensor indices, we'll convert
529
    # multi-dimensional indices to 1-D indices for efficiency.
530
    input_flat_indices = _sparse_coo_flatten_indices(
531
        input.indices(), input.shape[: input.sparse_dim()]
532
    )
533
    mask_flat_indices = _sparse_coo_flatten_indices(
534
        mask.indices(), mask.shape[: mask.sparse_dim()]
535
    )
536

537
    # the set of mask flat indices that define masked-in elements:
538
    if mask.dense_dim() > 0:
539
        mask_values = _any(
540
            mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
541
        )
542
    else:
543
        mask_values = mask.values()
544
    maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
545

546
    def intersection(i1, i2):
547
        union, counts = torch.cat([i1, i2]).unique(return_counts=True)
548
        return union, torch.where(counts.gt(1))
549

550
    def minus(i1, i2):
551
        union, counts = torch.cat([i1, i2]).unique(return_counts=True)
552
        return intersection(union[torch.where(counts.eq(1))], i1)
553

554
    def _apply(a):
555
        obj, w = a
556
        return obj[w]
557

558
    # the set of input flat indices of specified and masked-in elements:
559
    maskin_input_flat_indices = _apply(
560
        intersection(maskin_flat_indices, input_flat_indices)
561
    )
562
    _, w = intersection(input_flat_indices, maskin_input_flat_indices)
563

564
    # the indices and values of masked-in elements
565
    where_input_indices = input.indices()[(slice(None),) + w]
566
    where_input_values = input.values()[w]
567

568
    if mask.dense_dim() > 0:
569
        # apply mask to the dense part of the input values:
570
        _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
571
        where_mask_values = mask.values()[w1]
572
        where_input_values = torch.where(
573
            where_mask_values, where_input_values, fill_value
574
        )
575

576
    # the set of flat indices of unspecified input and masked-in elements:
577
    maskin_zero_flat_indices = _apply(
578
        minus(maskin_flat_indices, maskin_input_flat_indices)
579
    )
580

581
    # the indices of masked-in zero elements
582
    _, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
583
    where_zero_indices = mask.indices()[(slice(None),) + w]
584

585
    # construct result
586
    n = where_zero_indices.size(1)
587
    if n == 0:
588
        # the input is coalesced, hence input_flat_indices are ordered
589
        # and the result is guaranteed to be coalesced:
590
        result = torch.sparse_coo_tensor(
591
            where_input_indices, where_input_values, input.shape
592
        )
593
        return result._coalesced_(True)
594

595
    where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
596
    where_values = torch.cat(
597
        [
598
            where_input_values,
599
            where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
600
        ]
601
    )
602
    result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
603

604
    # appending zero elements leads to uncoalesced sparse tensor
605
    return result.coalesce()
606

607

608
def _sparse_coo_scatter_reduction_helper(
609
    op,
610
    mask_input: Tensor,
611
    dims: Tuple[int, ...],
612
    keepdim: bool,
613
    dtype: Optional[DType] = None,
614
) -> Tensor:
615
    reduce = op.__name__
616
    valid_reductions = ["sum", "prod", "amax", "amin"]
617
    if reduce not in valid_reductions:
618
        raise ValueError(
619
            f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
620
        )
621

622
    output_dtype = dtype
623
    values, indices = mask_input._values(), mask_input._indices()
624
    input_dims = mask_input.dim()
625
    num_sparse_dims = mask_input.sparse_dim()
626
    reduced_sparse_dims = []
627
    retained_sparse_dims = []
628
    reduced_dense_dims = []
629

630
    # promote dtype if specified
631
    if values.dtype != output_dtype:
632
        values = values.to(output_dtype)
633

634
    if keepdim:
635
        output_shape = tuple(
636
            1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
637
        )
638
    else:
639
        output_shape = tuple(
640
            si for (i, si) in enumerate(mask_input.shape) if i not in dims
641
        )
642

643
    for d in dims:
644
        if d >= input_dims:
645
            continue
646

647
        if d < num_sparse_dims:
648
            reduced_sparse_dims.append(d)
649
        else:
650
            reduced_dense_dims.append(d + 1 - num_sparse_dims)
651

652
    # Reduce dense dimensions
653
    if len(reduced_dense_dims) > 0:
654
        if reduce == "sum":
655
            new_values = values
656
            new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
657
        else:
658
            # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
659
            return NotImplemented
660
    else:
661
        new_values = values.clone()
662

663
    # Reduce sparse dimensions
664
    if len(reduced_sparse_dims) == num_sparse_dims:
665
        if reduce in {"amax", "amin"} and new_values.size(0) == 0:
666
            # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
667
            # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
668
            # See https://github.com/pytorch/pytorch/issues/61901
669
            new_values = _reduction_identity(reduce, new_values)
670
        else:
671
            new_values = op(new_values, dim=0)
672
        if keepdim:
673
            for _ in range(num_sparse_dims):
674
                new_values = new_values.unsqueeze(0)
675
        return new_values.to(dtype=output_dtype).to_sparse()
676
    else:
677
        new_indices = indices.clone()
678
        if keepdim:
679
            # zero out reduced sparse dimensions if keepdim = True
680
            # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
681
            new_indices[reduced_sparse_dims, :] = 0
682
        else:
683
            # remove reduced sparse dimensions if keepdim = False
684
            if len(reduced_sparse_dims) > 0:
685
                retained_sparse_dims = [
686
                    i
687
                    for i in range(num_sparse_dims)
688
                    if i not in set(reduced_sparse_dims)
689
                ]
690
                new_indices = new_indices.index_select(
691
                    0, torch.tensor(retained_sparse_dims).to(mask_input.device)
692
                )
693

694
    # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
695
    if new_indices.numel() > 0:
696
        # lexsort indices and get index tensor for scatter reduction
697
        new_indices, inverse_indices = torch.unique(
698
            new_indices, return_inverse=True, dim=1
699
        )
700
        out_shape = list(new_values.shape)
701
        out_shape[0] = new_indices.shape[1]
702
        for _ in range(new_values.ndim - 1):
703
            inverse_indices = inverse_indices.unsqueeze(-1)
704
        scatter_indices = inverse_indices.expand(new_values.shape)
705
        # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
706
        if output_dtype in {torch.bfloat16, torch.float16}:
707
            new_values = new_values.to(torch.float)
708
            out = new_values.new_empty(out_shape)
709
            new_values = out.scatter_reduce_(
710
                0, scatter_indices, new_values, reduce=reduce, include_self=False
711
            )
712
            new_values = new_values.to(dtype=output_dtype)
713
        else:
714
            out = new_values.new_empty(out_shape)
715
            new_values = out.scatter_reduce_(
716
                0, scatter_indices, new_values, reduce=reduce, include_self=False
717
            )
718

719
    return torch.sparse_coo_tensor(
720
        new_indices,
721
        new_values,
722
        output_shape,
723
        dtype=output_dtype,
724
        device=mask_input.device,
725
    )
726

727

728
def _sparse_csr_segment_reduction_helper(
729
    op,
730
    mask_input: Tensor,
731
    dims: Tuple[int, ...],
732
    keepdim: bool,
733
    dtype: Optional[DType] = None,
734
) -> Tensor:
735
    # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
736
    # FIXME: when dense dimensions are implemented for CSR tensors
737
    assert (
738
        keepdim
739
    ), "reduction operations on CSR tensors with keepdim=False is unsupported"
740
    reduce = op.__name__
741
    valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
742
    if reduce not in valid_reductions:
743
        raise ValueError(
744
            f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
745
        )
746
    device = mask_input.device
747
    output_dtype = dtype
748
    values, crow_indices, col_indices = (
749
        mask_input.values(),
750
        mask_input.crow_indices(),
751
        mask_input.col_indices(),
752
    )
753

754
    # promote dtype if specified
755
    if values.dtype != output_dtype:
756
        values = values.to(output_dtype)
757

758
    if len(dims) == 0:
759
        return mask_input
760
    if len(dims) == 1:
761
        if dims[0] == 0:
762
            new_col_indices, scatter_indices = torch.unique(
763
                col_indices, return_inverse=True
764
            )
765
            new_nnz = new_col_indices.shape[0]
766
            new_crow_indices = torch.tensor([0, new_nnz])
767
            new_values = values.new_empty(new_col_indices.shape)
768
            new_values.scatter_reduce_(
769
                0, scatter_indices, values, reduce, include_self=False
770
            )
771
            new_shape = [1, mask_input.size(1)]
772
        else:
773
            assert (
774
                dims[0] == 1
775
            ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
776
            # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
777
            # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
778
            new_crow_indices = torch.cat(
779
                (
780
                    crow_indices.new_zeros(1),
781
                    torch.cumsum(torch.diff(crow_indices) != 0, 0),
782
                ),
783
                0,
784
            )
785
            new_nnz = new_crow_indices[-1]
786
            new_col_indices = col_indices.new_zeros(new_nnz)
787
            new_values = torch._segment_reduce(values, reduce, offsets=crow_indices)  # type: ignore[attr-defined]
788
            new_shape = [mask_input.size(0), 1]
789
    else:
790
        assert len(dims) == 2
791
        nnz = min(1, values.numel())
792
        if nnz == 1:
793
            op_kwargs = {"keepdim": True, "dtype": output_dtype}
794
            # amax and amin do not support dtype kwarg
795
            if reduce in ["amax", "amin"]:
796
                del op_kwargs["dtype"]
797
            new_values = op(values, 0, **op_kwargs)
798
        else:
799
            new_values = torch.empty(0, dtype=output_dtype)
800
        new_col_indices = col_indices.new_zeros(nnz)
801
        new_crow_indices = torch.tensor([0, nnz])
802
        new_shape = [1, nnz]
803

804
    return torch.sparse_csr_tensor(
805
        new_crow_indices,
806
        new_col_indices,
807
        new_values,
808
        new_shape,
809
        dtype=output_dtype,
810
        device=device,
811
    )
812

813

814
def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
815
    """Sparse variant of torch.where. Supports sparse CSR tensors."""
816
    # TODO: implement sparse CSR specific where operator for efficiency
817
    return _sparse_coo_where(
818
        mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
819
    ).to_sparse_csr()
820

821

822
def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
823
    """torch.where with sparse inputs support.
824

825
    _where implements the following invariant:
826

827
      _where(mask, input, fill_value).to_dense(fill_value) ==
828
        torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
829

830
    where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
831
    tensor, and `to_dense(fill_value)` is like `to_dense()` except
832
    that the unspecified elements are mapped to `fill_value` rather
833
    than to `0`.
834

835
    Returns a sparse tensor with the following features:
836

837
    - all specified elements correspond to masked-in elements that
838
      have the values of the input tensor. If there exists a masked-in
839
      element (as specified by mask) that is not specified in the
840
      input, in the result tensor, the corresponding element has value
841
      0. In the dense part of the sparse tensor, the masked-out
842
      elements are replaced with fill_value.
843

844
    - all unspecified elements correspond to masked-out elements.
845
    """
846
    if mask.layout == torch.strided:
847
        return torch.where(mask, input, fill_value)
848
    elif mask.layout == torch.sparse_coo:
849
        return _sparse_coo_where(mask, input, fill_value)
850
    elif mask.layout == torch.sparse_csr:
851
        return _sparse_csr_where(mask, input, fill_value)
852
    else:
853
        raise ValueError(
854
            f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
855
        )
856

857

858
def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor:
859
    """Return canonical input mask.
860

861
    A canonical input mask is defined as a boolean mask tensor that
862
    shape and layout matches with the shape and the layout of the
863
    input.
864

865
    The canonical input mask is computed from the :attr:`mask` tensor
866
    content to meet the following criteria:
867

868
    1. The shape of the canonical input mask is the same as the shape
869
       of :attr:`input` tensor. If the mask tensor has a smaller shape
870
       than the shape of the :attr:`input`, broadcasting rules will be
871
       applied. Downcasting of mask is not supported.
872

873
    2. The layout of the canonical input mask is the same as the
874
       layout of the :attr:`input` tensor. If the mask has different
875
       layout, it will be converted to the expected layout.  In the
876
       case of sparse COO layout, the canonical input mask will be
877
       coalesced.
878

879
    3. The dtype of the canonical input mask is torch.bool. If the
880
       mask dtype is not bool then it will be converted to bool dtype
881
       using `.to(dtype=bool)` method call.
882

883
    4. The elements of the canonical input mask have boolean values
884
       copied from the content of the :attr:`mask` tensor (after
885
       possible broadcasting and dtype conversion transforms).  In
886
       general, the sparsity pattern of the sparse canonical input
887
       mask need not to be the same as the sparsity pattern of the
888
       sparse :attr:`input` tensor.
889

890
    """
891
    if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
892
        raise ValueError(
893
            f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
894
        )
895

896
    mask = kwargs.get("mask")
897

898
    # default mask
899
    if mask is None:
900
        raise ValueError("_input_mask requires explicit mask")
901

902
    # mask shape must match with input shape
903
    if mask.shape != input.shape:
904
        if mask.ndim > input.ndim:
905
            raise IndexError(
906
                "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
907
            )
908
        if mask.layout == torch.strided:
909
            mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
910
        elif mask.layout == torch.sparse_coo:
911
            mask = torch._sparse_broadcast_to(mask, input.shape)
912
        else:
913
            assert mask.layout == torch.sparse_csr
914
            # Broadcasting of CSR tensors is not implemented. Working
915
            # around by using COO layout.
916
            mask = torch._sparse_broadcast_to(
917
                mask.to_sparse(), input.shape
918
            ).to_sparse_csr()
919

920
    # mask layout must match with input layout
921
    if mask.layout != input.layout:
922
        if input.layout == torch.strided:
923
            mask = mask.to_dense()
924
        elif input.layout == torch.sparse_coo:
925
            if mask.layout == torch.strided:
926
                mask = mask.to_sparse(input.sparse_dim())
927
            else:
928
                mask = mask.to_sparse()
929
        else:
930
            assert input.layout == torch.sparse_csr
931
            mask = mask.to_sparse_csr()
932

933
    # sparse mask must be coalesced
934
    if mask.layout == torch.sparse_coo:
935
        mask = mask.coalesce()
936

937
    # mask is a boolean tensor
938
    mask = mask.to(dtype=torch.bool)
939

940
    return mask
941

942

943
def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
944
    """Return output mask of masked operation applied to given arguments."""
945
    if callable(op):
946
        is_reduction = op.__name__ in {
947
            "sum",
948
            "prod",
949
            "amax",
950
            "amin",
951
            "argmax",
952
            "argmin",
953
            "mean",
954
            "median",
955
            "norm",
956
            "var",
957
            "std",
958
            "logsumexp",
959
        }
960
        is_normalization = op.__name__ in {
961
            "softmax",
962
            "log_softmax",
963
            "softmin",
964
            "normalize",
965
            "cumsum",
966
            "cumprod",
967
        }
968
        if is_reduction:
969
            if op.__name__ == "norm":
970
                if args:
971
                    args = args[1:]  # lstrip ord argument
972
            dim = args[0] if args else kwargs.get("dim")
973
            outmask = _input_mask(input, *args, **kwargs)
974
            keepdim = kwargs.get("keepdim", False)
975
            dim_ = _canonical_dim(dim, input.ndim)
976
            return _any(outmask, dim_, bool(keepdim))
977
        elif is_normalization:
978
            return _input_mask(input, *args, **kwargs)
979
        else:
980
            raise ValueError(
981
                f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
982
            )
983
    else:
984
        raise ValueError(
985
            f"_output_mask expected masked operation (got {type(op).__name__} object)"
986
        )
987

988

989
def _combine_input_and_mask(
990
    op, input: Union[MaskedTensor, Tensor], mask, *args
991
) -> Tensor:
992
    def helper(input, mask):
993
        if mask is None:
994
            return input
995
        canonical_mask = _input_mask(input, mask=mask)
996
        if callable(op):
997
            fill_value = _reduction_identity(op.__name__, input, *args)
998
            return _where(canonical_mask, input, fill_value)
999
        else:
1000
            raise ValueError(
1001
                f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
1002
            )
1003

1004
    class Combine(torch.autograd.Function):
1005
        @staticmethod
1006
        def forward(ctx, input, mask):
1007
            """Return input with masked-out elements eliminated for the given operations."""
1008
            ctx.save_for_backward(mask)
1009

1010
            if mask is not None:
1011
                ctx.mark_non_differentiable(mask)
1012

1013
            return helper(input, mask)
1014

1015
        @staticmethod
1016
        def backward(ctx, grad_output):
1017
            (mask,) = ctx.saved_tensors
1018
            grad_data = (
1019
                grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
1020
            )
1021
            result = as_masked_tensor(grad_data, mask)
1022
            return result, None
1023

1024
    return (
1025
        Combine.apply(input.get_data(), input.get_mask())  # type: ignore[union-attr]
1026
        if is_masked_tensor(input)
1027
        else helper(input, mask)
1028
    )
1029

1030

1031
@_apply_docstring_templates
1032
def sum(
1033
    input: Union[Tensor, MaskedTensor],
1034
    dim: DimOrDims = None,
1035
    *,
1036
    keepdim: Optional[bool] = False,
1037
    dtype: Optional[DType] = None,
1038
    mask: Optional[Tensor] = None,
1039
) -> Tensor:
1040
    # __doc__ is generated by _apply_docstring_templates decorator
1041
    if dtype is None:
1042
        # promote integer types to int64 when output dtype is not specified
1043
        if input.layout == torch.sparse_csr:
1044
            if input.dtype in {
1045
                torch.uint8,
1046
                torch.bool,
1047
                torch.int8,
1048
                torch.int16,
1049
                torch.int32,
1050
            }:
1051
                # csr.to(dtype=torch.int64) is not implemented, so
1052
                # using coo.to on input to ensure the promoted dtype
1053
                input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
1054
            else:
1055
                dtype = input.dtype
1056
        else:
1057
            dtype = input.dtype
1058
            if input.dtype in {
1059
                torch.uint8,
1060
                torch.bool,
1061
                torch.int8,
1062
                torch.int16,
1063
                torch.int32,
1064
            }:
1065
                dtype = torch.int64
1066
    dim_ = _canonical_dim(dim, input.ndim)
1067
    mask_input = _combine_input_and_mask(sum, input, mask)
1068
    if mask_input.layout == torch.strided:
1069
        return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
1070
    elif mask_input.layout == torch.sparse_coo:
1071
        return _sparse_coo_scatter_reduction_helper(
1072
            torch.sum, mask_input, dim_, bool(keepdim), dtype
1073
        )
1074
    elif mask_input.layout == torch.sparse_csr:
1075
        return torch._sparse_csr_sum(
1076
            mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
1077
        )
1078
    else:
1079
        raise ValueError(
1080
            f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1081
        )
1082

1083

1084
@_apply_docstring_templates
1085
def prod(
1086
    input: Union[Tensor, MaskedTensor],
1087
    dim: DimOrDims = None,
1088
    *,
1089
    keepdim: Optional[bool] = False,
1090
    dtype: Optional[DType] = None,
1091
    mask: Optional[Tensor] = None,
1092
) -> Tensor:
1093
    # __doc__ is generated by _apply_docstring_templates decorator
1094
    if dtype is None:
1095
        # promote integer types to int64 when output dtype is not specified
1096
        if input.layout == torch.sparse_csr:
1097
            if input.dtype in {
1098
                torch.uint8,
1099
                torch.bool,
1100
                torch.int8,
1101
                torch.int16,
1102
                torch.int32,
1103
            }:
1104
                # csr.to(dtype=torch.int64) is not implemented, so
1105
                # using coo.to on input to ensure the promoted dtype
1106
                input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
1107
            else:
1108
                dtype = input.dtype
1109
        else:
1110
            dtype = input.dtype
1111
            if input.dtype in {
1112
                torch.uint8,
1113
                torch.bool,
1114
                torch.int8,
1115
                torch.int16,
1116
                torch.int32,
1117
            }:
1118
                dtype = torch.int64
1119
    dim_ = _canonical_dim(dim, input.ndim)
1120
    mask_input = _combine_input_and_mask(prod, input, mask)
1121
    if mask_input.layout == torch.strided:
1122
        # Workaround https://github.com/pytorch/pytorch/issues/56586
1123
        result = mask_input
1124
        result = result.to(dtype=dtype)
1125
        for d in reversed(dim_):
1126
            result = result.prod(dim=d, keepdim=bool(keepdim))
1127
        return result
1128
    elif mask_input.layout == torch.sparse_coo:
1129
        if mask is None:
1130
            # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
1131
            raise ValueError(
1132
                "masked prod expects explicit mask for sparse_coo tensor input"
1133
            )
1134
        return _sparse_coo_scatter_reduction_helper(
1135
            torch.prod, mask_input, dim_, bool(keepdim), dtype
1136
        )
1137
    elif mask_input.layout == torch.sparse_csr:
1138
        if mask is None:
1139
            # mask is None corresponds to all-True mask. The
1140
            # unspecified elements in the CSR tensor correspond to
1141
            # zero values. Hence, the prod reduction result is
1142
            # automatically zero unless all elements are specified.
1143
            # A semi-optimal way to take this into account is to use:
1144
            #
1145
            #   masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
1146
            #
1147
            # but that requires implementing `all` and `nonzero`
1148
            # support for sparse csr tensors.
1149
            raise ValueError(
1150
                "masked prod expects explicit mask for sparse_csr tensor input"
1151
            )
1152
        return torch._sparse_csr_prod(
1153
            mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
1154
        )
1155
    else:
1156
        raise ValueError(
1157
            f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1158
        )
1159

1160

1161
@_apply_docstring_templates
1162
def cumsum(
1163
    input: Tensor,
1164
    dim: int,
1165
    *,
1166
    dtype: Optional[DType] = None,
1167
    mask: Optional[Tensor] = None,
1168
) -> Tensor:
1169
    if dtype is None:
1170
        dtype = input.dtype
1171
    dim_ = _canonical_dim(dim, input.ndim)[0]
1172
    mask_input = _combine_input_and_mask(sum, input, mask)
1173
    if mask_input.layout == torch.strided:
1174
        return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
1175
    else:
1176
        raise ValueError(
1177
            f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
1178
        )
1179

1180

1181
@_apply_docstring_templates
1182
def cumprod(
1183
    input: Tensor,
1184
    dim: int,
1185
    *,
1186
    dtype: Optional[DType] = None,
1187
    mask: Optional[Tensor] = None,
1188
) -> Tensor:
1189
    if dtype is None:
1190
        dtype = input.dtype
1191
    dim_ = _canonical_dim(dim, input.ndim)[0]
1192
    mask_input = _combine_input_and_mask(prod, input, mask)
1193
    if mask_input.layout == torch.strided:
1194
        return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
1195
    else:
1196
        raise ValueError(
1197
            f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
1198
        )
1199

1200

1201
@_apply_docstring_templates
1202
def amax(
1203
    input: Union[Tensor, MaskedTensor],
1204
    dim: DimOrDims = None,
1205
    *,
1206
    keepdim: Optional[bool] = False,
1207
    dtype: Optional[DType] = None,
1208
    mask: Optional[Tensor] = None,
1209
) -> Tensor:
1210
    """\
1211
{reduction_signature}
1212

1213
{reduction_descr}
1214

1215
{reduction_identity_dtype}
1216

1217
{reduction_args}
1218

1219
{reduction_example}"""
1220
    if dtype is None:
1221
        dtype = input.dtype
1222

1223
    mask_input = _combine_input_and_mask(amax, input, mask)
1224
    dim_ = _canonical_dim(dim, mask_input.ndim)
1225
    if mask_input.layout == torch.strided:
1226
        return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
1227
    elif mask_input.layout == torch.sparse_coo:
1228
        if mask is None:
1229
            # See comment in the sparse_csr branch of prod, a similar issue arises here
1230
            # where unspecified elements along a dimension may need to be reduced with the result
1231
            raise ValueError(
1232
                "masked amax expects explicit mask for sparse_coo tensor input"
1233
            )
1234
        return _sparse_coo_scatter_reduction_helper(
1235
            torch.amax, mask_input, dim_, bool(keepdim), dtype
1236
        )
1237
    elif mask_input.layout == torch.sparse_csr:
1238
        if mask is None:
1239
            raise ValueError(
1240
                "masked amax expects explicit mask for sparse_csr tensor input"
1241
            )
1242
        return _sparse_csr_segment_reduction_helper(
1243
            torch.amax, mask_input, dim_, bool(keepdim), dtype
1244
        )
1245
    else:
1246
        raise ValueError(
1247
            f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1248
        )
1249

1250

1251
@_apply_docstring_templates
1252
def amin(
1253
    input: Union[Tensor, MaskedTensor],
1254
    dim: DimOrDims = None,
1255
    *,
1256
    keepdim: Optional[bool] = False,
1257
    dtype: Optional[DType] = None,
1258
    mask: Optional[Tensor] = None,
1259
) -> Tensor:
1260
    """\
1261
{reduction_signature}
1262

1263
{reduction_descr}
1264

1265
{reduction_identity_dtype}
1266

1267
{reduction_args}
1268

1269
{reduction_example}"""
1270
    if dtype is None:
1271
        dtype = input.dtype
1272

1273
    mask_input = _combine_input_and_mask(amin, input, mask)
1274
    dim_ = _canonical_dim(dim, mask_input.ndim)
1275
    if mask_input.layout == torch.strided:
1276
        return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
1277
    elif mask_input.layout == torch.sparse_coo:
1278
        if mask is None:
1279
            # See comment in the sparse_csr branch of prod, a similar issue arises here
1280
            # where unspecified elements along a dimension may need to be reduced with the result
1281
            raise ValueError(
1282
                "masked amax expects explicit mask for sparse_coo tensor input"
1283
            )
1284
        return _sparse_coo_scatter_reduction_helper(
1285
            torch.amin, mask_input, dim_, bool(keepdim), dtype
1286
        )
1287
    elif mask_input.layout == torch.sparse_csr:
1288
        if mask is None:
1289
            raise ValueError(
1290
                "masked amin expects explicit mask for sparse_csr tensor input"
1291
            )
1292
        return _sparse_csr_segment_reduction_helper(
1293
            torch.amin, mask_input, dim_, bool(keepdim), dtype
1294
        )
1295
    else:
1296
        raise ValueError(
1297
            f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1298
        )
1299

1300

1301
@_apply_docstring_templates
1302
def argmax(
1303
    input: Union[Tensor, MaskedTensor],
1304
    dim: Optional[int] = None,
1305
    *,
1306
    keepdim: Optional[bool] = False,
1307
    dtype: Optional[DType] = None,
1308
    mask: Optional[Tensor] = None,
1309
) -> Tensor:
1310
    """\
1311
{reduction_signature}
1312
{reduction_descr}
1313
{reduction_identity_dtype}
1314
{reduction_args}
1315
{reduction_example}"""
1316
    if dtype is None:
1317
        dtype = input.dtype
1318
    mask_input = _combine_input_and_mask(argmax, input, mask)
1319
    if mask_input.layout == torch.strided:
1320
        return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
1321
    else:
1322
        raise ValueError(
1323
            f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
1324
        )
1325

1326

1327
@_apply_docstring_templates
1328
def argmin(
1329
    input: Union[Tensor, MaskedTensor],
1330
    dim: Optional[int] = None,
1331
    *,
1332
    keepdim: Optional[bool] = False,
1333
    dtype: Optional[DType] = None,
1334
    mask: Optional[Tensor] = None,
1335
) -> Tensor:
1336
    """\
1337
{reduction_signature}
1338
{reduction_descr}
1339
{reduction_identity_dtype}
1340
{reduction_args}
1341
{reduction_example}"""
1342
    if dtype is None:
1343
        dtype = input.dtype
1344
    mask_input = _combine_input_and_mask(argmin, input, mask)
1345
    if mask_input.layout == torch.strided:
1346
        return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
1347
    else:
1348
        raise ValueError(
1349
            f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
1350
        )
1351

1352

1353
@_apply_docstring_templates
1354
def mean(
1355
    input: Union[Tensor, MaskedTensor],
1356
    dim: DimOrDims = None,
1357
    *,
1358
    keepdim: Optional[bool] = False,
1359
    dtype: Optional[DType] = None,
1360
    mask: Optional[Tensor] = None,
1361
) -> Tensor:
1362
    """\
1363
{reduction_signature}
1364

1365
{reduction_descr}
1366

1367
By definition, the identity value of a mean operation is the mean
1368
value of the tensor. If all elements of the input tensor along given
1369
dimension(s) :attr:`dim` are masked-out, the identity value of the
1370
mean is undefined.  Due to this ambiguity, the elements of output
1371
tensor with strided layout, that correspond to fully masked-out
1372
elements, have ``nan`` values.
1373

1374
{reduction_args}
1375

1376
{reduction_example}"""
1377
    if dtype is None:
1378
        dtype = input.dtype
1379
    if input.layout == torch.strided:
1380
        if mask is None:
1381
            # TODO: compute count analytically
1382
            count = sum(
1383
                torch.ones(input.shape, dtype=torch.int64, device=input.device),
1384
                dim,
1385
                keepdim=keepdim,
1386
            )
1387
            total = sum(input, dim, keepdim=keepdim, dtype=dtype)
1388
        else:
1389
            inmask = _input_mask(input, mask=mask)
1390
            count = sum(
1391
                inmask.new_ones(input.shape, dtype=torch.int64),
1392
                dim,
1393
                keepdim=keepdim,
1394
                mask=inmask,
1395
            )
1396
            total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
1397
        return total / count
1398
    elif input.layout == torch.sparse_csr:
1399
        mask_input = _combine_input_and_mask(mean, input, mask)
1400
        dim_ = _canonical_dim(dim, mask_input.ndim)
1401
        if mask is None:
1402
            raise ValueError(
1403
                "masked mean expects explicit mask for sparse_csr tensor input"
1404
            )
1405
        return _sparse_csr_segment_reduction_helper(
1406
            torch.mean, mask_input, dim_, bool(keepdim), dtype
1407
        )
1408
    else:
1409
        raise ValueError(
1410
            f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
1411
        )
1412

1413

1414
@_apply_docstring_templates
1415
def median(
1416
    input: Union[Tensor, MaskedTensor],
1417
    dim: int = -1,
1418
    *,
1419
    keepdim: bool = False,
1420
    dtype: Optional[DType] = None,
1421
    mask: Optional[Tensor] = None,
1422
) -> Tensor:
1423

1424
    """\
1425
{reduction_signature}
1426
{reduction_descr}
1427
By definition, the identity value of a median operation is the median
1428
value of the tensor. If all elements of the input tensor along given
1429
dimension(s) :attr:`dim` are masked-out, the identity value of the
1430
median is undefined.  Due to this ambiguity, the elements of output
1431
tensor with strided layout, that correspond to fully masked-out
1432
elements, have ``nan`` values.
1433
{reduction_args}
1434
{reduction_example}"""
1435
    if dtype is None:
1436
        dtype = input.dtype
1437
    dim_ = _canonical_dim(dim, input.ndim)[0]
1438
    is_float = torch.is_floating_point(input)
1439
    if not is_float:
1440
        input = input.to(dtype=torch.float)
1441
    mask_input = _combine_input_and_mask(median, input, mask)
1442
    if mask_input.layout == torch.strided:
1443
        output = torch.nanmedian(mask_input, dim_, keepdim).values
1444
        if is_float:
1445
            return output
1446
        elif not is_float and not torch.isnan(output).any():
1447
            return output.to(dtype=dtype)
1448
        else:
1449
            raise ValueError(
1450
                "masked median expects no fully masked out rows if dtype is not floating point"
1451
            )
1452
    else:
1453
        raise ValueError(
1454
            f"masked median expects strided tensor (got {mask_input.layout} tensor)"
1455
        )
1456

1457

1458
@_apply_docstring_templates
1459
def logsumexp(
1460
    input: Tensor,
1461
    dim: DimOrDims = None,
1462
    *,
1463
    keepdim: bool = False,
1464
    dtype: Optional[DType] = None,
1465
    mask: Optional[Tensor] = None,
1466
) -> Tensor:
1467
    if dtype is None:
1468
        dtype = input.dtype
1469
    dim_ = _canonical_dim(dim, input.ndim)
1470
    mask_input = _combine_input_and_mask(logsumexp, input, mask)
1471
    if mask_input.layout == torch.strided:
1472
        return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
1473
    else:
1474
        raise ValueError(
1475
            f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
1476
        )
1477

1478

1479
# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
1480
def logaddexp(
1481
    input: Union[Tensor, MaskedTensor],
1482
    other: Union[Tensor, MaskedTensor],
1483
    *,
1484
    dtype: Optional[DType] = None,
1485
    input_mask: Optional[Tensor] = None,
1486
    other_mask: Optional[Tensor] = None,
1487
) -> Tensor:
1488
    """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
1489

1490
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
1491
tensor. The :attr:`input` elements are masked out according to the boolean tensor
1492
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
1493
:attr:`other_mask`.
1494

1495
The shapes of a mask tensor and the tensor to be masked
1496
don't need to match, but they must be :ref:`broadcastable
1497
<broadcasting-semantics>` and the dimensionality of the mask
1498
tensor must not be greater than of the tensor to be masked.
1499

1500
Args:
1501
    input (Tensor): the input tensor
1502
    other (Tensor): the second input tensor
1503

1504
Keyword args:
1505
    dtype (:class:`torch.dtype`, optional): the desired data type
1506
      of returned tensor.  If specified, the output tensor is
1507
      casted to :attr:`dtype` after the operation is
1508
      performed. Default: None.
1509
    input_mask (:class:`torch.Tensor`, optional): the boolean tensor
1510
      containing the binary mask of validity of :attr:`input` tensor elements.
1511
      Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
1512
    other_mask (:class:`torch.Tensor`, optional): the boolean tensor
1513
      containing the binary mask of validity of :attr:`other` tensor elements.
1514
      Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
1515

1516
Example::
1517

1518
    >>> input = torch.tensor([-100.0, -200, -300])
1519
    >>> input
1520
    tensor([-100., -200., -300.])
1521
    >>> other = torch.tensor([-1.0, -2, -3])
1522
    >>> other
1523
    tensor([-1., -2., -3.])
1524
    >>> mask = torch.tensor([True, False, True])
1525
    >>> mask
1526
    tensor([ True, False,  True])
1527
    >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
1528
    tensor([-1., -inf, -3.])
1529
"""
1530
    if dtype is None:
1531
        dtype = input.dtype
1532
    if input.layout == torch.strided and other.layout == torch.strided:
1533
        mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
1534
        mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
1535
        return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
1536
    else:
1537
        raise ValueError(
1538
            f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
1539
        )
1540

1541

1542
@_apply_docstring_templates
1543
def norm(
1544
    input: Union[Tensor, MaskedTensor],
1545
    ord: Optional[float] = 2.0,
1546
    dim: DimOrDims = None,
1547
    *,
1548
    keepdim: Optional[bool] = False,
1549
    dtype: Optional[DType] = None,
1550
    mask: Optional[Tensor] = None,
1551
) -> Tensor:
1552
    """\
1553
{reduction_signature}
1554

1555
{reduction_descr}
1556

1557
The identity value of norm operation, which is used to start the
1558
reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
1559
``{identity_ord_ninf}``.
1560

1561
{reduction_args}
1562

1563
{reduction_example}"""
1564
    if dtype is None:
1565
        dtype = input.dtype
1566
    mask_input = _combine_input_and_mask(norm, input, mask, ord)
1567
    if mask_input.layout == torch.strided:
1568
        dim_ = _canonical_dim(dim, input.ndim)
1569
        return torch.linalg.vector_norm(
1570
            mask_input, ord, dim_, bool(keepdim), dtype=dtype
1571
        )
1572
    else:
1573
        raise ValueError(
1574
            f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
1575
        )
1576

1577

1578
def _std_var(
1579
    input: Union[Tensor, MaskedTensor],
1580
    dim: DimOrDims,
1581
    unbiased: Optional[bool],
1582
    *,
1583
    correction_opt: Optional[Union[int, float]],
1584
    keepdim: Optional[bool],
1585
    dtype: Optional[DType],
1586
    mask: Optional[Tensor],
1587
    take_sqrt: Optional[bool],
1588
) -> Tensor:
1589
    assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given"
1590
    correction = 1.0
1591
    if unbiased is not None:
1592
        correction = 1.0 if unbiased else 0.0
1593
    if correction_opt is not None:
1594
        correction = sym_float(correction_opt)
1595

1596
    if dtype is None:
1597
        dtype = input.dtype
1598
        if not (dtype.is_floating_point or dtype.is_complex):
1599
            dtype = torch.float32
1600
    compute_dtype = dtype
1601
    if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
1602
        compute_dtype = torch.float32
1603
    if input.layout == torch.strided:
1604
        if mask is None:
1605
            # TODO: compute count analytically
1606
            count = sum(
1607
                torch.ones(input.shape, dtype=torch.int64, device=input.device),
1608
                dim,
1609
                keepdim=True,
1610
            )
1611
            sample_total = sum(input, dim, keepdim=True, dtype=dtype)
1612
        else:
1613
            inmask = _input_mask(input, mask=mask)
1614
            count = sum(
1615
                inmask.new_ones(input.shape, dtype=torch.int64),
1616
                dim,
1617
                keepdim=True,
1618
                mask=inmask,
1619
            )
1620
            sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
1621
        # TODO: replace torch.subtract/divide/square/maximum with
1622
        # masked subtract/divide/square/maximum when these will be
1623
        # available.
1624
        sample_mean = torch.divide(sample_total, count)
1625
        x = torch.subtract(input, sample_mean)
1626
        if mask is None:
1627
            total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
1628
        else:
1629
            total = sum(
1630
                x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask  # type: ignore[possibly-undefined]
1631
            )
1632
        if not keepdim:
1633
            count = count.reshape(total.shape)
1634
        if correction != 0:
1635
            real_dtype = (corresponding_real_dtype(compute_dtype)
1636
                          if compute_dtype.is_complex else compute_dtype)
1637
            count = count.to(real_dtype)
1638
            count = torch.subtract(count, correction)
1639
            count = torch.maximum(count, count.new_zeros([]))
1640
        output = torch.divide(total, count).to(dtype=dtype)
1641
        if take_sqrt:
1642
            output = torch.sqrt(output)
1643
        return output
1644
    else:
1645
        raise ValueError(
1646
            f"masked std/var expects strided tensor (got {input.layout} tensor)"
1647
        )
1648

1649

1650
@_apply_docstring_templates
1651
def var(
1652
    input: Union[Tensor, MaskedTensor],
1653
    dim: DimOrDims = None,
1654
    unbiased: Optional[bool] = None,
1655
    *,
1656
    correction: Optional[Union[int, float]] = None,
1657
    keepdim: Optional[bool] = False,
1658
    dtype: Optional[DType] = None,
1659
    mask: Optional[Tensor] = None,
1660
) -> Tensor:
1661
    """\
1662
{reduction_signature}
1663
{reduction_descr}
1664
The identity value of sample variance operation is undefined. The
1665
elements of output tensor with strided layout, that correspond to
1666
fully masked-out elements, have ``nan`` values.
1667
{reduction_args}
1668
{reduction_example}"""
1669
    return _std_var(
1670
        input=input,
1671
        dim=dim,
1672
        unbiased=unbiased,
1673
        correction_opt=correction,
1674
        keepdim=keepdim,
1675
        dtype=dtype,
1676
        mask=mask,
1677
        take_sqrt=False,
1678
    )
1679

1680

1681
@_apply_docstring_templates
1682
def std(
1683
    input: Union[Tensor, MaskedTensor],
1684
    dim: DimOrDims = None,
1685
    unbiased: Optional[bool] = None,
1686
    *,
1687
    correction: Optional[int] = None,
1688
    keepdim: Optional[bool] = False,
1689
    dtype: Optional[DType] = None,
1690
    mask: Optional[Tensor] = None,
1691
) -> Tensor:
1692
    """\
1693
{reduction_signature}
1694
{reduction_descr}
1695
The identity value of sample standard deviation operation is undefined. The
1696
elements of output tensor with strided layout, that correspond to
1697
fully masked-out elements, have ``nan`` values.
1698
{reduction_args}
1699
{reduction_example}"""
1700
    return _std_var(
1701
        input=input,
1702
        dim=dim,
1703
        unbiased=unbiased,
1704
        correction_opt=correction,
1705
        keepdim=keepdim,
1706
        dtype=dtype,
1707
        mask=mask,
1708
        take_sqrt=True,
1709
    )
1710

1711

1712
@_apply_docstring_templates
1713
def softmax(
1714
    input: Union[Tensor, MaskedTensor],
1715
    dim: int,
1716
    *,
1717
    dtype: Optional[DType] = None,
1718
    mask: Optional[Tensor] = None,
1719
) -> Tensor:
1720
    if dtype is None:
1721
        dtype = input.dtype
1722
    dim_ = _canonical_dim(dim, input.ndim)[0]
1723
    mask_input = _combine_input_and_mask(amax, input, mask)
1724
    if mask_input.layout == torch.strided:
1725
        return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
1726
    else:
1727
        raise ValueError(
1728
            f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
1729
        )
1730

1731

1732
@_apply_docstring_templates
1733
def log_softmax(
1734
    input: Union[Tensor, MaskedTensor],
1735
    dim: int,
1736
    *,
1737
    dtype: Optional[DType] = None,
1738
    mask: Optional[Tensor] = None,
1739
) -> Tensor:
1740
    if dtype is None:
1741
        dtype = input.dtype
1742
    dim_ = _canonical_dim(dim, input.ndim)[0]
1743
    mask_input = _combine_input_and_mask(amax, input, mask)
1744
    if mask_input.layout == torch.strided:
1745
        return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
1746
    else:
1747
        raise ValueError(
1748
            f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
1749
        )
1750

1751

1752
@_apply_docstring_templates
1753
def softmin(
1754
    input: Union[Tensor, MaskedTensor],
1755
    dim: int,
1756
    *,
1757
    dtype: Optional[DType] = None,
1758
    mask: Optional[Tensor] = None,
1759
) -> Tensor:
1760
    if dtype is None:
1761
        dtype = input.dtype
1762
    dim_ = _canonical_dim(dim, input.ndim)[0]
1763
    mask_input = _combine_input_and_mask(amin, input, mask)
1764
    if mask_input.layout == torch.strided:
1765
        return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
1766
    else:
1767
        raise ValueError(
1768
            f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
1769
        )
1770

1771

1772
@_apply_docstring_templates
1773
def normalize(
1774
    input: Union[Tensor, MaskedTensor],
1775
    ord: float,
1776
    dim: int,
1777
    *,
1778
    eps: float = 1e-12,
1779
    dtype: Optional[DType] = None,
1780
    mask: Optional[Tensor] = None,
1781
) -> Tensor:
1782
    if dtype is None:
1783
        dtype = input.dtype
1784
    dim_ = _canonical_dim(dim, input.ndim)[0]
1785
    # TODO: eliminate mask_input as unnecessary when using masked divide.
1786
    mask_input = _combine_input_and_mask(sum, input, mask)
1787
    if mask_input.layout == torch.strided:
1788
        nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
1789
        # TODO: replace torch.maximum with masked maximum when available.
1790
        denom = torch.maximum(nrm_, nrm_.new_full([], eps))
1791
        # TODO: replace torch.divide with masked divide when available.
1792
        return torch.divide(mask_input, denom)
1793
    else:
1794
        raise ValueError(
1795
            f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
1796
        )
1797

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.