1
# Owner(s): ["module: nn"]
3
from itertools import chain, product
4
from inspect import signature, isgenerator
5
from copy import deepcopy
7
from operator import methodcaller
11
from torch._subclasses.meta_utils import assert_metadata_eq
12
from torch.testing._internal.common_cuda import with_tf32_off
13
from torch.testing._internal.common_device_type import (
14
instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
15
from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
16
from torch.testing._internal.common_utils import (
17
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
18
gradgradcheck, parametrize, wrapSwapTensorsTest)
19
from unittest.mock import patch, call
22
class TestModule(TestCase):
23
_do_cuda_memory_leak_check = True
24
_do_cuda_non_default_stream = True
28
def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
29
# Check device placement and dtype for created parameters and buffers.
30
# Only verify floating point dtypes since that's what the kwarg or methods
31
# such as `float()` applies to.
32
if not isinstance(device, torch.device):
33
device = torch.device(device)
35
def _check_module(items, name, device=device, dtype=dtype):
36
for item_name, item in items:
39
f'{name} {item_name} is on device {item.device} instead of the expected device {device}')
40
if item.dtype.is_floating_point:
43
f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}')
44
_check_module(module.named_parameters(), "Parameter")
45
_check_module(module.named_buffers(), "Buffer")
48
def test_forward(self, device, dtype, module_info, training):
49
module_cls = module_info.module_cls
50
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
51
requires_grad=False, training=training)
52
dtype_to_method_caller = {
53
torch.float32: methodcaller("float"),
54
torch.float64: methodcaller("double"),
56
for module_input in module_inputs:
57
if module_input.forward_input is None:
60
with freeze_rng_state():
61
# === Instantiate the module. ===
62
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
63
m = module_cls(*args, **kwargs)
64
m.to(device).to(dtype)
67
# === Do forward pass. ===
68
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
69
outputs = m(*args, **kwargs)
71
# === Compare outputs to a reference if one is specified. ===
72
# TODO: Handle precision
73
reference_fn = module_input.reference_fn
74
if reference_fn is not None:
75
ref_outputs = reference_fn(m, *args, **kwargs)
76
self.assertEqual(outputs, ref_outputs)
78
# === Use the method call and verify the parameters and buffers ===
79
if dtype in dtype_to_method_caller:
80
dtype_to_method_caller[dtype](m)
82
self._assert_module_parameters_and_buffer_are(m, device, dtype)
84
# Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
85
# They should be applied to any created parameters and buffers.
87
def test_factory_kwargs(self, device, dtype, module_info, training):
88
module_cls = module_info.module_cls
89
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
90
requires_grad=False, training=training)
91
for module_input in module_inputs:
92
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
94
# Check if this module creates parameters or registers buffers.
95
# The mock magic here passes through to the real Parameter / register_buffer
96
# logic and is only used to check call inputs.
97
module_creates_params_or_buffers = False
98
parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
99
with patch.object(torch.nn.Parameter, '__new__', parameter_new):
100
register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
101
with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
102
m = module_cls(*args, **kwargs)
105
# Check if a parameter or buffer was created with a tensor not passed to the constructor.
106
constructor_tensors = get_tensors_from(args, kwargs)
107
for mock in [parameter_new.mock, register_buffer.mock]:
108
for call_args, call_kwargs in mock.call_args_list:
109
call_tensors = get_tensors_from(call_args, call_kwargs)
110
if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors):
111
module_creates_params_or_buffers = True
114
if not module_creates_params_or_buffers:
117
# Instantiate module with the factory kwargs.
123
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
124
# Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
125
uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
126
with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
127
uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
128
with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
129
m = module_cls(*args, **kwargs)
131
uninit_param_new.mock.assert_has_calls(
132
[call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
133
uninit_buffer_new.mock.assert_has_calls(
134
[call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
136
# Check device placement and dtype for created parameters and buffers.
137
# Only verify floating point dtypes since that's what the kwarg applies to.
138
m = module_cls(*args, **kwargs)
140
self._assert_module_parameters_and_buffer_are(m, device, dtype)
144
def test_multiple_device_transfer(self, device, dtype, module_info, training):
145
module_cls = module_info.module_cls
146
module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
147
requires_grad=False, training=training)
148
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
149
requires_grad=False, training=training)
150
for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
151
if module_input_device.forward_input is None:
154
with freeze_rng_state():
155
# === Instantiate the module. ===
156
args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
157
m = module_cls(*args, **kwargs)
158
m.to(device).to(dtype)
161
# === Do forward pass on GPU ===
162
input_device_args = module_input_device.forward_input.args
163
input_device_kwargs = module_input_device.forward_input.kwargs
164
m(*input_device_args, **input_device_kwargs)
165
self._assert_module_parameters_and_buffer_are(m, device, dtype)
167
# === Move to CPU ===
168
input_cpu_args = module_input_cpu.forward_input.args
169
input_cpu_kwargs = module_input_cpu.forward_input.kwargs
171
m(*input_cpu_args, **input_cpu_kwargs)
172
self._assert_module_parameters_and_buffer_are(m, "cpu", dtype)
174
# === Move back to GPU and forward pass ===
176
m(*input_device_args, **input_device_kwargs)
177
self._assert_module_parameters_and_buffer_are(m, device, dtype)
179
if torch.cuda.device_count() >= 2:
180
# === test cross-GPU transfer works
181
def _to_device1(objs):
182
if isinstance(objs, (tuple, list)):
183
return type(objs)(_to_device1(item) for item in objs)
184
elif isinstance(objs, dict):
185
return {name: _to_device1(item) for name, item in objs.items()}
186
elif isinstance(objs, torch.Tensor):
190
input_device_1_args = _to_device1(input_device_args)
191
input_device_1_kwargs = _to_device1(input_device_kwargs)
194
with torch.cuda.device(1):
195
m(*input_device_1_args, **input_device_1_kwargs)
196
self._assert_module_parameters_and_buffer_are(m, torch.device("cuda:1"), dtype)
199
def test_repr(self, device, dtype, module_info, training):
200
# Test module can be represented with repr and str without errors.
201
module_cls = module_info.module_cls
202
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
203
requires_grad=False, training=training)
204
for module_input in module_inputs:
205
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
206
m = module_cls(*args, **kwargs)
207
m.to(device).to(dtype)
210
# Check that these methods do not raise errors
215
def test_pickle(self, device, dtype, module_info, training):
216
# Test that module can be pickled and unpickled.
217
module_cls = module_info.module_cls
218
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
219
requires_grad=False, training=training)
220
for module_input in module_inputs:
221
if module_input.forward_input is None:
224
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
226
with freeze_rng_state():
227
# === Instantiate the module. ===
228
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
229
m = module_cls(*args, **kwargs)
230
m.to(device).to(dtype)
233
# === Do forward pass. ===
234
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
235
output = m(*args, **kwargs)
237
# === Check unpickled module gives the same output. ===
238
with tempfile.TemporaryFile() as f:
241
m_copy = torch.load(f)
242
output_from_copy = m_copy(*args, **kwargs)
243
self.assertEqual(output, output_from_copy)
246
@modules([module_info for module_info in module_db
247
if 'inplace' in signature(module_info.module_cls).parameters])
248
def test_check_inplace(self, device, dtype, module_info, training):
249
# Check if the inplace variant of the module gives the same result as the out of place
251
module_cls = module_info.module_cls
252
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
253
requires_grad=True, training=training)
254
for module_input in module_inputs:
255
if module_input.forward_input is None:
258
# === Instantiate the module. ===
259
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
260
m_op = module_cls(*args, **kwargs, inplace=False)
261
m_op.to(device).to(dtype)
263
m_inplace = module_cls(*args, **kwargs, inplace=True)
264
m_inplace.to(device).to(dtype)
265
m_inplace.train(training)
267
# === Inplace modules only supports inplace operations on the first argument ===
268
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
270
# === Do not allow the first input to be in input_kwargs ===
271
forward_sig = signature(m_op).parameters
272
self.assertGreaterEqual(len(forward_sig), 1)
273
first_param_name = next(iter(forward_sig.items()))
274
self.assertNotIn(first_param_name, input_kwargs)
276
# === Out of place operation does not write to original tensor ===
277
self.assertGreaterEqual(len(input_args), 1)
278
input_version = input_args[0]._version
279
with freeze_rng_state():
280
output_op = m_op(*input_args, **input_kwargs)
281
self.assertEqual(input_args[0]._version, input_version)
283
# === Check that the inplace operation gives the same result ===
284
input_arg_copy = deepcopy(input_args)
285
input_arg_clone = tuple(i.clone() for i in input_arg_copy)
286
input_clone_version = input_arg_clone[0]._version
287
with freeze_rng_state():
288
output_ip = m_inplace(*input_arg_clone, **input_kwargs)
289
self.assertGreater(input_arg_clone[0]._version, input_clone_version)
290
self.assertEqual(output_op, output_ip)
292
# === Check that the gradients are the same ===
293
grad = output_op.data.clone().normal_()
294
output_op.backward(grad)
295
output_ip.backward(grad)
296
self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
298
def _traverse_obj(self, obj, func):
299
if isinstance(obj, (tuple, list)):
300
return type(obj)(self._traverse_obj(o, func) for o in obj)
301
elif isgenerator(obj):
302
return tuple(self._traverse_obj(o, func) for o in obj)
303
elif isinstance(obj, dict):
304
return {name: self._traverse_obj(o, func) for name, o in obj.items()}
305
elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
310
def _retain_grad(self, obj):
311
# gradients needs to be retained to check for grad. This is useful when
312
# non-leafs are present in the graph.
313
def inner_retain_grad(obj):
314
if obj.requires_grad:
316
self._traverse_obj(obj, inner_retain_grad)
318
def _get_grads(self, obj):
319
def inner_get_grad(obj):
320
if obj.requires_grad:
322
return self._traverse_obj(obj, inner_get_grad)
324
def _zero_grad(self, obj):
325
def inner_zero_grad(obj):
326
if obj.grad is not None:
328
self._traverse_obj(obj, inner_zero_grad)
331
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
332
# Check modules work with non-contiguous tensors
334
module_cls = module_info.module_cls
335
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
336
requires_grad=True, training=training)
338
def _make_non_contiguous(obj):
339
def inner_make_non_contiguous(obj):
340
# Scalar tensors can not be made non-contiguous
341
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
344
out = torch.repeat_interleave(obj, 2, dim=-1)
345
out = out[..., ::2].detach()
346
out.requires_grad = obj.requires_grad
348
return self._traverse_obj(obj, inner_make_non_contiguous)
350
def _can_be_noncontiguous(obj):
351
if isinstance(obj, (tuple, list)):
352
return any(_can_be_noncontiguous(o) for o in obj)
353
elif isinstance(obj, dict):
354
return any(_can_be_noncontiguous(o) for o in obj.values())
355
# scalar tensors can not be non-contiguous
356
if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
360
for module_input in module_inputs:
361
if module_input.forward_input is None:
364
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
365
if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)):
368
# === Instantiate the module. ===
369
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
370
m = module_cls(*args, **kwargs)
371
m.to(device).to(dtype)
374
self._retain_grad((input_args, input_kwargs))
376
# === Forward with default input
377
with freeze_rng_state():
378
default_output = m(*input_args, **input_kwargs)
379
if isinstance(default_output, torch.Tensor):
380
grad_output = default_output.clone().detach_().normal_()
381
default_output.backward(grad_output, retain_graph=True)
383
grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_() if o.requires_grad else None)
384
for o in default_output)
385
flattened_default_output = torch.utils._pytree.tree_leaves(default_output)
386
flattened_grad_output = torch.utils._pytree.tree_leaves(grad_output)
387
for o, g_o in zip(flattened_default_output, flattened_grad_output):
388
if (o.requires_grad):
389
o.backward(g_o, retain_graph=True)
391
default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
392
default_param_grad = deepcopy([p.grad for p in m.parameters()])
394
# === Construct non-contiguous tensors ===
395
nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs))
396
nc_grad_output = _make_non_contiguous(grad_output)
398
# === Compare results with non-contiguous and contiguous tensors ===
399
inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)]
400
grads = [grad_output, nc_grad_output]
402
for (in_args, in_kwargs), g_out in product(inputs, grads):
403
g_out_copy = deepcopy(g_out)
404
self._zero_grad((in_args, in_kwargs))
405
self._zero_grad(m.parameters())
407
with freeze_rng_state():
408
out = m(*in_args, **in_kwargs)
409
if isinstance(out, torch.Tensor):
410
out.backward(g_out_copy, retain_graph=True)
412
flattened_out = torch.utils._pytree.tree_leaves(out)
413
flattened_g_out_copy = torch.utils._pytree.tree_leaves(g_out_copy)
414
for o, g_o in zip(flattened_out, flattened_g_out_copy):
416
o.backward(g_o, retain_graph=True)
418
input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
419
self.assertEqual(out, default_output)
420
self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0)
421
self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0)
423
param_grad = [p.grad for p in m.parameters()]
424
self.assertEqual(param_grad, default_param_grad)
426
def _test_gradients_helper(self, device, dtype, module_info, training, check):
428
module_cls = module_info.module_cls
429
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
430
requires_grad=True, training=training)
431
# === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
432
gradcheck_nondet_tol = 0.0
433
if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
434
gradcheck_nondet_tol = module_info.gradcheck_nondet_tol
436
for module_input in module_inputs:
437
if module_input.forward_input is None:
440
# === Instantiate the module. ===
441
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
442
m = module_cls(*args, **kwargs)
443
m.to(device).to(dtype)
446
params = tuple(m.parameters())
448
# === Lazy modules need to see an input to initialize params before gradcheck is run. ===
449
input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
450
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
451
with torch.no_grad():
452
m(*input_args, **input_kwargs)
454
# === Perform gradient check on the input_args ===
457
for name, obj in input_kwargs.items():
458
if isinstance(obj, torch.Tensor):
459
kwarg_tensors.append((name, obj))
461
other_kwargs[name] = obj
463
def fn_to_gradcheck(*flat_input_and_params):
464
input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec)
465
new_input_args = input_and_params[:len(input_args)]
466
kwarg_args = input_and_params[-len(kwarg_tensors):]
467
new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
469
with freeze_rng_state():
470
output = m(*new_input_args, **new_kwargs, **other_kwargs)
471
output_flattened = torch.utils._pytree.tree_leaves(output)
472
return output_flattened
474
# check total derivative
475
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
476
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
478
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
480
# check partial derivatives
481
old_params_requires_grad = [p.requires_grad for p in params]
483
p.requires_grad = False
485
old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors]
486
for (_, obj) in kwarg_tensors:
487
obj.requires_grad = False
489
for p, old in zip(params, old_params_requires_grad):
490
p.requires_grad = old
491
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
492
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
493
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
494
p.requires_grad = False
496
for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
497
obj.requires_grad = old
498
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
499
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
500
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
501
obj.requires_grad = False
503
@modules(module_db, allowed_dtypes=[torch.double])
504
def test_grad(self, device, dtype, module_info, training):
505
self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
507
@modules([m for m in module_db if m.supports_gradgrad],
508
allowed_dtypes=[torch.double])
509
def test_gradgrad(self, device, dtype, module_info, training):
510
self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
513
@with_tf32_off # Turn off TF32 to compute at full precision https://github.com/pytorch/pytorch/issues/86798
514
@toleranceOverride({torch.float32: tol(5e-2, 0),
515
torch.float64: tol(4e-4, 0)})
517
def test_cpu_gpu_parity(self, device, dtype, module_info, training):
518
# TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
519
# nicer way for eval mode only.
520
# See https://github.com/pytorch/pytorch/issues/79161
521
rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM}
522
if (module_info.module_cls in rnn_modules
525
and torch.backends.cudnn.enabled):
528
# Test cpu and gpu results are the same
529
module_cls = module_info.module_cls
530
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
531
requires_grad=True, training=training)
534
if isinstance(obj, torch.Tensor):
535
res = obj.detach().to(device=device)
536
res.requires_grad = obj.requires_grad
538
elif isinstance(obj, tuple):
539
return tuple(_to_device(o) for o in obj)
540
elif isinstance(obj, dict):
541
return {key: _to_device(o) for key, o in obj.items()}
545
for module_input in module_inputs_cpu:
546
# === Move input from cpu to device ===
547
cpu_forward_args = module_input.forward_input.args
548
cpu_forward_kwargs = module_input.forward_input.kwargs
550
gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs))
552
self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs))
554
# === Construct module on cpu and gpu ===
555
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
557
cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
558
cpu_module.train(training)
559
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
560
gpu_module.train(training)
562
# === Lazy modules need to see an input to initialize params ===
563
if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin):
564
with torch.no_grad():
565
cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
566
gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
568
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
569
gpu_p.data.copy_(cpu_p)
571
# === Compare forward output between cpu and gpu ===
572
cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
573
gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
575
self.assertEqual(cpu_outputs, gpu_outputs)
577
# === Run backwards on CPU and GPU and compare results ===
578
def check_backward(cpu_output, gpu_output):
579
cpu_grad_output = cpu_output.clone().normal_()
580
gpu_grad_output = cpu_grad_output.type_as(gpu_output)
582
cpu_output.backward(cpu_grad_output, retain_graph=True)
583
gpu_output.backward(gpu_grad_output, retain_graph=True)
585
cpu_grad_input = self._get_grads(cpu_forward_args)
586
gpu_grad_input = self._get_grads(gpu_forward_args)
587
self.assertEqual(cpu_grad_input, gpu_grad_input)
589
for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
590
self.assertEqual(cpu_p.grad, gpu_p.grad)
592
cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs)
593
gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
594
self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)
597
if isinstance(cpu_outputs, torch.Tensor):
598
check_backward(cpu_outputs, gpu_outputs)
600
flatten_cpu_outputs = torch.utils._pytree.tree_leaves(cpu_outputs)
601
flatten_gpu_outputs = torch.utils._pytree.tree_leaves(gpu_outputs)
602
for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
603
if cpu_output.requires_grad:
604
check_backward(cpu_output, gpu_output)
608
def test_memory_format(self, device, dtype, module_info, training):
609
is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6)
610
or torch.cuda.get_device_capability(0) == (8, 0))
611
# TODO tighten it to a specific module
612
atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None)
613
module_cls = module_info.module_cls
614
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
615
requires_grad=True, training=training)
616
module_memformat_affects_out = module_info.module_memformat_affects_out
618
def _get_mem_formats(channels_last=False, channels_last_3d=False):
620
return ([torch.contiguous_format, torch.channels_last],
621
[torch.preserve_format, torch.contiguous_format, torch.channels_last])
622
elif channels_last_3d:
623
return ([torch.contiguous_format, torch.channels_last_3d],
624
[torch.preserve_format, torch.contiguous_format, torch.channels_last_3d])
626
return ([torch.contiguous_format],
627
[torch.preserve_format, torch.contiguous_format])
629
# Check that at least one Tensor input has dim == n
630
def _check_dims(obj, n):
631
if isinstance(obj, torch.Tensor):
632
return obj.dim() == n
633
elif isinstance(obj, (tuple, list)):
634
return any(_check_dims(o, n) for o in obj)
638
# Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format
639
def _to_mem_format(mem_format, obj):
640
def inner_to_mem_format(obj):
642
if ((mem_format == torch.channels_last and d != 4)
643
or (mem_format == torch.channels_last_3d and d != 5)):
644
return obj.clone().detach().requires_grad_(obj.requires_grad)
645
return obj.clone().to(memory_format=mem_format).detach().requires_grad_(obj.requires_grad)
647
return self._traverse_obj(obj, inner_to_mem_format)
649
def _check_out_mem_format(output, input_mem_format, module_mem_format):
650
def inner_check_out_mem_format(output):
652
if (d == 4 and ((input_mem_format == torch.channels_last)
653
or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
654
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
655
elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
656
or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
657
self.assertTrue(output.is_contiguous(memory_format=torch.channels_last_3d))
659
self.assertTrue(output.is_contiguous())
660
return self._traverse_obj(output, inner_check_out_mem_format)
663
return isinstance(t, torch.Tensor) and t.requires_grad
665
for module_input in module_inputs:
666
if module_input.forward_input is None:
669
supports_channels_last = _check_dims(module_input.forward_input.args, 4)
670
supports_channels_last_3d = _check_dims(module_input.forward_input.args, 5)
671
input_mem_formats, module_mem_formats = _get_mem_formats(supports_channels_last, supports_channels_last_3d)
673
with freeze_rng_state():
674
# === Instantiate the module. ===
675
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
677
m = module_cls(*args, **kwargs)
678
m.to(device).to(dtype)
681
# === Get output in (contiguous, contiguous) configuration. ===
682
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
683
desired_outputs = m(*args, **kwargs)
684
# === Do backward pass. ===
685
ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(desired_outputs) if _req_grad(t))
686
if training and len(ref_diff_outputs) > 0:
687
params = tuple(p for p in m.parameters())
688
ref_diff_inputs = tuple(
690
for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
693
ref_grad_outputs = tuple(
695
for t in ref_diff_outputs
697
ref_grad_inputs = torch.autograd.grad(
700
grad_outputs=ref_grad_outputs,
703
for input_mem_format in input_mem_formats:
704
# === Change memformat of input. ===
705
d_args = _to_mem_format(input_mem_format, module_input.forward_input.args)
706
d_kwargs = _to_mem_format(input_mem_format, module_input.forward_input.kwargs)
708
# See https://github.com/pytorch/pytorch/issues/107861
709
# When inductor tests are turned on, the setting of requires_grad will be lost
711
torch.utils._pytree.tree_leaves(d_args),
712
torch.utils._pytree.tree_leaves(module_input.forward_input.args),
714
t1.requires_grad_(t2.requires_grad)
716
torch.utils._pytree.tree_leaves(d_kwargs),
717
torch.utils._pytree.tree_leaves(module_input.forward_input.kwargs),
719
t1.requires_grad_(t2.requires_grad)
721
module_input.forward_input.args = d_args
722
module_input.forward_input.kwargs = d_kwargs
724
for module_mem_format in module_mem_formats:
725
# === Change memformat of module ===
726
m.to(memory_format=module_mem_format)
728
# === Do forward pass. ===
729
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
730
outputs = m(*args, **kwargs)
732
# === Compare outputs to (contiguous, contiguous) output. ===
733
if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format:
734
self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol)
736
# === Check mem format of output. ===
737
_check_out_mem_format(outputs, input_mem_format, module_mem_format)
739
# === Do backward pass. ===
740
diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(outputs) if _req_grad(t))
741
if training and len(diff_outputs) > 0:
742
params = tuple(p for p in m.parameters())
745
for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
748
grad_outputs = tuple(
749
torch.empty_like(t1).copy_(t2)
750
for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
753
grad_inputs = torch.autograd.grad(
756
grad_outputs=grad_outputs,
760
input_mem_format != torch.contiguous_format
761
or module_mem_format != torch.contiguous_format
764
grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol
767
# === Check mem format of grad_inputs. ===
768
_check_out_mem_format(grad_inputs, input_mem_format, module_mem_format)
770
# Test whether train and eval modes differ for each module. Use to verify
771
# that the ModuleInfo entry flag is correct.
772
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
773
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
774
module_cls = module_info.module_cls
775
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
776
requires_grad=False, training=training)
778
# Run forward inputs through to see if the training flag is accessed during forward.
779
for module_input in module_inputs:
780
if module_input.forward_input is None:
783
# === Instantiate the module. ===
784
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
785
m = module_cls(*args, **kwargs)
786
m.to(device).to(dtype)
789
# Remove training attribute and see if forward still works.
790
delattr(m, 'training')
792
# === Do forward pass. ===
794
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
796
except AttributeError as e:
797
if "'training'" in str(e):
798
self.assertTrue(module_info.train_and_eval_differ,
799
f"The ModuleInfo entry for {module_info.name} has "
800
"train_and_eval_differ=False, but the training mode was found to "
801
"affect the forward pass. Consider setting train_and_eval_differ=True "
802
"for this ModuleInfo entry.")
809
def test_device_ctx_init(self, device, dtype, module_info, training):
810
module_cls = module_info.module_cls
811
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
812
requires_grad=False, training=training)
813
with torch.device('meta'):
814
module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
815
requires_grad=False, training=training)
817
for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
818
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
819
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
821
c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
822
fw_args_meta, fw_kwargs_meta = module_input_meta.forward_input.args, module_input_meta.forward_input.kwargs
824
m_cpu = module_cls(*c_args, **c_kwargs)
826
with torch.device('meta'):
827
m = module_cls(*c_args_meta, **c_kwargs_meta)
829
for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()),
830
zip(m.buffers(), m_cpu.buffers())):
831
if torch.nn.parameter.is_lazy(p_meta):
833
self.assertTrue(p_meta.is_meta)
834
assert_metadata_eq(self.assertEqual, p_meta, p_cpu)
837
@modules([module for module in module_db if module.module_error_inputs_func is not None])
838
def test_errors(self, device, dtype, module_info, training):
839
module_cls = module_info.module_cls
840
error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype,
841
requires_grad=False, training=training)
842
for error_input in error_inputs:
843
module_input = error_input.module_error_input
844
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
845
if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR:
846
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
847
m = module_cls(*c_args, **c_kwargs)
848
elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR:
849
m = module_cls(*c_args, **c_kwargs)
850
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
851
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
852
m(*fw_args, **fw_kwargs)
854
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
856
@modules([module for module in module_db if not module.is_lazy])
857
@parametrize('swap', [True, False])
858
@parametrize('set_grad', [True, False])
859
@wrapSwapTensorsTest()
860
def test_to(self, device, dtype, module_info, training, swap, set_grad):
861
module_cls = module_info.module_cls
863
if torch.cuda.is_available():
865
dtypes = module_info.dtypes
866
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
867
requires_grad=False, training=training)
868
torch.__future__.set_swap_module_params_on_conversion(swap)
870
for module_input in module_inputs:
871
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
873
m = module_cls(*c_args, **c_kwargs)
875
# Avoid using `module.to()` when constructing module since that is the method we are testing
876
def _to(m, set_grad=False):
877
for c in m.children():
878
_to(c, set_grad=set_grad)
879
for n, p in m.named_parameters(recurse=False):
880
new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype))
883
new_p.grad = torch.randn_like(new_p)
884
for n, b in m.named_buffers(recurse=False):
885
new_b = b.detach().clone().to(device, dtype)
887
_to(m, set_grad=set_grad)
889
prev_device, prev_dtype = device, dtype
890
for device_, dtype_ in product(devices, dtypes):
891
# if device/dtype do not change, grad.to(device, dtype) is a no-op so
892
# swapping will not change ._cdata
893
# parameters will be wrapped in an nn.Parameter before swapping
894
# which will cause the ._cdata to change
895
g_no_swap = device_ == prev_device and dtype_ == prev_dtype
896
prev_device, prev_dtype = device_, dtype_
898
p_ids_before = [id(p) for p in m.parameters()]
899
p_cdatas_before = [p._cdata for p in m.parameters()]
901
g_ids_before = [id(p.grad) for p in m.parameters()]
902
g_cdatas_before = [p.grad._cdata for p in m.parameters()]
904
m.to(device=device_, dtype=dtype_)
906
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
907
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
908
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
909
p_ids_after = [id(p) for p in m.parameters()]
910
p_cdatas_after = [p._cdata for p in m.parameters()]
913
self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters()))
914
self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters()))
915
g_ids_after = [id(p.grad) for p in m.parameters()]
916
g_cdatas_after = [p.grad._cdata for p in m.parameters()]
919
# id same, ._cdata differs --> swapped cdata of THPVariable
920
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
921
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
924
all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after)))
926
# id and _cdata remain the same --> .data setting
927
self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after)))
928
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
930
self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
931
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
934
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
935
@parametrize('swap', [True, False])
936
@wrapSwapTensorsTest()
937
def test_to_empty(self, device, dtype, module_info, swap, training):
938
module_cls = module_info.module_cls
940
with torch.device("meta"):
941
module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
942
requires_grad=False, training=training)
944
torch.__future__.set_swap_module_params_on_conversion(swap)
945
device_ = torch.device(device)
947
for module_input in module_inputs:
948
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
950
with torch.device("meta"):
951
m = module_cls(*c_args, **c_kwargs)
953
p_ids_before = [id(p) for p in m.parameters()]
954
p_cdatas_before = [p._cdata for p in m.parameters()]
955
m.to_empty(device=device_)
957
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
958
self.assertTrue(all(p.device == device_ for p in m.parameters()))
959
self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
960
p_ids_after = [id(p) for p in m.parameters()]
961
p_cdatas_after = [p._cdata for p in m.parameters()]
964
# id same, ._cdata differs --> swapped cdata of THPVariable
965
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
966
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
968
# id and ._cdata differ
969
# meta and device have different shallow copy types, so this will create a new
970
# parameter and assign it to the module
971
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
972
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
975
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
977
if __name__ == '__main__':