1
# Owner(s): ["module: mta"]
3
from contextlib import nullcontext
4
from numbers import Number
12
from torch.testing import make_tensor
13
from torch.testing._comparison import default_tolerances
14
from torch.testing._internal.common_cuda import TEST_MULTIGPU
15
from torch.testing._internal.common_utils import \
16
TestCase, run_tests, TEST_WITH_ROCM, skipIfTorchDynamo, parametrize, gradcheck, skipIfRocmVersionLessThan
17
from torch.testing._internal.common_device_type import \
18
(instantiate_device_type_tests, dtypes, onlyCUDA, ops, OpDTypes)
19
from torch.testing._internal.common_methods_invocations import (
20
foreach_unary_op_db, foreach_binary_op_db, foreach_pointwise_op_db,
21
foreach_reduce_op_db, foreach_other_op_db)
22
from torch.testing._internal.common_dtype import (
23
all_types_and_complex_and, floating_types_and, floating_types, integral_types_and,
27
_BOOL_SUB_ERR_MSG = "Subtraction, the `-` operator"
30
class RegularFuncWrapper:
31
def __init__(self, func):
34
def __call__(self, inputs, scalars=None, **kwargs):
35
if scalars is not None:
36
assert len(inputs) == 3
37
# We need to distribute each scalar to the regular func and it needs
38
# special consideration as it is a keyword only argument to the
39
# regular func. (Strangely, it is not a keyword only argument to the
41
return [self.func(*i, value=scalars[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
42
if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
43
# binary op with tensorlist and scalar.
44
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
45
return [self.func(*i, **kwargs) for i in zip(*inputs)]
48
class ForeachFuncWrapper:
49
def __init__(self, func):
51
# Some foreach functions don't have in-place implementations.
52
self.is_inplace = False if func is None else func.__name__.endswith('_')
54
def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
56
zero_size = kwargs.pop("zero_size", False)
59
torch.autograd.kineto_available() and
60
torch.profiler.ProfilerActivity.CUDA in torch.profiler.supported_activities()
62
with torch.profiler.profile() as p:
63
actual = self.func(*inputs, **kwargs)
64
keys = tuple([e.key for e in p.key_averages()])
65
mta_called = any("multi_tensor_apply_kernel" in k for k in keys)
66
assert mta_called == (expect_fastpath and (not zero_size))
68
actual = self.func(*inputs, **kwargs)
69
# note(mkozuki): inplace foreach functions are void functions.
70
return inputs[0] if self.is_inplace else actual
73
class InplaceForeachVersionBumpCheck:
75
def __init__(self, testcase: TestCase, tensorlist: "List[torch.Tensor]") -> None: # noqa: F821
76
self._testcase = testcase
77
self._tensorlist = tensorlist
78
self._orig_version_counts = [t._version for t in tensorlist]
83
def __exit__(self, exc_type, exc_value, traceback):
84
# note(crcrpar): some methods e.g. `_binary_test` could call the given inplace function multiple times
85
self._testcase.assertGreaterEqual([t._version for t in self._tensorlist], self._orig_version_counts)
88
def get_transform_func(num_tensors, dtype, device, is_fastpath):
90
if not torch.is_tensor(t):
92
if torch.is_tensor(t) and t.ndim == 0:
95
(num_tensors, num_tensors), dtype=dtype, device=device,
96
requires_grad=True, noncontiguous=not is_fastpath,
102
# note(crcrpar): `zero_size` is `False` unless (dtype, device) == (torch.float32, "cuda")
103
# as the pair would go through `multi_tensor_apply_kernel` if inputs are not zero size.
104
class TestForeach(TestCase):
107
return self.device_type == 'cuda'
109
def _get_funcs(self, op):
111
ForeachFuncWrapper(op.method_variant),
112
RegularFuncWrapper(op.ref),
113
ForeachFuncWrapper(op.inplace_variant),
114
RegularFuncWrapper(op.ref_inplace),
117
# note(crcrpar): Make sure 0-size tensors are appropriately ignored by `multi_tensor_apply`
118
# which is originally reported in https://github.com/pytorch/pytorch/issues/94865.
120
# - https://github.com/pytorch/pytorch/pull/94655
121
# - https://github.com/pytorch/pytorch/issues/100701
122
# - https://github.com/pytorch/pytorch/pull/100811
125
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_reduce_op_db + foreach_other_op_db,
126
dtypes=(torch.float32,)
128
def test_all_zero_size_tensors_do_not_launch_kernel(self, device, dtype, op):
129
wrapped_op, _, inplace_op, _ = self._get_funcs(op)
131
for sample in op.sample_zero_size_inputs(device, dtype):
133
wrapped_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
134
with InplaceForeachVersionBumpCheck(self, sample.input):
135
inplace_op((sample.input, *sample.args), is_cuda=self.is_cuda, expect_fastpath=True, zero_size=True)
137
@skipIfRocmVersionLessThan((6, 0))
139
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_reduce_op_db + foreach_other_op_db,
142
"noncontiguous,inplace",
143
[(False, False), (False, True), (True, False), (True, True)],
144
name_fn=lambda x, y: '{}_{}'.format(
145
'fastpath' if not x else 'slowpath', 'inplace' if y else 'outplace'
148
def test_parity(self, device, dtype, op, noncontiguous, inplace):
150
_, _, func, ref = self._get_funcs(op)
152
func, ref, _, _ = self._get_funcs(op)
153
for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous):
154
ref_kwargs = sample.kwargs
155
# div promotes ints to floats, so we cannot go on the fastpath there
156
div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div'
157
expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath)
158
ref_input, ctxmgr = sample.input, nullcontext()
160
with torch.no_grad():
161
ref_input = [t.clone().detach() for t in sample.input]
162
ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
165
actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **sample.kwargs)
166
except Exception as e:
168
self.assertRaisesRegex(type(e), re.escape(str(e)))
169
if not (op.has_no_in_place or not op.supports_out)
170
else self.assertRaises(type(e))
172
ref([ref_input, *sample.ref_args], **ref_kwargs)
174
expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
175
self.assertEqual(expected, actual)
179
dtype, op, ref, inputs, is_fastpath, is_inplace,
181
alpha, scalar_self_arg: bool,
183
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1]] if is_inplace else inputs
185
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
186
actual = op(inputs, self.is_cuda, is_fastpath)
187
except RuntimeError as e:
188
with self.assertRaisesRegex(type(e), re.escape(str(e))):
189
if not scalar_self_arg:
192
[ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
194
expected = ref(ref_inputs) if not scalar_self_arg else [ref.func(ref_inputs[0], t) for t in ref_inputs[1]]
195
self.assertEqual(actual, expected)
196
if alpha is not None and not scalar_self_arg:
197
kwargs = {'alpha': alpha}
201
op_kwargs.update(kwargs)
202
with InplaceForeachVersionBumpCheck(self, inputs[0]) if op.is_inplace else nullcontext():
203
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
204
except RuntimeError as e:
205
with self.assertRaisesRegex(type(e), re.escape(str(e))):
206
ref(ref_inputs, **kwargs)
208
expected = ref(ref_inputs, **kwargs)
209
if dtype in (torch.float16, torch.bfloat16) and TEST_WITH_ROCM:
210
self.assertEqual(expected, actual, atol=1.e-3, rtol=default_tolerances(dtype)[0])
212
self.assertEqual(expected, actual)
214
@ops(filter(lambda op: op.supports_scalar_self_arg, foreach_binary_op_db))
215
@parametrize("is_fastpath", (True, False))
216
def test_binary_op_with_scalar_self_support(self, device, dtype, op, is_fastpath):
219
if isinstance(arg, (list, tuple)):
220
return [clone(a) for a in arg]
221
if torch.is_tensor(arg):
222
return arg.clone().detach().requires_grad_()
226
scalar_self_arg_test_complete = False
227
for i, sample in enumerate(op.sample_inputs(device, dtype, noncontiguous=not is_fastpath)):
228
(rhs_arg,) = sample.args
229
kwargs = {} or sample.kwargs
230
alpha = kwargs.pop("alpha", None)
231
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
232
if isinstance(rhs_arg, Number) and not scalar_self_arg_test_complete:
233
scalar_self_arg_test_complete = True
235
dtype, wrapped_op, ref, [rhs_arg, sample.input], is_fastpath, False,
236
alpha=alpha, scalar_self_arg=True,
238
if op.supports_autograd and dtype == torch.float32:
239
transformed_sample = sample.transform(
240
get_transform_func(len(sample.input), dtype, device, is_fastpath))
241
tensors = transformed_sample.input
242
(rhs_arg,) = transformed_sample.args
243
ref_tensors, ref_rhs_arg = clone(tensors), clone(rhs_arg)
245
[rhs_arg, tensors], is_cuda=False, expect_fastpath=False
247
sum([ref.func(ref_rhs_arg, t) for t in ref_tensors]).mean().backward()
248
self.assertEqual([t.grad for t in tensors], [t.grad for t in ref_tensors])
250
@ops(foreach_pointwise_op_db)
251
@parametrize("is_fastpath", (True, False))
252
def test_pointwise_op_with_tensor_of_scalarlist_overload(self, device, dtype, op, is_fastpath):
253
for sample in op.sample_inputs(device, dtype, noncontiguous=not is_fastpath):
254
assert isinstance(sample.args, tuple)
255
assert len(sample.args) == 2
256
inputs = [sample.input, *sample.args]
257
kwargs = sample.kwargs.copy()
258
disable_fastpath = sample.disable_fastpath and is_fastpath
259
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
260
scalars = kwargs.pop("scalars", None)
262
if is_fastpath and scalars:
263
sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
264
inputs = [sample.input, *sample.args]
265
tensor_values = torch.tensor(scalars)
266
# 1D Tensor of scalars
267
for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)):
268
self._pointwise_test(
269
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
270
scalars=tensor_values, **kwargs)
271
self._pointwise_test(
272
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
273
scalars=tensor_values[0],
274
custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
278
self._pointwise_test(
279
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
280
scalars=tensor_values.cuda(),
281
custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
284
self._pointwise_test(
285
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
286
scalars=tensor_values[:2],
287
custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
290
self._pointwise_test(
291
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
292
scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
293
custom_values_err="Expected scalars to be contiguous.",
297
# Tests of implicit broadcasting
298
N = len(sample.input)
300
[make_tensor((N, N), device=device, dtype=dtype, noncontiguous=not is_fastpath) for _ in range(N)],
302
make_tensor((N - i, 1), device=device, dtype=dtype, noncontiguous=not is_fastpath)
306
make_tensor((1, N - i), device=device, dtype=dtype, noncontiguous=not is_fastpath)
310
self._pointwise_test(
311
wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False,
312
scalars=scalars, **kwargs)
313
self._pointwise_test(
314
inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath,
315
is_inplace=True, scalars=scalars, **kwargs)
319
op, ref, inputs, is_fastpath, is_inplace,
321
scalars=None, custom_values_err=None, **kwargs
323
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
325
with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()):
326
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
327
except RuntimeError as e:
328
with self.assertRaisesRegex(type(e), re.escape(str(e))):
329
ref(ref_inputs, **kwargs)
331
expected = ref(ref_inputs, **kwargs)
332
self.assertEqual(expected, actual)
333
if scalars is not None:
334
kwargs = kwargs.copy()
335
kwargs["scalars"] = scalars
337
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
338
except RuntimeError as e:
339
# Match with error messages from regular non-foreach reference if no
340
# custom error message was provided.
341
if custom_values_err is None:
342
with self.assertRaisesRegex(type(e), re.escape(str(e))):
343
ref(ref_inputs, **kwargs)
345
self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
347
expected = ref(ref_inputs, **kwargs)
348
self.assertEqual(expected, actual)
350
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
351
def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype):
352
# TODO: enable empty list case
353
for tensors in [[torch.randn([0], device=device, dtype=dtype)],
354
[torch.empty_strided((0, 1), (0, 0), dtype=dtype, device=device)]]:
355
res = torch._foreach_add(tensors, 1)
356
self.assertEqual(res, tensors)
358
torch._foreach_add_(tensors, 1)
359
self.assertEqual(res, tensors)
361
# Regression test for https://github.com/pytorch/pytorch/issues/113156
362
torch._foreach_mul_(tensors, 1)
365
filter(lambda op: op.supports_out, foreach_binary_op_db),
366
dtypes=OpDTypes.supported,
368
def test_binary_op_scalar_with_overlapping_tensors(self, device, dtype, op):
369
foreach_op, ref = op.method_variant, op.ref
370
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
372
if ref == torch.sub and dtype == torch.bool:
373
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
374
[ref(t, 1) for t in tensors]
375
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
376
foreach_op(tensors, 1)
379
expected = [ref(t, 1) for t in tensors]
380
res = foreach_op(tensors, 1)
381
self.assertEqual(res, expected)
384
filter(lambda op: op.supports_out, foreach_binary_op_db),
385
allowed_dtypes=[torch.float],
387
def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
388
foreach_op = op.method_variant
390
torch.tensor([1.1], dtype=torch.float, device=device),
391
torch.tensor([1], dtype=torch.long, device=device),
395
foreach_op(tensors, 1)
396
except RuntimeError as e:
398
self.assertIsNone(runtime_error)
400
@skipIfTorchDynamo("Different error msgs, TODO")
402
filter(lambda op: op.supports_out, foreach_binary_op_db),
403
dtypes=OpDTypes.supported,
405
def test_binary_op_list_error_cases(self, device, dtype, op):
406
foreach_op, foreach_op_, ref, ref_ = op.method_variant, op.inplace_variant, op.ref, op.ref_inplace
409
ops_to_test = [foreach_op, foreach_op_]
412
for fop in ops_to_test:
413
with self.assertRaisesRegex(RuntimeError, "There were no tensor arguments to this function"):
414
fop(tensors1, tensors2)
417
tensors1.append(torch.tensor([1], device=device, dtype=dtype))
418
for fop in ops_to_test:
419
with self.assertRaisesRegex(RuntimeError, "Tensor list must have same number of elements as scalar list."):
420
fop(tensors1, tensors2)
422
# Lists have different amount of tensors
423
tensors2.append(torch.tensor([1], device=device))
424
tensors2.append(torch.tensor([1], device=device))
425
for fop in ops_to_test:
426
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 1 and 2"):
427
fop(tensors1, tensors2)
428
with self.assertRaisesRegex(RuntimeError, "Tensor lists must have the same number of tensors, got 2 and 1"):
429
fop(tensors2, tensors1)
431
# Corresponding tensors with different sizes that aren't compatible with broadcast
432
# If sizes are different then foreach chooses slow path, thus error messages are expected
433
# to be the same as torch regular function.
434
tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
435
tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
437
foreach_op(tensors1, tensors2)
438
except RuntimeError as e:
439
with self.assertRaisesRegex(type(e), re.escape(str(e))):
440
[ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
442
foreach_op_(tensors1, tensors2)
443
except RuntimeError as e:
444
with self.assertRaisesRegex(type(e), re.escape(str(e))):
445
[ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
448
if self.device_type == "cuda" and torch.cuda.device_count() > 1:
449
tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
450
tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
451
if dtype == torch.bool and foreach_op == torch._foreach_sub:
452
for fop in ops_to_test:
453
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
454
fop([tensor1], [tensor2])
456
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
457
foreach_op([tensor1], [tensor2])
458
if dtype in integral_types_and(torch.bool) and foreach_op == torch._foreach_div:
459
with self.assertRaisesRegex(RuntimeError, "result type"):
460
foreach_op_([tensor1], [tensor2])
462
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
463
foreach_op_([tensor1], [tensor2])
465
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
467
filter(lambda op: op.supports_out, foreach_binary_op_db),
468
dtypes=OpDTypes.supported,
470
def test_binary_op_list_slow_path(self, device, dtype, op):
471
foreach_op, native_op, foreach_op_, native_op_ = self._get_funcs(op)
473
tensor1 = make_tensor((10, 10), dtype=dtype, device=device)
474
tensor2 = make_tensor((1,), device=device, dtype=dtype).expand_as(tensor1)
475
inputs = ([tensor1], [tensor2])
477
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
478
alpha=None, scalar_self_arg=False)
480
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
481
alpha=None, scalar_self_arg=False)
484
tensor1 = torch.zeros(10, 10, device=device, dtype=dtype)
485
tensor2 = torch.ones(10, 10, device=device, dtype=dtype)
486
inputs = ([tensor1], [tensor2.t()])
488
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
489
alpha=None, scalar_self_arg=False)
491
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
492
alpha=None, scalar_self_arg=False)
495
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
496
tensor2 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype, noncontiguous=True)
497
self.assertFalse(tensor1.is_contiguous())
498
self.assertFalse(tensor2.is_contiguous())
499
inputs = ([tensor1], [tensor2])
501
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
502
alpha=None, scalar_self_arg=False)
504
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
505
alpha=None, scalar_self_arg=False)
508
tensor1 = make_tensor((5, 2, 1, 3), device=device, dtype=dtype)
509
tensor2 = make_tensor((5, 2, 1, 3 * 7), device=device, dtype=dtype)[:, :, :, ::7]
510
inputs = ([tensor1], [tensor2])
512
dtype, foreach_op, native_op, inputs, is_fastpath=False, is_inplace=False,
513
alpha=None, scalar_self_arg=False)
515
dtype, foreach_op_, native_op_, inputs, is_fastpath=False, is_inplace=True,
516
alpha=None, scalar_self_arg=False)
519
filter(lambda op: op.supports_out, foreach_binary_op_db),
520
dtypes=floating_types_and(torch.half, torch.bfloat16),
522
def test_binary_op_float_inf_nan(self, device, dtype, op):
525
torch.tensor([float("inf")], device=device, dtype=dtype),
526
torch.tensor([-float("inf")], device=device, dtype=dtype),
527
torch.tensor([float("nan")], device=device, dtype=dtype),
528
torch.tensor([float("nan")], device=device, dtype=dtype),
531
torch.tensor([-float("inf")], device=device, dtype=dtype),
532
torch.tensor([float("inf")], device=device, dtype=dtype),
533
torch.tensor([float("inf")], device=device, dtype=dtype),
534
torch.tensor([float("nan")], device=device, dtype=dtype),
537
op, ref, inplace_op, inplace_ref = self._get_funcs(op)
538
self._binary_test(dtype, op, ref, inputs, True, False, alpha=None, scalar_self_arg=False)
540
dtype, inplace_op, inplace_ref, inputs, True, True, alpha=None, scalar_self_arg=False
543
# note: Below three tests (postfixed with `_tensors_on_different_devices`)
544
# checks whether foreach works with lists of tensors on different devices
545
# but tensors of the same index are on the same device, e.g., ['cuda', 'cpu].
547
@ops(foreach_unary_op_db)
548
def test_unary_op_tensors_on_different_devices(self, device, dtype, op):
549
method, ref, inplace_method, ref_inplace = self._get_funcs(op)
550
# tensors: ['cuda', 'cpu]
551
tensors = next(iter(op.sample_inputs(device, dtype, num_input_tensors=[2]))).input
552
tensors[1] = tensors[1].to("cpu")
553
if not op.supports_out:
555
actual = method((tensors,), False, False, zero_size=False)
556
except RuntimeError as e:
557
with self.assertRaisesRegex(type(e), str(e)):
560
expected = ref((tensors,))
561
self.assertEqual(expected, actual)
564
inplace_method((tensors,), False, False, zero_size=False)
565
except RuntimeError as e:
566
with self.assertRaisesRegex(type(e), str(e)):
567
ref_inplace((tensors,))
569
if not op.supports_out:
570
self.assertEqual(expected, tensors)
572
self.assertEqual([torch.zeros_like(t) for t in tensors], tensors)
575
@ops(filter(lambda op: op.supports_out, foreach_binary_op_db))
576
def test_binary_op_tensors_on_different_devices(self, device, dtype, op):
577
# `tensors1`: ['cuda', 'cpu']
578
# `tensors2`: ['cuda', 'cpu']
579
_cuda_tensors = next(iter(op.sample_inputs(device, dtype, num_input_tensors=[2], same_size=True))).input
580
_cpu_tensors = next(iter(op.sample_inputs("cpu", dtype, num_input_tensors=[2], same_size=True))).input
581
tensors1, tensors2 = list(zip(_cuda_tensors, _cpu_tensors))
583
foreach_op, foreach_op_ = op.method_variant, op.inplace_variant
584
native_op, native_op_ = op.ref, op.ref_inplace
586
actual = foreach_op(tensors1, tensors2)
587
except RuntimeError as e:
588
with self.assertRaisesRegex(type(e), re.escape(str(e))):
589
[native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
591
expected = [native_op(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
592
self.assertEqual(expected, actual)
594
foreach_op_(tensors1, tensors2)
595
except RuntimeError as e:
596
with self.assertRaisesRegex(type(e), re.escape(str(e))):
597
[native_op_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
599
self.assertEqual(actual, tensors1)
602
@ops(foreach_pointwise_op_db, allowed_dtypes=floating_types())
603
def test_pointwise_op_tensors_on_different_devices(self, device, dtype, op):
604
# tensors1: ['cuda', 'cpu]
605
# tensors2: ['cuda', 'cpu]
606
# tensors3: ['cuda', 'cpu]
607
# first tensorlist is zero-size when float32
608
_cuda_tensors = list(
609
op.sample_inputs(device, dtype, num_input_tensors=[3], same_size=True)
610
)[int(dtype == torch.float32)].input
611
_cpu_tensors = next(iter(op.sample_inputs("cpu", dtype, num_input_tensors=[3], same_size=True))).input
612
tensors1, tensors2, tensors3 = list(zip(_cuda_tensors, _cpu_tensors))
614
foreach_op, foreach_op_, native_op = op.method_variant, op.inplace_variant, op.ref
615
actual = foreach_op(tensors1, tensors2, tensors3)
616
expected = [native_op(*_cuda_tensors), native_op(*_cpu_tensors)]
617
self.assertEqual(expected, actual)
619
# note(mkozuki): Limiting dtypes to FP32&FP64, we can safely run inplace ops.
620
foreach_op_(tensors1, tensors2, tensors3)
621
self.assertEqual(expected, tensors1)
623
# note: BFloat16 has the same number of exponent bits as FP32
624
# so if squared L2 norm overflows in BF16, then it also overflows in FP32.
626
@ops(foreach_reduce_op_db, allowed_dtypes=(torch.half, torch.bfloat16))
627
def test_foreach_l2_large_value_input(self, device, dtype, op):
629
max_value = torch.finfo(dtype).max
630
scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
632
t * scaler for t in next(iter(op.sample_inputs(device, dtype, requries_grad=True, num_input_tensors=[N], low=1))).input
634
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
635
self.assertTrue(scaler * scaler * N > max_value)
636
fn, ref_fn, *_ = self._get_funcs(op)
637
actual = fn(inputs, is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
638
expect = ref_fn(inputs, ord=ord)
640
if dtype == torch.float16:
641
# making sure the reference L2 norm values are in the range of FP16.
642
self.assertFalse(any(torch.isinf(e) for e in expect))
645
inputs[0][i].numel() == 0 or torch.isinf(e)
646
for i, e in enumerate(expect)))
647
self.assertEqual(expect, actual, equal_nan=False)
650
@ops(foreach_reduce_op_db, allowed_dtypes=floating_types())
651
def test_big_num_tensors(self, device, dtype, op):
653
tensorlist = [make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) for _ in range(N)]
654
fn, ref_fn, *_ = self._get_funcs(op)
657
for ord in (1, 2, math.inf):
658
actual = fn(inputs=[tensorlist], is_cuda=True, expect_fastpath=True, ord=ord, zero_size=False)
659
expect = ref_fn(inputs=[tensorlist], ord=ord)
661
self.assertEqual(expect, actual, equal_nan=True)
664
@ops(foreach_reduce_op_db)
665
def test_foreach_reduce_large_input(self, device, dtype, op):
666
# test inputs larger than kChunkSize = 65536
667
ord, N = 2, 65536 * 2
668
disable_fastpath = True
669
if ord in (1, 2) and dtype in floating_types_and(torch.half, torch.bfloat16):
670
disable_fastpath = False
671
inputs = ([make_tensor((N,), dtype=dtype, device=device, noncontiguous=False)],)
672
wrapped_op, ref, _, _ = self._get_funcs(op)
674
ref(inputs, ord=ord),
675
wrapped_op(inputs, self.is_cuda, not disable_fastpath, ord=ord, zero_size=False),
680
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
681
dtypes=(torch.float,),
683
def test_inplace_foreach_leaf_check_and_grad_fn(self, device, dtype, op):
684
inplace_op = op.inplace_variant
685
if inplace_op is None:
686
self.skipTest("no in-place op available")
688
sample = next(iter(op.sample_inputs(dtype=dtype, device=device, num_input_tensors=[2], same_size=True)))
689
sample.input[0].requires_grad_(True)
690
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
691
inplace_op(sample.input, *sample.args)
692
sample.input[1].requires_grad_(True)
693
with self.assertRaisesRegex(RuntimeError, "a leaf Variable that requires grad"):
694
inplace_op(sample.input, *sample.args)
696
_tensors = [t.clone().detach().requires_grad_(i == 0) for i, t in enumerate(sample.input)]
697
tensors = [t.clone() for t in _tensors]
698
inplace_op(tensors, *sample.args)
699
self.assertIsNotNone(tensors[0].grad_fn)
700
self.assertIsNone(tensors[1].grad_fn)
705
lambda op: op.supports_out,
706
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
708
dtypes=(torch.float,),
710
def test_outplace_with_invalid_grads(self, device, dtype, op):
711
func, *_ = self._get_funcs(op)
712
sample = next(iter(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True)))
713
self.assertTrue(all(t.requires_grad for t in sample.input))
714
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs)
715
out1.backward(torch.ones_like(out1))
716
self.assertIsNotNone(sample.input[0].grad)
717
self.assertIsNone(sample.input[1].grad)
721
lambda op: op.backward_requires_result,
722
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_other_op_db,
724
dtypes=(torch.float32,),
726
def test_lifetime_of_grad_fn_when_result_is_saved(self, device, dtype, op):
728
def get_ref(func, sample):
732
out = func((sample.input, *sample.args), is_cuda=False, expect_fastpath=False, **sample.kwargs)
734
meta_dict = out[0].grad_fn.metadata
736
ref = weakref.ref(foo)
739
def _test(func, sample):
740
out, ref = get_ref(func, sample)
741
self.assertIsNotNone(ref())
743
self.assertIsNone(ref())
745
func = self._get_funcs(op)[0]
746
for sample in op.sample_inputs(device, dtype, requires_grad=True, num_input_tensors=[1]):
747
for key in ("is_fastpath", "disable_fastpath"):
748
if key in sample.kwargs:
749
del sample.kwargs[key]
750
# note: `_foreach_pow.Scalar` and `_foreach_pow.ScalarList` don't depend on `result`
751
# see: https://github.com/pytorch/pytorch/blob/5403c777/tools/autograd/derivatives.yaml#L3048-L3049
752
if op.name == "_foreach_pow":
754
(isinstance(sample.args[0], list) and isinstance(sample.args[0][0], Number))
755
or (isinstance(sample.args[0], Number) and not isinstance(sample.args[0], float))
758
if isinstance(sample.args[0], float):
759
new_args = (sample.input,)
760
sample.input = sample.args[0]
761
sample.args = new_args
764
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
765
def test_tensors_grouping(self):
766
num_tensors_per_list = 10
767
num_devices = torch.cuda.device_count()
768
dtypes = (torch.float16, torch.float32, torch.float64)
772
device=torch.device("cuda", random.randint(0, num_devices - 1)),
773
dtype=dtypes[random.randint(0, 2)],
774
) for i in range(num_tensors_per_list)
776
list2 = [None for _ in list1]
777
list3 = [torch.rand_like(t) for t in list1]
778
nested_tensorlists = [list1, list2, list3]
779
grouped_tensors = torch.utils._foreach_utils._group_tensors_by_device_and_dtype(nested_tensorlists, with_indices=True)
781
for (device, dtype), ([l1, l2, l3], indices) in grouped_tensors.items():
782
for t in itertools.chain(l1, l3):
783
self.assertEqual(t.device, device)
784
self.assertEqual(t.dtype, dtype)
785
num_tensors_seen += 1
786
self.assertEqual(len(l1), len(l2))
787
self.assertTrue(all(p is None for p in l2))
788
for i, index in enumerate(indices):
789
self.assertEqual(l1[i], list1[index])
790
self.assertEqual(l2[i], list2[index])
791
self.assertEqual(l3[i], list3[index])
792
self.assertEqual(num_tensors_seen, 2 * num_tensors_per_list)
795
def test_0dim_tensor_overload_cpu_ok(self):
796
tensors = [torch.ones((), device="cuda", dtype=torch.float32) for _ in range(2)]
797
scalar_cpu_tensor = torch.tensor(4.0, device="cpu")
799
# For mul and div, the scalar is allowed to be on CPU too
800
actual = torch._foreach_mul(tensors, scalar_cpu_tensor)
801
self.assertEqual(actual, [t.mul(scalar_cpu_tensor) for t in tensors])
802
actual = torch._foreach_div(tensors, scalar_cpu_tensor)
803
self.assertEqual(actual, [t.div(scalar_cpu_tensor) for t in tensors])
807
def test_0dim_tensor_overload_exception(self):
808
# check exceptions of fast path
809
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)]
810
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
811
torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
813
tensors = [make_tensor((2, 2), dtype=torch.float, device=d) for d in ("cpu", "cuda")]
814
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
815
torch._foreach_mul(tensors, torch.tensor([1.0, 1.0], device="cuda"))
816
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be 0 dim but"):
817
torch._foreach_add(tensors, torch.tensor([1.0, 1.0], device="cuda"))
820
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
821
def test_foreach_copy_with_multi_device_inputs(self, device, dtype, op):
822
foreach_copy_ = op.inplace_variant
823
copy_ = op.ref_inplace
824
for non_blocking in (False, True):
825
for sample in op.sample_inputs(device, dtype, noncontiguous=False):
826
with torch.no_grad():
827
ref_input = [t.clone().detach() for t in sample.input]
828
foreach_copy_(sample.input, sample.args[0], non_blocking)
829
for t, s in zip(ref_input, sample.args[0]):
830
copy_(t, s, non_blocking)
831
self.assertEqual(sample.input, ref_input)
832
if torch.cuda.device_count() > 1:
833
device = torch.device("cuda", 1)
834
rhs_tensors = [t.to(device) for t in sample.args[0]]
835
foreach_copy_(sample.input, rhs_tensors, non_blocking)
836
for t, s in zip(ref_input, rhs_tensors):
837
copy_(t, s, non_blocking)
838
self.assertEqual(ref_input, sample.input)
840
# Test reverse-mode & forward-mode AD if supported.
843
foreach_unary_op_db + foreach_binary_op_db + foreach_pointwise_op_db + foreach_reduce_op_db + foreach_other_op_db,
844
dtypes=OpDTypes.supported,
845
allowed_dtypes=(torch.float64, torch.complex128),
847
@parametrize("inplace", (False, True), name_fn=lambda x: "inplace" if x else "outplace")
848
def test_autodiff(self, device, dtype, op, inplace):
849
if not (op.supports_autograd or op.supports_forward_ad):
850
self.skipTest("neither reverse mode nor forward mode supported")
851
if (not inplace) and not op.supports_out:
852
self.skipTest("out-of-place not implemented")
853
if inplace and op.has_no_in_place:
854
self.skipTest("in-place not implemented")
856
# note(crcrpar): without this, some unary functions fail, unlike inplace and/or complex.
857
if (not inplace) and dtype == torch.float64 and op.name in (
858
"_foreach_acos", "_foreach_asin", "_foreach_log10", "_foreach_log1p", "_foreach_log2",
859
"_foreach_log", "_foreach_pow", "_foreach_sqrt",
861
value_range = {"low": 0.5, "high": 1.0}
864
for sample in op.sample_inputs(
865
device, dtype, requires_grad=True, num_input_tensors=[5], **value_range,
867
# Skip `_foreach_pow.ScalarAndTensor(Scalar, Tensor[])`
868
if op.name == "_foreach_pow" and isinstance(sample.input, Number):
873
# Call `clone` to avoid inplace modifications likewise
874
# `torch.testing._internal.common_utils.TestGradients._get_safe_inplace`
875
def inplace_func(*tensorlist):
876
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
877
op.inplace_variant(tuple(t.clone() for t in tensorlist), *sample.args, **kwargs)
881
def outplace_func(*tensorlist):
882
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
883
return op.method_variant(tensorlist, *sample.args, **kwargs)
886
working_sample, err_msg_pattern = check_autodiff_sample(op, sample, dtype, inplace)
888
def call_gradcheck():
892
raise_exception=True,
893
check_forward_ad=op.supports_forward_ad,
894
check_batched_forward_grad=False,
895
check_backward_ad=op.supports_autograd,
896
check_batched_grad=False,
899
if not working_sample:
900
if not err_msg_pattern:
901
# lhs of float64 and rhs of complex.
903
with self.assertRaisesRegex(RuntimeError, re.escape(err_msg_pattern)):
908
# Test per-tensor `grad_fn` behavior.
909
if inplace and op.supports_inplace_autograd:
910
# per-tensor `grad_fn` check.
913
def get_grad_fn_hook(i):
915
def hook(grad_inputs, grad_outputs) -> None:
916
hook_buffer.append(i)
920
_inputs = [t.clone().detach().requires_grad_() for t in sample.input]
921
inputs = [t.clone() for t in _inputs]
922
kwargs = {"alpha": sample.kwargs["alpha"]} if "alpha" in sample.kwargs else {}
923
op.inplace_variant(inputs, *sample.args, **kwargs)
925
self.assertEqual(len({t.grad_fn for t in inputs}), len(inputs))
927
for i, t in enumerate(inputs):
928
t.grad_fn.register_hook(get_grad_fn_hook(i))
932
inputs=(_inputs[0],),
933
grad_outputs=(torch.rand_like(inputs[0]),),
936
self.assertEqual(hook_buffer, [0])
939
# tensors have different shapes.
940
sum_of_cloned_tensors = torch.cat([t.view(-1) for t in inputs]).sum()
941
grad_output = torch.rand_like(sum_of_cloned_tensors)
943
sum_of_cloned_tensors,
944
inputs=tuple(_inputs),
945
grad_outputs=(grad_output,),
948
self.assertEqual(hook_buffer, list(reversed(range(len(inputs)))))
951
# TODO(crcrpar): Hide this inside torch/testing/_internal.
952
# would end up adding another layer to `foreach_inputs_sample_func.__call__`
953
# so that we can use this function as something like the first argument of `filter` function.
954
# Even after moving this function to testing, I personally think it'd be better to check the error message.
955
def check_autodiff_sample(op, sample, dtype, is_inplace):
956
if op.name == "_foreach_abs" and is_inplace and dtype == torch.complex128:
957
return False, "In-place abs is not supported for complex tensors."
959
op.name == "_foreach_sub"
961
(isinstance(sample.args[0], list) and any(isinstance(a, bool) for a in sample.args[0]))
962
or isinstance(sample.args[0], bool)
965
return False, _BOOL_SUB_ERR_MSG
966
if op.name == "_foreach_norm" and (not is_inplace):
969
"Trying to set a forward gradient that has a different size than that of the original Tensor, "
970
"this is not supported. Tensor is of size [] while the given forward gradient is of size [1, 1]."
972
rhs_arg_has_complex_number = sample.args and ((
973
isinstance(sample.args[0], list)
974
and any(isinstance(a, complex) for a in sample.args[0])
976
isinstance(sample.args[0], complex)
978
if rhs_arg_has_complex_number and dtype == torch.float64:
979
if op.name in ("_foreach_clamp_max", "_foreach_clamp_min", "_foreach_maximum", "_foreach_minimum"):
980
return False, "clamp is not supported for complex types"
984
if op.name == "_foreach_pow":
985
return False, "Found dtype Double but expected ComplexDouble"
986
if op.name in ("_foreach_add", "_foreach_sub", "_foreach_mul", "_foreach_div"):
987
return False, "result type ComplexDouble can't be cast to the desired output type Double"
991
instantiate_device_type_tests(TestForeach, globals())
994
if __name__ == "__main__":