1
# Owner(s): ["module: masked operators"]
3
"""Tests for masked operations.
8
from typing import List, Any
9
from functools import wraps
11
from torch.testing._internal.common_utils import skipIfTorchDynamo
14
from torch.testing._internal.common_utils import \
15
(TestCase, parametrize, suppress_warnings, _TestParametrizer, run_tests)
16
from torch.testing._internal.common_methods_invocations import \
18
from torch.testing._internal.common_device_type import \
19
(instantiate_device_type_tests, ops, onlyNativeDeviceTypes, precisionOverride)
22
def apply_masked_reduction_along_dim(op, input, *args, **kwargs):
23
"""Applies reduction op along given dimension to strided x
24
elements that are valid according to mask tensor.
26
The op is applied to each elementary slice of input with args and
27
kwargs with the following constraints:
29
1. Prior applying the op:
31
A. if kwargs contains an item with key 'dim_position' then it is
32
removed from kwargs. The value of 'dim_position' is an
33
integer that describes the dim argument position: while
34
typically the dim argument appears at the 0-th position of
35
the op arguments (excluding input), for instance, sum(input,
36
dim), then there exists reductions that have extra arguments
37
prior the dim argument, for instance, norm(input, ord, dim).
39
B. if args or kwargs contains dim or keepdim arguments, these
40
will be removed or replaced with None so that the op is
41
applied to elementary slice using the default dim and keepdim
44
2. The elementary slice of the input is defined as the flattened
45
slice that has no masked out elements and when op is applied,
46
the result will be a scalar value (assuming keepdim=False). For
47
example, an input tensor to a reduction operation op having
48
dim=0 and keepdim=True argument:
53
(* denotes masked out elements) has the following elementary
54
slices: [1, 2] and [3, 4, 5]. The result of
55
apply_masked_reduction_along_dim is
57
[[op([1, 2], *args0, **kwargs, dim=None, keepdim=False)]
58
[op([3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)]]
60
where args0 is args where dim value is replased with None if
63
Using the same example data, if the op is called with dim=(0, 1)
64
and keepdim=False, there is one elementary slice: [1, 2, 3, 4,
65
5]; and the corresponding result of the op is:
67
op([1, 2, 3, 4, 5], *args0, **kwargs, dim=None, keepdim=False)
69
3. If the elementary slice is empty, the corresponding output
70
value is nan if dtype is float, otherwise, 0. An empty
71
elementary slice corresponds to fully masked-out output, so, the
72
corresponding specific value of the output will not be important
73
because we used masked equality check for comparing the results
76
# eliminate mask and dim_position keyword arguments:
77
mask = kwargs.pop('mask', None)
78
dim_pos = kwargs.pop('dim_position', 0)
80
dtype = kwargs.get('dtype', input.dtype)
82
# scalar input is an elementary slice
83
return op(input, *args, **kwargs).to(dtype=dtype)
85
# eliminate keepdim keyword argument if specified:
86
keepdim = kwargs.pop('keepdim', False)
88
# eliminate dim argument that may appear both as args or kwargs
90
if dim_pos < len(args):
91
# dim is specified in args
92
assert 'dim' not in kwargs, (args, kwargs)
94
args0 = args[:dim_pos] + (None,) + args[dim_pos + 1:]
96
# dim may be specified in kwargs
97
dim = kwargs.pop('dim', None)
100
# dimensions along which the reduction operation is applied:
101
dim_ = torch.masked._canonical_dim(dim, input.ndim)
102
# slices in product(*ranges) define all elementary slices:
103
ranges: List[Any] = []
104
# shape of output for the keepdim=True case:
106
for i in range(input.ndim):
108
ranges.append((slice(None),))
111
ranges.append(range(input.shape[i]))
112
shape.append(input.shape[i])
114
# keepdim=True version of the output, filled with nan or 0:
115
output = input.new_full(shape, float('nan') if dtype.is_floating_point else 0, dtype=dtype)
117
# apply op to all elementary slices:
119
inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
121
inpmask = torch.masked._input_mask(input, mask=mask)
122
for s in itertools.product(*ranges):
123
# data of an elementary slice is 1D sequence and has only
124
# masked-in elements:
125
data = input[s].flatten()[inpmask[s].flatten().argwhere()]
127
# empty elementary slice
129
output[s][0] = op(data, *args0, **kwargs)
132
# reshape output for the keepdim=False case
133
shape = [shape[i] for i in range(len(shape)) if i not in dim_]
134
output = output.reshape(shape)
138
def apply_masked_normalization_along_dim(op, input, *args, **kwargs):
139
"""Applies normalization op along given dimension to strided x
140
elements that are valid according to mask tensor.
142
mask = kwargs.pop('mask', None)
143
dim_pos = kwargs.pop('dim_position', 0)
144
if input.ndim == 0: # scalar input
145
return op(input, *args, **kwargs)
146
dtype = kwargs.get('dtype', input.dtype)
148
args0 = args[:dim_pos] + (0,) + args[dim_pos + 1:]
149
output = torch.zeros_like(input, dtype=dtype)
151
inpmask = input.new_ones([], dtype=torch.bool).expand(input.shape)
153
inpmask = torch.masked._input_mask(input, mask=mask)
154
dim_ = dim % input.ndim
155
left_ranges = tuple(map(range, input.shape[:dim_]))
156
right_ranges = tuple(map(range, input.shape[dim_ + 1:]))
157
for s in itertools.product(*(left_ranges + ((slice(None),),) + right_ranges)):
158
indices = inpmask[s].argwhere()
159
output[s][indices] = op(input[s][indices], *args0, **kwargs)
163
reference_functions = dict(
164
norm=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.linalg.vector_norm, *args, **dict(kwargs, dim_position=1)),
165
var=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.var, *args, **dict(kwargs, dim_position=0)),
166
std=lambda *args, **kwargs: apply_masked_reduction_along_dim(torch.std, *args, **dict(kwargs, dim_position=0)),
167
softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.softmax, *args, **kwargs),
168
log_softmax=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.log_softmax, *args, **kwargs),
169
softmin=lambda *args, **kwargs: apply_masked_normalization_along_dim(torch.nn.functional.softmin, *args, **kwargs),
170
normalize=lambda *args, **kwargs: apply_masked_normalization_along_dim(
171
torch.nn.functional.normalize, *args, **dict(kwargs, dim_position=1)),
174
masked_ops = [op for op in op_db if op.name.startswith('masked.')]
175
masked_ops_with_references = [op for op in masked_ops if op.name.rsplit('.', 1)[-1] in reference_functions]
176
masked_ops_with_non_strided_support = [op for op in masked_ops if op.supports_sparse or op.supports_sparse_csr]
179
def _tensor_to_strided(obj):
180
# after gh-59958 is resolved, replace the usage of this function
181
# with torch.Tensor.to_dense
182
if torch.is_tensor(obj):
183
if obj.layout == torch.strided:
185
return obj.to_dense()
190
"""Convert the tensor content of object to strided tensor content.
192
return torch.utils._pytree.tree_map(_tensor_to_strided, obj)
195
def to_sparse_coo(obj):
196
"""Convert the tensor content of object to sparse coo tensor content.
198
return torch.utils._pytree.tree_map(torch.Tensor.to_sparse, obj)
201
def to_sparse_csr(obj):
202
"""Convert the tensor content of object to sparse csr tensor content.
204
return torch.utils._pytree.tree_map(torch.Tensor.to_sparse_csr, obj)
207
class mask_layouts(_TestParametrizer):
208
"""Decorator class for parametrization of test function with an input
209
layout argument and an extra argument of sample inputs generator.
210
The sample_inputs generator provides samples with all supported
211
layouts for the mask argument.
213
def _parametrize_test(self, test, generic_cls, device_cls):
216
def wrap(self, layout, device, dtype, op):
217
layout_name = str(layout).lstrip('torch.')
218
if layout == torch.strided:
219
# strided layouts are always supported
220
sample_inputs_func = op.sample_inputs
221
elif layout == torch.sparse_coo:
222
if not op.supports_sparse:
223
raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout")
224
sample_inputs_func = op.sample_inputs_sparse_coo
225
elif layout == torch.sparse_csr:
226
if not op.supports_sparse_csr:
227
raise unittest.SkipTest(f"{op.name} does not support inputs with {layout_name} layout")
228
sample_inputs_func = op.sample_inputs_sparse_csr
230
raise NotImplementedError(f'{layout}')
232
def sample_inputs_generator():
233
for sample_input in sample_inputs_func(device, dtype):
234
mask = sample_input.kwargs.get('mask')
238
if layout == sample_input.input.layout:
240
if layout != torch.strided:
241
sample_input_kwargs = sample_input.kwargs.copy()
242
sample_input_kwargs.update(mask=mask.to_dense())
243
yield SampleInput(sample_input.input.clone(),
244
args=sample_input.args,
245
kwargs=sample_input_kwargs)
246
if layout != torch.sparse_coo and op.supports_sparse:
247
sample_input_kwargs = sample_input.kwargs.copy()
248
sample_input_kwargs.update(mask=mask.to_sparse())
249
yield SampleInput(sample_input.input.clone(),
250
args=sample_input.args,
251
kwargs=sample_input_kwargs)
252
if layout != torch.sparse_csr and op.supports_sparse_csr and sample_input.input.ndim == 2:
253
sample_input_kwargs = sample_input.kwargs.copy()
254
sample_input_kwargs.update(mask=mask.to_sparse_csr())
255
yield SampleInput(sample_input.input.clone(),
256
args=sample_input.args,
257
kwargs=sample_input_kwargs)
259
test(self, layout, device, dtype, op, sample_inputs_generator())
261
for layout in (torch.strided, torch.sparse_coo, torch.sparse_csr):
262
yield (wrap, str(layout).lstrip('torch.'), {'layout': layout}, lambda _: [])
265
class TestMasked(TestCase):
267
def assertEqualMasked(self, actual, expected, mask):
268
strided = to_strided(actual)
270
strided = torch.where(mask, strided, strided.new_zeros([]))
271
expected = torch.where(mask, expected, expected.new_zeros([]))
272
self.assertEqual(strided, expected, exact_device=False)
274
@onlyNativeDeviceTypes
276
@ops(masked_ops_with_references)
277
@precisionOverride({torch.bfloat16: 5e-4, torch.float16: 5e-4})
278
def test_reference_masked(self, device, dtype, op):
279
op_name = op.name.rsplit('.', 1)[-1]
280
ref_op = reference_functions[op_name]
281
sample_inputs = op.sample_inputs(device, dtype)
282
for sample_input in sample_inputs:
283
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
284
if op_name in {'var', 'std'} and not (t_inp.dtype.is_floating_point or t_inp.dtype.is_complex):
285
# torch.var/torch.std does not support integer inputs
287
actual = op.op(t_inp, *t_args, **t_kwargs)
288
expected = ref_op(t_inp, *t_args, **t_kwargs)
289
if t_kwargs.get('mask') is None:
292
outmask = torch.masked._output_mask(op.op, t_inp, *t_args, **t_kwargs)
293
self.assertEqualMasked(actual, expected, outmask)
296
@onlyNativeDeviceTypes
298
@ops(masked_ops_with_non_strided_support)
299
@precisionOverride({torch.bfloat16: 5e-3, torch.float16: 5e-3})
300
def test_mask_layout(self, layout, device, dtype, op, sample_inputs):
301
for sample in sample_inputs:
302
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
303
actual = op.op(t_inp, *t_args, **t_kwargs)
305
assert actual.layout == layout
307
# check masked invariance:
308
# op(inp, mask).to_dense() == op(inp.to_dense(), mask.to_dense()) at outmask
310
r_inp, r_args, r_kwargs = to_strided((t_inp, t_args, t_kwargs))
311
if r_kwargs.get('mask') is None:
314
outmask = torch.masked._output_mask(op.op, r_inp, *r_args, **r_kwargs)
315
expected = op.op(r_inp, *r_args, **r_kwargs)
316
self.assertEqualMasked(actual, expected, outmask)
318
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1992")
319
@parametrize("sparse_kind,fill_value", [('coo', 0), ('hybrid_coo', 0),
320
('coo', 123), ('hybrid_coo', 123),
321
('csr', 0), ('csr', 123)],
322
name_fn=lambda sparse_kind, fill_value: f'{sparse_kind}_fill_value_{fill_value}')
323
def test_where(self, sparse_kind, fill_value):
326
if sparse_kind == 'coo':
328
def to_sparse(dense):
329
return dense.to_sparse(2)
331
def set_values(sparse, index, value):
332
sparse._values()[index] = value
334
elif sparse_kind == 'hybrid_coo':
337
def to_sparse(dense):
338
return dense.to_sparse(1)
340
def set_values(sparse, index, value):
341
sparse._values()[index] = value
343
elif sparse_kind == 'csr':
345
def to_sparse(dense):
346
return dense.to_sparse_csr()
348
def set_values(sparse, index, value):
349
sparse.values()[index] = value
352
assert 0, sparse_kind
354
mask = torch.tensor([[1, 0, 1, 0, 0],
359
[1, 1, 0, 0, 0]]).to(dtype=bool)
360
mask = to_sparse(mask)
361
# make some specified mask elements as explicit masked-out masks:
363
set_values(mask, (1, 1), False)
364
set_values(mask, (-2, -2), False)
366
set_values(mask, 3, False)
367
set_values(mask, -3, False)
369
input = torch.tensor([[1, 0, 0, 0, -1],
375
input = to_sparse(input)
376
# make specified input elements have zero values:
378
set_values(input, (1, 1), 0)
379
set_values(input, (-1, 0), 0)
382
set_values(input, 3, 0)
383
set_values(input, -3, 0)
386
# expected where result:
388
# Z value corresponds to masked-in elements that are not
389
# specified in the input and it will be replaced with a zero
390
tmp = torch.tensor([[1, F, Z, F, F],
399
sparse = torch.masked._where(mask, input,
400
torch.tensor(fill_value, dtype=input.dtype, device=input.device))
402
if tmp.layout == torch.sparse_coo:
403
expected_sparse = torch.sparse_coo_tensor(
405
torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)),
407
outmask = torch.sparse_coo_tensor(sparse.indices(),
408
sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool),
409
sparse.shape)._coalesced_(True)
410
elif tmp.layout == torch.sparse_csr:
411
expected_sparse = torch.sparse_csr_tensor(
414
torch.where(tmp.values() != Z, tmp.values(), tmp.values().new_full([], 0)),
416
outmask = torch.sparse_csr_tensor(sparse.crow_indices(), sparse.col_indices(),
417
sparse.values().new_full(sparse.values().shape, 1).to(dtype=bool),
422
self.assertEqual(sparse, expected_sparse)
425
# torch.where(mask.to_dense(), input.to_dense(), fill_value)
426
# == where(mask, input, fill_value).to_dense(fill_value)
427
expected = torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, F))
428
dense = torch.where(outmask.to_dense(), sparse.to_dense(), torch.full(sparse.shape, F))
429
self.assertEqual(dense, expected)
432
instantiate_device_type_tests(TestMasked, globals(), except_for='meta')
434
if __name__ == "__main__":