4
# A workaround to support both TorchScript and MyPy:
5
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
8
from torch import Tensor
9
from torch.masked import as_masked_tensor, is_masked_tensor, MaskedTensor
11
from torch._prims_common import corresponding_real_dtype
12
from torch import sym_float
15
from torch.types import _dtype as DType
17
DimOrDims = Optional[Union[int, Tuple[int], List[int]]]
19
# The JIT doesn't understand Union, nor torch.dtype here
21
DimOrDims = Optional[Tuple[int]]
24
__all__: List[str] = []
26
# All masked reduction/normalization operations have the same
27
# signatures. Here we introduce docstring templates that are applied
28
# to docstrings of reduction/normalization functions via
29
# _apply_docstring_templates decorator.
32
def _apply_docstring_templates(func):
33
"""Decorator that applies docstring templates to function docstring
34
and returns the function instance.
37
doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
38
if doc_string is None:
40
f"No documentation string available for {func.__name__}."
41
" PyTorch team should run `python tools/update_masked_docs.py`"
42
" to generate the missing docstrings."
45
func.__doc__ = doc_string
47
# Expose function as public symbol
48
__all__.append(func.__name__)
53
def _generate_docstring(func):
54
"""A utility function called from tools/update_masked_docs.py
55
script to update the module torch.masked._docs.py
57
docstring_templates = dict(
58
reduction_signature="""\
59
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
61
Returns {operation name} of all the elements in the :attr:`input`
62
tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
63
elements are masked out according to the boolean tensor
66
If :attr:`keepdim` is ``True``, the output tensor is of the same size
67
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
68
size 1. Otherwise, :attr:`dim` is squeezed (see
69
:func:`torch.squeeze`), resulting in the output tensor having 1 (or
70
``len(dim)``) fewer dimension(s).
72
The boolean tensor :attr:`mask` defines the "validity" of
73
:attr:`input` tensor elements: if :attr:`mask` element is True
74
then the corresponding element in :attr:`input` tensor will be
75
included in {operation name} computation, otherwise the element is
78
When all elements of :attr:`input` along the given dimension
79
:attr:`dim` are ignored (fully masked-out), the corresponding element
80
of the output tensor will have undefined value: it may or may not
81
correspond to the identity value of {operation name} operation; the
82
choice may correspond to the value that leads to the most efficient
83
storage of :attr:`output` tensor.
85
The mask of the output tensor can be computed as
86
``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
89
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
90
don't need to match, but they must be :ref:`broadcastable
91
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
92
tensor must not be greater than of the :attr:`input` tensor.
95
input (Tensor): the input tensor
99
{kwargs_declarations}""",
100
reduction_example="""\
103
>>> input = {example_input}
105
{indent_example_input}
106
>>> mask = {example_mask}
108
{indent_example_mask}
109
>>> {full_function_name}(input, {example_args}, mask=mask)
110
{indent_example_output}
112
reduction_identity="""\
113
The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
114
reduction_identity_dtype="""\
115
The identity value of {operation name} operation, which is used to start the
116
reduction, depends on input dtype. For instance, for float32, uint8,
117
and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
118
normalization_signature="""\
119
{function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
120
normalization_descr="""\
121
Returns {operation name} of all the slices in the :attr:`input` tensor
122
along :attr:`dim` while the :attr:`input` elements are masked out
123
according to the boolean tensor :attr:`mask`.
126
normalization_args="""\
127
The boolean tensor :attr:`mask` defines the "validity" of
128
:attr:`input` tensor elements: if :attr:`mask` element is True then
129
the corresponding element in :attr:`input` tensor will be included in
130
{operation name} computation, otherwise the element is ignored.
132
The values of masked-out elements of the output tensor have undefined
133
value: it may or may not be set to zero or nan; the choice may correspond to
134
the value that leads to the most efficient storage of :attr:`output`
137
The mask of the {operation name} output tensor can be computed as
138
``torch.broadcast_to(mask, input.shape)``.
140
The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
141
don't need to match, but they must be :ref:`broadcastable
142
<broadcasting-semantics>` and the dimensionality of the :attr:`mask`
143
tensor must not be greater than of the :attr:`input` tensor.
146
input (Tensor): the input tensor
150
{kwargs_declarations}""",
151
normalization_example="""\
154
>>> input = {example_input}
156
{indent_example_input}
157
>>> mask = {example_mask}
159
{indent_example_mask}
160
>>> {full_function_name}(input, {example_args}, mask=mask)
161
{indent_example_output}
165
args_and_kwargs = dict(
166
# argument name sufficies separated by double underscore will
167
# be removed in the final documentation string.
168
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
169
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
170
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
171
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
172
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
173
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
174
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
175
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
176
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
177
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
183
("keepdim=False", "dtype=None", "mask=None"),
185
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
186
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
187
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
188
softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
189
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
190
softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
196
("eps=1e-12", "dtype=None", "mask=None"),
200
argument_declarations = dict(
202
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
203
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
205
dim (int): the dimension along which {operation name} is computed.""",
207
ord (int, float, optional): the order of vector norm. Default: 2.
208
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
210
ord (int, float): the order of vector norm. Default: 2.
211
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
213
unbiased (bool): when True, use Bessel’s correction, otherwise, compute
214
the uncorrected sample variance.""",
216
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
218
keepdim (bool, optional): whether the output tensor has
219
:attr:`dim` retained or not. Default: {default}.""",
221
dtype (:class:`torch.dtype`, optional): the desired data type
222
of returned tensor. If specified, the input tensor is
223
casted to :attr:`dtype` before the operation is
224
performed. Default: {default}.""",
226
mask (:class:`torch.Tensor`, optional): the boolean tensor
227
containing the binary mask of validity of input tensor
229
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
234
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
235
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
236
defined as ``exp(x[i])/sum(exp(x))``.""",
238
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
239
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
240
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
242
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
243
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
244
defined as ``exp(-x[i])/sum(exp(-x))``.""",
246
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
247
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
248
defined as ``x[i]/max(norm(x, p), eps)``.""",
250
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
251
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
252
defined as ``sum(x[:i])``.""",
254
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
255
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
256
defined as ``prod(x[:i])``.""",
259
reduction_names = dict(
270
std="standard_deviation",
271
logsumexp="logsumexp",
274
normalization_names = dict(
276
log_softmax="log_softmax",
278
normalize="normalize",
279
cumsum="cumulative_sum",
280
cumprod="cumulative_prod",
284
operation_names.update(reduction_names)
285
operation_names.update(normalization_names)
287
# Default example data:
289
example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
290
example_mask = torch.tensor([[True, False, True], [False, False, False]])
291
example_args: Tuple[Any, ...]
292
if func.__name__ in {"norm", "normalize"}:
293
example_args = (2.0, example_dim)
294
example_input = example_input.to(dtype=torch.float32)
295
elif func.__name__ in {"var", "std"}:
296
example_args = (example_dim, False)
297
elif func.__name__ == "median":
298
example_args = (example_dim,)
299
example_input = example_input.to(dtype=torch.float32)
301
example_args = (example_dim,)
303
operation_args: Tuple[str, ...]
304
operation_kwargs: Tuple[str, ...]
305
operation_args, operation_kwargs = args_and_kwargs[func.__name__]
308
argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines()
310
for a in operation_args
312
kwarg_declarations = [
314
argument_declarations.get(
315
a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.'
317
.format(default=a.split("=", 1)[1])
320
for a in operation_kwargs
323
if func.__name__ in reduction_names:
324
op_kind = "reduction"
325
doc_sections = ["signature", "descr", "identity", "args", "example"]
326
elif func.__name__ in normalization_names:
327
op_kind = "normalization"
328
doc_sections = ["signature", "descr", "args", "example"]
329
example_input = example_input.to(dtype=torch.float32)
331
assert 0 # add function name to operation names dictionaries
332
example_output = func(example_input, *example_args, mask=example_mask)
335
"function_name": func.__name__,
336
"full_function_name": func.__module__ + "." + func.__name__,
337
"operation name": operation_names[func.__name__],
338
"operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
339
"operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
340
# one-line representation of a tensor:
341
"example_input": " ".join(str(example_input).split()),
342
"example_args": ", ".join(map(str, example_args)),
343
"example_mask": " ".join(str(example_mask).split()),
344
# multi-line representation of a tensor with indent
345
"indent_example_input": ("\n ").join(str(example_input).splitlines()),
346
"indent_example_mask": ("\n ").join(str(example_mask).splitlines()),
347
"indent_example_output": ("\n ").join(str(example_output).splitlines()),
350
if func.__name__ in reduction_names:
351
template_data.update(
352
identity_uint8=_reduction_identity(
353
func.__name__, torch.tensor(0, dtype=torch.uint8)
355
identity_int32=_reduction_identity(
356
func.__name__, torch.tensor(0, dtype=torch.int32)
358
identity_float32=_reduction_identity(
359
func.__name__, torch.tensor(0, dtype=torch.float32)
362
if func.__name__ == "norm":
363
template_data.update(
364
identity_ord_ninf=_reduction_identity(
365
func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
368
elif func.__name__ in normalization_names:
369
template_data.update(definition=definitions[func.__name__])
371
assert 0 # add function name to operation names dictionaries
372
template_data.update(
373
args_declarations=("\n ".join(arg_declarations)).format_map(template_data)
375
template_data.update(
376
kwargs_declarations=("\n ".join(kwarg_declarations)).format_map(
381
# Apply function name info to docstring templates:
383
k: v.format_map(template_data)
384
for k, v in docstring_templates.items()
385
if k.startswith(op_kind)
388
(k, v.format_map(template_data) if isinstance(v, str) else v)
389
for k, v in template_data.items()
392
# Apply docstring templates to function doctring:
393
if func.__doc__ is None:
394
doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
396
doc_template = func.__doc__
397
return doc_template.format_map(templates)
400
def _reduction_identity(op_name: str, input: Tensor, *args):
401
"""Return identity value as scalar tensor of a reduction operation on
402
given input, or None, if the identity value cannot be uniquely
403
defined for the given input.
405
The identity value of the operation is defined as the initial
406
value to reduction operation that has a property ``op(op_identity,
407
value) == value`` for any value in the domain of the operation.
408
Or put it another way, including or excluding the identity value in
409
a list of operands will not change the reduction result.
411
See https://github.com/pytorch/rfcs/pull/27 for more information.
414
dtype: DType = input.dtype
415
device = input.device
416
op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present
417
if op_name in {"sum", "cumsum"}:
418
return torch.tensor(0, dtype=dtype, device=device)
419
elif op_name in {"prod", "cumprod"}:
420
return torch.tensor(1, dtype=dtype, device=device)
421
elif op_name in {"amax", "argmax", "logsumexp"}:
422
if torch.is_floating_point(input):
423
return torch.tensor(-torch.inf, dtype=dtype, device=device)
424
elif torch.is_signed(input) or dtype == torch.uint8:
425
return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
426
elif op_name in {"amin", "argmin"}:
427
if torch.is_floating_point(input):
428
return torch.tensor(torch.inf, dtype=dtype, device=device)
429
elif torch.is_signed(input) or dtype == torch.uint8:
430
return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
431
elif op_name == "mean":
432
# Strictly speaking, the identity value of the mean operation
433
# is the mean of the input. Since the mean value depends on
434
# the dim argument and it may be a non-scalar tensor, we
435
# consider the identity value of the mean operation ambiguous.
436
# Moreover, the mean value of empty input is undefined.
438
elif op_name == "norm":
439
ord = args[0] if args else 2
440
if ord == float("-inf"):
441
assert torch.is_floating_point(input), input.dtype
442
return torch.tensor(torch.inf, dtype=dtype, device=device)
443
return torch.tensor(0, dtype=dtype, device=device)
444
elif op_name == "median":
445
# We use NaN for now because the implementation is currently using torch.nanmedian
446
# and NaN is the identity for that function since it gets ignored
447
dtype = input.dtype if torch.is_floating_point(input) else torch.float
448
return torch.tensor(torch.nan, dtype=dtype, device=device)
449
elif op_name in {"var", "std"}:
451
raise NotImplementedError(f"identity of {op_name} on {dtype} input")
454
def _canonical_dim(dim: DimOrDims, ndim: int) -> Tuple[int, ...]:
455
"""Return dim argument as a tuple of sorted dim values."""
458
# Currently, `dim=()` in reductions operations means "reduce
459
# over all dimensions" while in future, it will read "no
460
# reduce". See https://github.com/pytorch/pytorch/issues/29137
461
# When gh-29137 is resolved, this if-block must be deleted.
464
return tuple(range(ndim))
466
dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
469
raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
470
if d >= ndim or d < -ndim:
472
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim-1}], but got {d})"
474
dims.append(d % ndim)
475
return tuple(sorted(dims))
478
def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
479
# Flatted N-D indices to 1-D indices
480
flat_indices = indices.new_zeros(indices.size(1))
481
for d, sz in enumerate(shape):
482
flat_indices.mul_(sz)
483
flat_indices.add_(indices[d])
487
def _any(input: Tensor, dim: tuple, keepdim: bool):
488
# Support torch.any with tuple dim argument.
489
# Workaround of https://github.com/pytorch/pytorch/issues/56586
491
for d in reversed(dim):
492
r = r.any(dim=d, keepdim=keepdim)
496
def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
497
"""Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
499
_sparse_coo_where implements the following invariant:
501
_sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
502
torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
504
where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
505
tensor, and `to_dense(fill_value)` is like `to_dense()` except
506
that the unspecified elements are mapped to `fill_value` rather
509
Returns a sparse COO tensor with the following features:
511
- all specified elements correspond to masked-in elements that
512
have the values of the input tensor. If there exists a masked-in
513
element (as specified by mask) that is not specified in the
514
input, in the result tensor, the corresponding element has value
515
0. In the dense part of the sparse tensor, the masked-out
516
elements are replaced with fill_value.
518
- all unspecified elements correspond to masked-out elements.
521
assert input.layout == torch.sparse_coo
522
assert mask.layout == input.layout
523
assert mask.shape == input.shape
524
assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction
526
input = input.coalesce()
528
# For set operations on sparse tensor indices, we'll convert
529
# multi-dimensional indices to 1-D indices for efficiency.
530
input_flat_indices = _sparse_coo_flatten_indices(
531
input.indices(), input.shape[: input.sparse_dim()]
533
mask_flat_indices = _sparse_coo_flatten_indices(
534
mask.indices(), mask.shape[: mask.sparse_dim()]
537
# the set of mask flat indices that define masked-in elements:
538
if mask.dense_dim() > 0:
540
mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
543
mask_values = mask.values()
544
maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
546
def intersection(i1, i2):
547
union, counts = torch.cat([i1, i2]).unique(return_counts=True)
548
return union, torch.where(counts.gt(1))
551
union, counts = torch.cat([i1, i2]).unique(return_counts=True)
552
return intersection(union[torch.where(counts.eq(1))], i1)
558
# the set of input flat indices of specified and masked-in elements:
559
maskin_input_flat_indices = _apply(
560
intersection(maskin_flat_indices, input_flat_indices)
562
_, w = intersection(input_flat_indices, maskin_input_flat_indices)
564
# the indices and values of masked-in elements
565
where_input_indices = input.indices()[(slice(None),) + w]
566
where_input_values = input.values()[w]
568
if mask.dense_dim() > 0:
569
# apply mask to the dense part of the input values:
570
_, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
571
where_mask_values = mask.values()[w1]
572
where_input_values = torch.where(
573
where_mask_values, where_input_values, fill_value
576
# the set of flat indices of unspecified input and masked-in elements:
577
maskin_zero_flat_indices = _apply(
578
minus(maskin_flat_indices, maskin_input_flat_indices)
581
# the indices of masked-in zero elements
582
_, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
583
where_zero_indices = mask.indices()[(slice(None),) + w]
586
n = where_zero_indices.size(1)
588
# the input is coalesced, hence input_flat_indices are ordered
589
# and the result is guaranteed to be coalesced:
590
result = torch.sparse_coo_tensor(
591
where_input_indices, where_input_values, input.shape
593
return result._coalesced_(True)
595
where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
596
where_values = torch.cat(
599
where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
602
result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
604
# appending zero elements leads to uncoalesced sparse tensor
605
return result.coalesce()
608
def _sparse_coo_scatter_reduction_helper(
611
dims: Tuple[int, ...],
613
dtype: Optional[DType] = None,
616
valid_reductions = ["sum", "prod", "amax", "amin"]
617
if reduce not in valid_reductions:
619
f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
623
values, indices = mask_input._values(), mask_input._indices()
624
input_dims = mask_input.dim()
625
num_sparse_dims = mask_input.sparse_dim()
626
reduced_sparse_dims = []
627
retained_sparse_dims = []
628
reduced_dense_dims = []
630
# promote dtype if specified
631
if values.dtype != output_dtype:
632
values = values.to(output_dtype)
635
output_shape = tuple(
636
1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
639
output_shape = tuple(
640
si for (i, si) in enumerate(mask_input.shape) if i not in dims
647
if d < num_sparse_dims:
648
reduced_sparse_dims.append(d)
650
reduced_dense_dims.append(d + 1 - num_sparse_dims)
652
# Reduce dense dimensions
653
if len(reduced_dense_dims) > 0:
656
new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
658
# FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
659
return NotImplemented
661
new_values = values.clone()
663
# Reduce sparse dimensions
664
if len(reduced_sparse_dims) == num_sparse_dims:
665
if reduce in {"amax", "amin"} and new_values.size(0) == 0:
666
# IndexError: amax(): Expected reduction dim 0 to have non-zero size.
667
# sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
668
# See https://github.com/pytorch/pytorch/issues/61901
669
new_values = _reduction_identity(reduce, new_values)
671
new_values = op(new_values, dim=0)
673
for _ in range(num_sparse_dims):
674
new_values = new_values.unsqueeze(0)
675
return new_values.to(dtype=output_dtype).to_sparse()
677
new_indices = indices.clone()
679
# zero out reduced sparse dimensions if keepdim = True
680
# ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
681
new_indices[reduced_sparse_dims, :] = 0
683
# remove reduced sparse dimensions if keepdim = False
684
if len(reduced_sparse_dims) > 0:
685
retained_sparse_dims = [
687
for i in range(num_sparse_dims)
688
if i not in set(reduced_sparse_dims)
690
new_indices = new_indices.index_select(
691
0, torch.tensor(retained_sparse_dims).to(mask_input.device)
694
# Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
695
if new_indices.numel() > 0:
696
# lexsort indices and get index tensor for scatter reduction
697
new_indices, inverse_indices = torch.unique(
698
new_indices, return_inverse=True, dim=1
700
out_shape = list(new_values.shape)
701
out_shape[0] = new_indices.shape[1]
702
for _ in range(new_values.ndim - 1):
703
inverse_indices = inverse_indices.unsqueeze(-1)
704
scatter_indices = inverse_indices.expand(new_values.shape)
705
# FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
706
if output_dtype in {torch.bfloat16, torch.float16}:
707
new_values = new_values.to(torch.float)
708
out = new_values.new_empty(out_shape)
709
new_values = out.scatter_reduce_(
710
0, scatter_indices, new_values, reduce=reduce, include_self=False
712
new_values = new_values.to(dtype=output_dtype)
714
out = new_values.new_empty(out_shape)
715
new_values = out.scatter_reduce_(
716
0, scatter_indices, new_values, reduce=reduce, include_self=False
719
return torch.sparse_coo_tensor(
724
device=mask_input.device,
728
def _sparse_csr_segment_reduction_helper(
731
dims: Tuple[int, ...],
733
dtype: Optional[DType] = None,
735
# Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
736
# FIXME: when dense dimensions are implemented for CSR tensors
739
), "reduction operations on CSR tensors with keepdim=False is unsupported"
741
valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
742
if reduce not in valid_reductions:
744
f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
746
device = mask_input.device
748
values, crow_indices, col_indices = (
750
mask_input.crow_indices(),
751
mask_input.col_indices(),
754
# promote dtype if specified
755
if values.dtype != output_dtype:
756
values = values.to(output_dtype)
762
new_col_indices, scatter_indices = torch.unique(
763
col_indices, return_inverse=True
765
new_nnz = new_col_indices.shape[0]
766
new_crow_indices = torch.tensor([0, new_nnz])
767
new_values = values.new_empty(new_col_indices.shape)
768
new_values.scatter_reduce_(
769
0, scatter_indices, values, reduce, include_self=False
771
new_shape = [1, mask_input.size(1)]
775
), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
776
# all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
777
# except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
778
new_crow_indices = torch.cat(
780
crow_indices.new_zeros(1),
781
torch.cumsum(torch.diff(crow_indices) != 0, 0),
785
new_nnz = new_crow_indices[-1]
786
new_col_indices = col_indices.new_zeros(new_nnz)
787
new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
788
new_shape = [mask_input.size(0), 1]
790
assert len(dims) == 2
791
nnz = min(1, values.numel())
793
op_kwargs = {"keepdim": True, "dtype": output_dtype}
794
# amax and amin do not support dtype kwarg
795
if reduce in ["amax", "amin"]:
796
del op_kwargs["dtype"]
797
new_values = op(values, 0, **op_kwargs)
799
new_values = torch.empty(0, dtype=output_dtype)
800
new_col_indices = col_indices.new_zeros(nnz)
801
new_crow_indices = torch.tensor([0, nnz])
804
return torch.sparse_csr_tensor(
814
def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
815
"""Sparse variant of torch.where. Supports sparse CSR tensors."""
816
# TODO: implement sparse CSR specific where operator for efficiency
817
return _sparse_coo_where(
818
mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
822
def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
823
"""torch.where with sparse inputs support.
825
_where implements the following invariant:
827
_where(mask, input, fill_value).to_dense(fill_value) ==
828
torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
830
where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
831
tensor, and `to_dense(fill_value)` is like `to_dense()` except
832
that the unspecified elements are mapped to `fill_value` rather
835
Returns a sparse tensor with the following features:
837
- all specified elements correspond to masked-in elements that
838
have the values of the input tensor. If there exists a masked-in
839
element (as specified by mask) that is not specified in the
840
input, in the result tensor, the corresponding element has value
841
0. In the dense part of the sparse tensor, the masked-out
842
elements are replaced with fill_value.
844
- all unspecified elements correspond to masked-out elements.
846
if mask.layout == torch.strided:
847
return torch.where(mask, input, fill_value)
848
elif mask.layout == torch.sparse_coo:
849
return _sparse_coo_where(mask, input, fill_value)
850
elif mask.layout == torch.sparse_csr:
851
return _sparse_csr_where(mask, input, fill_value)
854
f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
858
def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor:
859
"""Return canonical input mask.
861
A canonical input mask is defined as a boolean mask tensor that
862
shape and layout matches with the shape and the layout of the
865
The canonical input mask is computed from the :attr:`mask` tensor
866
content to meet the following criteria:
868
1. The shape of the canonical input mask is the same as the shape
869
of :attr:`input` tensor. If the mask tensor has a smaller shape
870
than the shape of the :attr:`input`, broadcasting rules will be
871
applied. Downcasting of mask is not supported.
873
2. The layout of the canonical input mask is the same as the
874
layout of the :attr:`input` tensor. If the mask has different
875
layout, it will be converted to the expected layout. In the
876
case of sparse COO layout, the canonical input mask will be
879
3. The dtype of the canonical input mask is torch.bool. If the
880
mask dtype is not bool then it will be converted to bool dtype
881
using `.to(dtype=bool)` method call.
883
4. The elements of the canonical input mask have boolean values
884
copied from the content of the :attr:`mask` tensor (after
885
possible broadcasting and dtype conversion transforms). In
886
general, the sparsity pattern of the sparse canonical input
887
mask need not to be the same as the sparsity pattern of the
888
sparse :attr:`input` tensor.
891
if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
893
f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
896
mask = kwargs.get("mask")
900
raise ValueError("_input_mask requires explicit mask")
902
# mask shape must match with input shape
903
if mask.shape != input.shape:
904
if mask.ndim > input.ndim:
906
"_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
908
if mask.layout == torch.strided:
909
mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
910
elif mask.layout == torch.sparse_coo:
911
mask = torch._sparse_broadcast_to(mask, input.shape)
913
assert mask.layout == torch.sparse_csr
914
# Broadcasting of CSR tensors is not implemented. Working
915
# around by using COO layout.
916
mask = torch._sparse_broadcast_to(
917
mask.to_sparse(), input.shape
920
# mask layout must match with input layout
921
if mask.layout != input.layout:
922
if input.layout == torch.strided:
923
mask = mask.to_dense()
924
elif input.layout == torch.sparse_coo:
925
if mask.layout == torch.strided:
926
mask = mask.to_sparse(input.sparse_dim())
928
mask = mask.to_sparse()
930
assert input.layout == torch.sparse_csr
931
mask = mask.to_sparse_csr()
933
# sparse mask must be coalesced
934
if mask.layout == torch.sparse_coo:
935
mask = mask.coalesce()
937
# mask is a boolean tensor
938
mask = mask.to(dtype=torch.bool)
943
def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
944
"""Return output mask of masked operation applied to given arguments."""
946
is_reduction = op.__name__ in {
960
is_normalization = op.__name__ in {
969
if op.__name__ == "norm":
971
args = args[1:] # lstrip ord argument
972
dim = args[0] if args else kwargs.get("dim")
973
outmask = _input_mask(input, *args, **kwargs)
974
keepdim = kwargs.get("keepdim", False)
975
dim_ = _canonical_dim(dim, input.ndim)
976
return _any(outmask, dim_, bool(keepdim))
977
elif is_normalization:
978
return _input_mask(input, *args, **kwargs)
981
f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
985
f"_output_mask expected masked operation (got {type(op).__name__} object)"
989
def _combine_input_and_mask(
990
op, input: Union[MaskedTensor, Tensor], mask, *args
992
def helper(input, mask):
995
canonical_mask = _input_mask(input, mask=mask)
997
fill_value = _reduction_identity(op.__name__, input, *args)
998
return _where(canonical_mask, input, fill_value)
1001
f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
1004
class Combine(torch.autograd.Function):
1006
def forward(ctx, input, mask):
1007
"""Return input with masked-out elements eliminated for the given operations."""
1008
ctx.save_for_backward(mask)
1010
if mask is not None:
1011
ctx.mark_non_differentiable(mask)
1013
return helper(input, mask)
1016
def backward(ctx, grad_output):
1017
(mask,) = ctx.saved_tensors
1019
grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
1021
result = as_masked_tensor(grad_data, mask)
1025
Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr]
1026
if is_masked_tensor(input)
1027
else helper(input, mask)
1031
@_apply_docstring_templates
1033
input: Union[Tensor, MaskedTensor],
1034
dim: DimOrDims = None,
1036
keepdim: Optional[bool] = False,
1037
dtype: Optional[DType] = None,
1038
mask: Optional[Tensor] = None,
1040
# __doc__ is generated by _apply_docstring_templates decorator
1042
# promote integer types to int64 when output dtype is not specified
1043
if input.layout == torch.sparse_csr:
1051
# csr.to(dtype=torch.int64) is not implemented, so
1052
# using coo.to on input to ensure the promoted dtype
1053
input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
1066
dim_ = _canonical_dim(dim, input.ndim)
1067
mask_input = _combine_input_and_mask(sum, input, mask)
1068
if mask_input.layout == torch.strided:
1069
return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
1070
elif mask_input.layout == torch.sparse_coo:
1071
return _sparse_coo_scatter_reduction_helper(
1072
torch.sum, mask_input, dim_, bool(keepdim), dtype
1074
elif mask_input.layout == torch.sparse_csr:
1075
return torch._sparse_csr_sum(
1076
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
1080
f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1084
@_apply_docstring_templates
1086
input: Union[Tensor, MaskedTensor],
1087
dim: DimOrDims = None,
1089
keepdim: Optional[bool] = False,
1090
dtype: Optional[DType] = None,
1091
mask: Optional[Tensor] = None,
1093
# __doc__ is generated by _apply_docstring_templates decorator
1095
# promote integer types to int64 when output dtype is not specified
1096
if input.layout == torch.sparse_csr:
1104
# csr.to(dtype=torch.int64) is not implemented, so
1105
# using coo.to on input to ensure the promoted dtype
1106
input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
1119
dim_ = _canonical_dim(dim, input.ndim)
1120
mask_input = _combine_input_and_mask(prod, input, mask)
1121
if mask_input.layout == torch.strided:
1122
# Workaround https://github.com/pytorch/pytorch/issues/56586
1124
result = result.to(dtype=dtype)
1125
for d in reversed(dim_):
1126
result = result.prod(dim=d, keepdim=bool(keepdim))
1128
elif mask_input.layout == torch.sparse_coo:
1130
# See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
1132
"masked prod expects explicit mask for sparse_coo tensor input"
1134
return _sparse_coo_scatter_reduction_helper(
1135
torch.prod, mask_input, dim_, bool(keepdim), dtype
1137
elif mask_input.layout == torch.sparse_csr:
1139
# mask is None corresponds to all-True mask. The
1140
# unspecified elements in the CSR tensor correspond to
1141
# zero values. Hence, the prod reduction result is
1142
# automatically zero unless all elements are specified.
1143
# A semi-optimal way to take this into account is to use:
1145
# masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
1147
# but that requires implementing `all` and `nonzero`
1148
# support for sparse csr tensors.
1150
"masked prod expects explicit mask for sparse_csr tensor input"
1152
return torch._sparse_csr_prod(
1153
mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
1157
f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1161
@_apply_docstring_templates
1166
dtype: Optional[DType] = None,
1167
mask: Optional[Tensor] = None,
1171
dim_ = _canonical_dim(dim, input.ndim)[0]
1172
mask_input = _combine_input_and_mask(sum, input, mask)
1173
if mask_input.layout == torch.strided:
1174
return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
1177
f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
1181
@_apply_docstring_templates
1186
dtype: Optional[DType] = None,
1187
mask: Optional[Tensor] = None,
1191
dim_ = _canonical_dim(dim, input.ndim)[0]
1192
mask_input = _combine_input_and_mask(prod, input, mask)
1193
if mask_input.layout == torch.strided:
1194
return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
1197
f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
1201
@_apply_docstring_templates
1203
input: Union[Tensor, MaskedTensor],
1204
dim: DimOrDims = None,
1206
keepdim: Optional[bool] = False,
1207
dtype: Optional[DType] = None,
1208
mask: Optional[Tensor] = None,
1211
{reduction_signature}
1215
{reduction_identity_dtype}
1219
{reduction_example}"""
1223
mask_input = _combine_input_and_mask(amax, input, mask)
1224
dim_ = _canonical_dim(dim, mask_input.ndim)
1225
if mask_input.layout == torch.strided:
1226
return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
1227
elif mask_input.layout == torch.sparse_coo:
1229
# See comment in the sparse_csr branch of prod, a similar issue arises here
1230
# where unspecified elements along a dimension may need to be reduced with the result
1232
"masked amax expects explicit mask for sparse_coo tensor input"
1234
return _sparse_coo_scatter_reduction_helper(
1235
torch.amax, mask_input, dim_, bool(keepdim), dtype
1237
elif mask_input.layout == torch.sparse_csr:
1240
"masked amax expects explicit mask for sparse_csr tensor input"
1242
return _sparse_csr_segment_reduction_helper(
1243
torch.amax, mask_input, dim_, bool(keepdim), dtype
1247
f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1251
@_apply_docstring_templates
1253
input: Union[Tensor, MaskedTensor],
1254
dim: DimOrDims = None,
1256
keepdim: Optional[bool] = False,
1257
dtype: Optional[DType] = None,
1258
mask: Optional[Tensor] = None,
1261
{reduction_signature}
1265
{reduction_identity_dtype}
1269
{reduction_example}"""
1273
mask_input = _combine_input_and_mask(amin, input, mask)
1274
dim_ = _canonical_dim(dim, mask_input.ndim)
1275
if mask_input.layout == torch.strided:
1276
return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
1277
elif mask_input.layout == torch.sparse_coo:
1279
# See comment in the sparse_csr branch of prod, a similar issue arises here
1280
# where unspecified elements along a dimension may need to be reduced with the result
1282
"masked amax expects explicit mask for sparse_coo tensor input"
1284
return _sparse_coo_scatter_reduction_helper(
1285
torch.amin, mask_input, dim_, bool(keepdim), dtype
1287
elif mask_input.layout == torch.sparse_csr:
1290
"masked amin expects explicit mask for sparse_csr tensor input"
1292
return _sparse_csr_segment_reduction_helper(
1293
torch.amin, mask_input, dim_, bool(keepdim), dtype
1297
f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
1301
@_apply_docstring_templates
1303
input: Union[Tensor, MaskedTensor],
1304
dim: Optional[int] = None,
1306
keepdim: Optional[bool] = False,
1307
dtype: Optional[DType] = None,
1308
mask: Optional[Tensor] = None,
1311
{reduction_signature}
1313
{reduction_identity_dtype}
1315
{reduction_example}"""
1318
mask_input = _combine_input_and_mask(argmax, input, mask)
1319
if mask_input.layout == torch.strided:
1320
return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
1323
f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
1327
@_apply_docstring_templates
1329
input: Union[Tensor, MaskedTensor],
1330
dim: Optional[int] = None,
1332
keepdim: Optional[bool] = False,
1333
dtype: Optional[DType] = None,
1334
mask: Optional[Tensor] = None,
1337
{reduction_signature}
1339
{reduction_identity_dtype}
1341
{reduction_example}"""
1344
mask_input = _combine_input_and_mask(argmin, input, mask)
1345
if mask_input.layout == torch.strided:
1346
return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
1349
f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
1353
@_apply_docstring_templates
1355
input: Union[Tensor, MaskedTensor],
1356
dim: DimOrDims = None,
1358
keepdim: Optional[bool] = False,
1359
dtype: Optional[DType] = None,
1360
mask: Optional[Tensor] = None,
1363
{reduction_signature}
1367
By definition, the identity value of a mean operation is the mean
1368
value of the tensor. If all elements of the input tensor along given
1369
dimension(s) :attr:`dim` are masked-out, the identity value of the
1370
mean is undefined. Due to this ambiguity, the elements of output
1371
tensor with strided layout, that correspond to fully masked-out
1372
elements, have ``nan`` values.
1376
{reduction_example}"""
1379
if input.layout == torch.strided:
1381
# TODO: compute count analytically
1383
torch.ones(input.shape, dtype=torch.int64, device=input.device),
1387
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
1389
inmask = _input_mask(input, mask=mask)
1391
inmask.new_ones(input.shape, dtype=torch.int64),
1396
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
1397
return total / count
1398
elif input.layout == torch.sparse_csr:
1399
mask_input = _combine_input_and_mask(mean, input, mask)
1400
dim_ = _canonical_dim(dim, mask_input.ndim)
1403
"masked mean expects explicit mask for sparse_csr tensor input"
1405
return _sparse_csr_segment_reduction_helper(
1406
torch.mean, mask_input, dim_, bool(keepdim), dtype
1410
f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
1414
@_apply_docstring_templates
1416
input: Union[Tensor, MaskedTensor],
1419
keepdim: bool = False,
1420
dtype: Optional[DType] = None,
1421
mask: Optional[Tensor] = None,
1425
{reduction_signature}
1427
By definition, the identity value of a median operation is the median
1428
value of the tensor. If all elements of the input tensor along given
1429
dimension(s) :attr:`dim` are masked-out, the identity value of the
1430
median is undefined. Due to this ambiguity, the elements of output
1431
tensor with strided layout, that correspond to fully masked-out
1432
elements, have ``nan`` values.
1434
{reduction_example}"""
1437
dim_ = _canonical_dim(dim, input.ndim)[0]
1438
is_float = torch.is_floating_point(input)
1440
input = input.to(dtype=torch.float)
1441
mask_input = _combine_input_and_mask(median, input, mask)
1442
if mask_input.layout == torch.strided:
1443
output = torch.nanmedian(mask_input, dim_, keepdim).values
1446
elif not is_float and not torch.isnan(output).any():
1447
return output.to(dtype=dtype)
1450
"masked median expects no fully masked out rows if dtype is not floating point"
1454
f"masked median expects strided tensor (got {mask_input.layout} tensor)"
1458
@_apply_docstring_templates
1461
dim: DimOrDims = None,
1463
keepdim: bool = False,
1464
dtype: Optional[DType] = None,
1465
mask: Optional[Tensor] = None,
1469
dim_ = _canonical_dim(dim, input.ndim)
1470
mask_input = _combine_input_and_mask(logsumexp, input, mask)
1471
if mask_input.layout == torch.strided:
1472
return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
1475
f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
1479
# Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
1481
input: Union[Tensor, MaskedTensor],
1482
other: Union[Tensor, MaskedTensor],
1484
dtype: Optional[DType] = None,
1485
input_mask: Optional[Tensor] = None,
1486
other_mask: Optional[Tensor] = None,
1488
"""logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
1490
Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
1491
tensor. The :attr:`input` elements are masked out according to the boolean tensor
1492
:attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
1495
The shapes of a mask tensor and the tensor to be masked
1496
don't need to match, but they must be :ref:`broadcastable
1497
<broadcasting-semantics>` and the dimensionality of the mask
1498
tensor must not be greater than of the tensor to be masked.
1501
input (Tensor): the input tensor
1502
other (Tensor): the second input tensor
1505
dtype (:class:`torch.dtype`, optional): the desired data type
1506
of returned tensor. If specified, the output tensor is
1507
casted to :attr:`dtype` after the operation is
1508
performed. Default: None.
1509
input_mask (:class:`torch.Tensor`, optional): the boolean tensor
1510
containing the binary mask of validity of :attr:`input` tensor elements.
1511
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
1512
other_mask (:class:`torch.Tensor`, optional): the boolean tensor
1513
containing the binary mask of validity of :attr:`other` tensor elements.
1514
Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
1518
>>> input = torch.tensor([-100.0, -200, -300])
1520
tensor([-100., -200., -300.])
1521
>>> other = torch.tensor([-1.0, -2, -3])
1523
tensor([-1., -2., -3.])
1524
>>> mask = torch.tensor([True, False, True])
1526
tensor([ True, False, True])
1527
>>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
1528
tensor([-1., -inf, -3.])
1532
if input.layout == torch.strided and other.layout == torch.strided:
1533
mask_input = _combine_input_and_mask(logsumexp, input, input_mask)
1534
mask_other = _combine_input_and_mask(logsumexp, other, other_mask)
1535
return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
1538
f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
1542
@_apply_docstring_templates
1544
input: Union[Tensor, MaskedTensor],
1545
ord: Optional[float] = 2.0,
1546
dim: DimOrDims = None,
1548
keepdim: Optional[bool] = False,
1549
dtype: Optional[DType] = None,
1550
mask: Optional[Tensor] = None,
1553
{reduction_signature}
1557
The identity value of norm operation, which is used to start the
1558
reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
1559
``{identity_ord_ninf}``.
1563
{reduction_example}"""
1566
mask_input = _combine_input_and_mask(norm, input, mask, ord)
1567
if mask_input.layout == torch.strided:
1568
dim_ = _canonical_dim(dim, input.ndim)
1569
return torch.linalg.vector_norm(
1570
mask_input, ord, dim_, bool(keepdim), dtype=dtype
1574
f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
1579
input: Union[Tensor, MaskedTensor],
1581
unbiased: Optional[bool],
1583
correction_opt: Optional[Union[int, float]],
1584
keepdim: Optional[bool],
1585
dtype: Optional[DType],
1586
mask: Optional[Tensor],
1587
take_sqrt: Optional[bool],
1589
assert (unbiased is None or correction_opt is None), "Only one of unbiased and correction may be given"
1591
if unbiased is not None:
1592
correction = 1.0 if unbiased else 0.0
1593
if correction_opt is not None:
1594
correction = sym_float(correction_opt)
1598
if not (dtype.is_floating_point or dtype.is_complex):
1599
dtype = torch.float32
1600
compute_dtype = dtype
1601
if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
1602
compute_dtype = torch.float32
1603
if input.layout == torch.strided:
1605
# TODO: compute count analytically
1607
torch.ones(input.shape, dtype=torch.int64, device=input.device),
1611
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
1613
inmask = _input_mask(input, mask=mask)
1615
inmask.new_ones(input.shape, dtype=torch.int64),
1620
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
1621
# TODO: replace torch.subtract/divide/square/maximum with
1622
# masked subtract/divide/square/maximum when these will be
1624
sample_mean = torch.divide(sample_total, count)
1625
x = torch.subtract(input, sample_mean)
1627
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
1630
x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined]
1633
count = count.reshape(total.shape)
1635
real_dtype = (corresponding_real_dtype(compute_dtype)
1636
if compute_dtype.is_complex else compute_dtype)
1637
count = count.to(real_dtype)
1638
count = torch.subtract(count, correction)
1639
count = torch.maximum(count, count.new_zeros([]))
1640
output = torch.divide(total, count).to(dtype=dtype)
1642
output = torch.sqrt(output)
1646
f"masked std/var expects strided tensor (got {input.layout} tensor)"
1650
@_apply_docstring_templates
1652
input: Union[Tensor, MaskedTensor],
1653
dim: DimOrDims = None,
1654
unbiased: Optional[bool] = None,
1656
correction: Optional[Union[int, float]] = None,
1657
keepdim: Optional[bool] = False,
1658
dtype: Optional[DType] = None,
1659
mask: Optional[Tensor] = None,
1662
{reduction_signature}
1664
The identity value of sample variance operation is undefined. The
1665
elements of output tensor with strided layout, that correspond to
1666
fully masked-out elements, have ``nan`` values.
1668
{reduction_example}"""
1673
correction_opt=correction,
1681
@_apply_docstring_templates
1683
input: Union[Tensor, MaskedTensor],
1684
dim: DimOrDims = None,
1685
unbiased: Optional[bool] = None,
1687
correction: Optional[int] = None,
1688
keepdim: Optional[bool] = False,
1689
dtype: Optional[DType] = None,
1690
mask: Optional[Tensor] = None,
1693
{reduction_signature}
1695
The identity value of sample standard deviation operation is undefined. The
1696
elements of output tensor with strided layout, that correspond to
1697
fully masked-out elements, have ``nan`` values.
1699
{reduction_example}"""
1704
correction_opt=correction,
1712
@_apply_docstring_templates
1714
input: Union[Tensor, MaskedTensor],
1717
dtype: Optional[DType] = None,
1718
mask: Optional[Tensor] = None,
1722
dim_ = _canonical_dim(dim, input.ndim)[0]
1723
mask_input = _combine_input_and_mask(amax, input, mask)
1724
if mask_input.layout == torch.strided:
1725
return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
1728
f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
1732
@_apply_docstring_templates
1734
input: Union[Tensor, MaskedTensor],
1737
dtype: Optional[DType] = None,
1738
mask: Optional[Tensor] = None,
1742
dim_ = _canonical_dim(dim, input.ndim)[0]
1743
mask_input = _combine_input_and_mask(amax, input, mask)
1744
if mask_input.layout == torch.strided:
1745
return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
1748
f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
1752
@_apply_docstring_templates
1754
input: Union[Tensor, MaskedTensor],
1757
dtype: Optional[DType] = None,
1758
mask: Optional[Tensor] = None,
1762
dim_ = _canonical_dim(dim, input.ndim)[0]
1763
mask_input = _combine_input_and_mask(amin, input, mask)
1764
if mask_input.layout == torch.strided:
1765
return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
1768
f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
1772
@_apply_docstring_templates
1774
input: Union[Tensor, MaskedTensor],
1779
dtype: Optional[DType] = None,
1780
mask: Optional[Tensor] = None,
1784
dim_ = _canonical_dim(dim, input.ndim)[0]
1785
# TODO: eliminate mask_input as unnecessary when using masked divide.
1786
mask_input = _combine_input_and_mask(sum, input, mask)
1787
if mask_input.layout == torch.strided:
1788
nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
1789
# TODO: replace torch.maximum with masked maximum when available.
1790
denom = torch.maximum(nrm_, nrm_.new_full([], eps))
1791
# TODO: replace torch.divide with masked divide when available.
1792
return torch.divide(mask_input, denom)
1795
f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"