pytorch

Форк
0
/
test_modules.py 
978 строк · 49.3 Кб
1
# Owner(s): ["module: nn"]
2

3
from itertools import chain, product
4
from inspect import signature, isgenerator
5
from copy import deepcopy
6
import tempfile
7
from operator import methodcaller
8

9
import torch
10

11
from torch._subclasses.meta_utils import assert_metadata_eq
12
from torch.testing._internal.common_cuda import with_tf32_off
13
from torch.testing._internal.common_device_type import (
14
    instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
15
from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
16
from torch.testing._internal.common_utils import (
17
    TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
18
    gradgradcheck, parametrize, wrapSwapTensorsTest)
19
from unittest.mock import patch, call
20

21

22
class TestModule(TestCase):
23
    _do_cuda_memory_leak_check = True
24
    _do_cuda_non_default_stream = True
25
    precision = 1e-5
26
    rel_tol = 1e-5
27

28
    def _assert_module_parameters_and_buffer_are(self, module, device, dtype):
29
        # Check device placement and dtype for created parameters and buffers.
30
        # Only verify floating point dtypes since that's what the kwarg or methods
31
        # such as `float()` applies to.
32
        if not isinstance(device, torch.device):
33
            device = torch.device(device)
34

35
        def _check_module(items, name, device=device, dtype=dtype):
36
            for item_name, item in items:
37
                self.assertEqual(
38
                    item.device, device,
39
                    f'{name} {item_name} is on device {item.device} instead of the expected device {device}')
40
                if item.dtype.is_floating_point:
41
                    self.assertEqual(
42
                        item.dtype, dtype,
43
                        f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}')
44
        _check_module(module.named_parameters(), "Parameter")
45
        _check_module(module.named_buffers(), "Buffer")
46

47
    @modules(module_db)
48
    def test_forward(self, device, dtype, module_info, training):
49
        module_cls = module_info.module_cls
50
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
51
                                                       requires_grad=False, training=training)
52
        dtype_to_method_caller = {
53
            torch.float32: methodcaller("float"),
54
            torch.float64: methodcaller("double"),
55
        }
56
        for module_input in module_inputs:
57
            if module_input.forward_input is None:
58
                continue
59

60
            with freeze_rng_state():
61
                # === Instantiate the module. ===
62
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
63
                m = module_cls(*args, **kwargs)
64
                m.to(device).to(dtype)
65
                m.train(training)
66

67
                # === Do forward pass. ===
68
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
69
                outputs = m(*args, **kwargs)
70

71
                # === Compare outputs to a reference if one is specified. ===
72
                # TODO: Handle precision
73
                reference_fn = module_input.reference_fn
74
                if reference_fn is not None:
75
                    ref_outputs = reference_fn(m, *args, **kwargs)
76
                    self.assertEqual(outputs, ref_outputs)
77

78
                # === Use the method call and verify the parameters and buffers ===
79
                if dtype in dtype_to_method_caller:
80
                    dtype_to_method_caller[dtype](m)
81
                    m(*args, **kwargs)
82
                    self._assert_module_parameters_and_buffer_are(m, device, dtype)
83

84
    # Tests passing factory kwargs (e.g. device / dtype) during module instantiation.
85
    # They should be applied to any created parameters and buffers.
86
    @modules(module_db)
87
    def test_factory_kwargs(self, device, dtype, module_info, training):
88
        module_cls = module_info.module_cls
89
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
90
                                                       requires_grad=False, training=training)
91
        for module_input in module_inputs:
92
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
93

94
            # Check if this module creates parameters or registers buffers.
95
            # The mock magic here passes through to the real Parameter / register_buffer
96
            # logic and is only used to check call inputs.
97
            module_creates_params_or_buffers = False
98
            parameter_new = mock_wrapper(torch.nn.Parameter.__new__)
99
            with patch.object(torch.nn.Parameter, '__new__', parameter_new):
100
                register_buffer = mock_wrapper(torch.nn.Module.register_buffer)
101
                with patch.object(torch.nn.Module, 'register_buffer', register_buffer):
102
                    m = module_cls(*args, **kwargs)
103
                    m.train(training)
104

105
                    # Check if a parameter or buffer was created with a tensor not passed to the constructor.
106
                    constructor_tensors = get_tensors_from(args, kwargs)
107
                    for mock in [parameter_new.mock, register_buffer.mock]:
108
                        for call_args, call_kwargs in mock.call_args_list:
109
                            call_tensors = get_tensors_from(call_args, call_kwargs)
110
                            if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors):
111
                                module_creates_params_or_buffers = True
112
                                break
113

114
            if not module_creates_params_or_buffers:
115
                continue
116

117
            # Instantiate module with the factory kwargs.
118
            kwargs.update({
119
                'device': device,
120
                'dtype': dtype,
121
            })
122

123
            if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
124
                # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers.
125
                uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__)
126
                with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new):
127
                    uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__)
128
                    with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new):
129
                        m = module_cls(*args, **kwargs)
130
                        m.train(training)
131
                        uninit_param_new.mock.assert_has_calls(
132
                            [call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls])
133
                        uninit_buffer_new.mock.assert_has_calls(
134
                            [call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls])
135
            else:
136
                # Check device placement and dtype for created parameters and buffers.
137
                # Only verify floating point dtypes since that's what the kwarg applies to.
138
                m = module_cls(*args, **kwargs)
139
                m.train(training)
140
                self._assert_module_parameters_and_buffer_are(m, device, dtype)
141

142
    @onlyCUDA
143
    @modules(module_db)
144
    def test_multiple_device_transfer(self, device, dtype, module_info, training):
145
        module_cls = module_info.module_cls
146
        module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
147
                                                              requires_grad=False, training=training)
148
        module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
149
                                                           requires_grad=False, training=training)
150
        for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu):
151
            if module_input_device.forward_input is None:
152
                continue
153

154
            with freeze_rng_state():
155
                # === Instantiate the module. ===
156
                args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs
157
                m = module_cls(*args, **kwargs)
158
                m.to(device).to(dtype)
159
                m.train(training)
160

161
                # === Do forward pass on GPU ===
162
                input_device_args = module_input_device.forward_input.args
163
                input_device_kwargs = module_input_device.forward_input.kwargs
164
                m(*input_device_args, **input_device_kwargs)
165
                self._assert_module_parameters_and_buffer_are(m, device, dtype)
166

167
                # === Move to CPU ===
168
                input_cpu_args = module_input_cpu.forward_input.args
169
                input_cpu_kwargs = module_input_cpu.forward_input.kwargs
170
                m.cpu()
171
                m(*input_cpu_args, **input_cpu_kwargs)
172
                self._assert_module_parameters_and_buffer_are(m, "cpu", dtype)
173

174
                # === Move back to GPU and forward pass ===
175
                m.cuda()
176
                m(*input_device_args, **input_device_kwargs)
177
                self._assert_module_parameters_and_buffer_are(m, device, dtype)
178

179
                if torch.cuda.device_count() >= 2:
180
                    # === test cross-GPU transfer works
181
                    def _to_device1(objs):
182
                        if isinstance(objs, (tuple, list)):
183
                            return type(objs)(_to_device1(item) for item in objs)
184
                        elif isinstance(objs, dict):
185
                            return {name: _to_device1(item) for name, item in objs.items()}
186
                        elif isinstance(objs, torch.Tensor):
187
                            return objs.cuda(1)
188
                        else:
189
                            return objs
190
                    input_device_1_args = _to_device1(input_device_args)
191
                    input_device_1_kwargs = _to_device1(input_device_kwargs)
192

193
                    m.cuda(1)
194
                    with torch.cuda.device(1):
195
                        m(*input_device_1_args, **input_device_1_kwargs)
196
                    self._assert_module_parameters_and_buffer_are(m, torch.device("cuda:1"), dtype)
197

198
    @modules(module_db)
199
    def test_repr(self, device, dtype, module_info, training):
200
        # Test module can be represented with repr and str without errors.
201
        module_cls = module_info.module_cls
202
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
203
                                                       requires_grad=False, training=training)
204
        for module_input in module_inputs:
205
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
206
            m = module_cls(*args, **kwargs)
207
            m.to(device).to(dtype)
208
            m.train(training)
209

210
            # Check that these methods do not raise errors
211
            m.__repr__()
212
            str(m)
213

214
    @modules(module_db)
215
    def test_pickle(self, device, dtype, module_info, training):
216
        # Test that module can be pickled and unpickled.
217
        module_cls = module_info.module_cls
218
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
219
                                                       requires_grad=False, training=training)
220
        for module_input in module_inputs:
221
            if module_input.forward_input is None:
222
                continue
223

224
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
225

226
            with freeze_rng_state():
227
                # === Instantiate the module. ===
228
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
229
                m = module_cls(*args, **kwargs)
230
                m.to(device).to(dtype)
231
                m.train(training)
232

233
                # === Do forward pass. ===
234
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
235
                output = m(*args, **kwargs)
236

237
                # === Check unpickled module gives the same output. ===
238
                with tempfile.TemporaryFile() as f:
239
                    torch.save(m, f)
240
                    f.seek(0)
241
                    m_copy = torch.load(f)
242
                    output_from_copy = m_copy(*args, **kwargs)
243
                    self.assertEqual(output, output_from_copy)
244

245
    @skipMeta
246
    @modules([module_info for module_info in module_db
247
              if 'inplace' in signature(module_info.module_cls).parameters])
248
    def test_check_inplace(self, device, dtype, module_info, training):
249
        # Check if the inplace variant of the module gives the same result as the out of place
250
        # variant.
251
        module_cls = module_info.module_cls
252
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
253
                                                       requires_grad=True, training=training)
254
        for module_input in module_inputs:
255
            if module_input.forward_input is None:
256
                continue
257

258
            # === Instantiate the module. ===
259
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
260
            m_op = module_cls(*args, **kwargs, inplace=False)
261
            m_op.to(device).to(dtype)
262
            m_op.train(training)
263
            m_inplace = module_cls(*args, **kwargs, inplace=True)
264
            m_inplace.to(device).to(dtype)
265
            m_inplace.train(training)
266

267
            # === Inplace modules only supports inplace operations on the first argument ===
268
            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
269

270
            # ===  Do not allow the first input to be in input_kwargs ===
271
            forward_sig = signature(m_op).parameters
272
            self.assertGreaterEqual(len(forward_sig), 1)
273
            first_param_name = next(iter(forward_sig.items()))
274
            self.assertNotIn(first_param_name, input_kwargs)
275

276
            # === Out of place operation does not write to original tensor ===
277
            self.assertGreaterEqual(len(input_args), 1)
278
            input_version = input_args[0]._version
279
            with freeze_rng_state():
280
                output_op = m_op(*input_args, **input_kwargs)
281
            self.assertEqual(input_args[0]._version, input_version)
282

283
            # === Check that the inplace operation gives the same result ===
284
            input_arg_copy = deepcopy(input_args)
285
            input_arg_clone = tuple(i.clone() for i in input_arg_copy)
286
            input_clone_version = input_arg_clone[0]._version
287
            with freeze_rng_state():
288
                output_ip = m_inplace(*input_arg_clone, **input_kwargs)
289
            self.assertGreater(input_arg_clone[0]._version, input_clone_version)
290
            self.assertEqual(output_op, output_ip)
291

292
            # === Check that the gradients are the same ===
293
            grad = output_op.data.clone().normal_()
294
            output_op.backward(grad)
295
            output_ip.backward(grad)
296
            self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
297

298
    def _traverse_obj(self, obj, func):
299
        if isinstance(obj, (tuple, list)):
300
            return type(obj)(self._traverse_obj(o, func) for o in obj)
301
        elif isgenerator(obj):
302
            return tuple(self._traverse_obj(o, func) for o in obj)
303
        elif isinstance(obj, dict):
304
            return {name: self._traverse_obj(o, func) for name, o in obj.items()}
305
        elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)):
306
            return func(obj)
307
        else:
308
            return obj
309

310
    def _retain_grad(self, obj):
311
        # gradients needs to be retained to check for grad. This is useful when
312
        # non-leafs are present in the graph.
313
        def inner_retain_grad(obj):
314
            if obj.requires_grad:
315
                obj.retain_grad()
316
        self._traverse_obj(obj, inner_retain_grad)
317

318
    def _get_grads(self, obj):
319
        def inner_get_grad(obj):
320
            if obj.requires_grad:
321
                return obj.grad
322
        return self._traverse_obj(obj, inner_get_grad)
323

324
    def _zero_grad(self, obj):
325
        def inner_zero_grad(obj):
326
            if obj.grad is not None:
327
                obj.grad = None
328
        self._traverse_obj(obj, inner_zero_grad)
329

330
    @modules(module_db)
331
    def test_non_contiguous_tensors(self, device, dtype, module_info, training):
332
        # Check modules work with non-contiguous tensors
333

334
        module_cls = module_info.module_cls
335
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
336
                                                       requires_grad=True, training=training)
337

338
        def _make_non_contiguous(obj):
339
            def inner_make_non_contiguous(obj):
340
                # Scalar tensors can not be made non-contiguous
341
                if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
342
                    return obj
343

344
                out = torch.repeat_interleave(obj, 2, dim=-1)
345
                out = out[..., ::2].detach()
346
                out.requires_grad = obj.requires_grad
347
                return out
348
            return self._traverse_obj(obj, inner_make_non_contiguous)
349

350
        def _can_be_noncontiguous(obj):
351
            if isinstance(obj, (tuple, list)):
352
                return any(_can_be_noncontiguous(o) for o in obj)
353
            elif isinstance(obj, dict):
354
                return any(_can_be_noncontiguous(o) for o in obj.values())
355
            # scalar tensors can not be non-contiguous
356
            if not isinstance(obj, torch.Tensor) or obj.dim() == 0:
357
                return False
358
            return True
359

360
        for module_input in module_inputs:
361
            if module_input.forward_input is None:
362
                continue
363

364
            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
365
            if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)):
366
                continue
367

368
            # === Instantiate the module. ===
369
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
370
            m = module_cls(*args, **kwargs)
371
            m.to(device).to(dtype)
372
            m.train(training)
373

374
            self._retain_grad((input_args, input_kwargs))
375

376
            # === Forward with default input
377
            with freeze_rng_state():
378
                default_output = m(*input_args, **input_kwargs)
379
                if isinstance(default_output, torch.Tensor):
380
                    grad_output = default_output.clone().detach_().normal_()
381
                    default_output.backward(grad_output, retain_graph=True)
382
                else:
383
                    grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_() if o.requires_grad else None)
384
                                        for o in default_output)
385
                    flattened_default_output = torch.utils._pytree.tree_leaves(default_output)
386
                    flattened_grad_output = torch.utils._pytree.tree_leaves(grad_output)
387
                    for o, g_o in zip(flattened_default_output, flattened_grad_output):
388
                        if (o.requires_grad):
389
                            o.backward(g_o, retain_graph=True)
390

391
            default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs)))
392
            default_param_grad = deepcopy([p.grad for p in m.parameters()])
393

394
            # === Construct non-contiguous tensors ===
395
            nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs))
396
            nc_grad_output = _make_non_contiguous(grad_output)
397

398
            # === Compare results with non-contiguous and contiguous tensors ===
399
            inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)]
400
            grads = [grad_output, nc_grad_output]
401

402
            for (in_args, in_kwargs), g_out in product(inputs, grads):
403
                g_out_copy = deepcopy(g_out)
404
                self._zero_grad((in_args, in_kwargs))
405
                self._zero_grad(m.parameters())
406

407
                with freeze_rng_state():
408
                    out = m(*in_args, **in_kwargs)
409
                    if isinstance(out, torch.Tensor):
410
                        out.backward(g_out_copy, retain_graph=True)
411
                    else:
412
                        flattened_out = torch.utils._pytree.tree_leaves(out)
413
                        flattened_g_out_copy = torch.utils._pytree.tree_leaves(g_out_copy)
414
                        for o, g_o in zip(flattened_out, flattened_g_out_copy):
415
                            if o.requires_grad:
416
                                o.backward(g_o, retain_graph=True)
417

418
                input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs))
419
                self.assertEqual(out, default_output)
420
                self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0)
421
                self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0)
422

423
                param_grad = [p.grad for p in m.parameters()]
424
                self.assertEqual(param_grad, default_param_grad)
425

426
    def _test_gradients_helper(self, device, dtype, module_info, training, check):
427
        # Check gradients
428
        module_cls = module_info.module_cls
429
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
430
                                                       requires_grad=True, training=training)
431
        # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled
432
        gradcheck_nondet_tol = 0.0
433
        if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled):
434
            gradcheck_nondet_tol = module_info.gradcheck_nondet_tol
435

436
        for module_input in module_inputs:
437
            if module_input.forward_input is None:
438
                continue
439

440
            # === Instantiate the module. ===
441
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
442
            m = module_cls(*args, **kwargs)
443
            m.to(device).to(dtype)
444
            m.train(training)
445

446
            params = tuple(m.parameters())
447

448
            # === Lazy modules need to see an input to initialize params before gradcheck is run. ===
449
            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
450
            if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
451
                with torch.no_grad():
452
                    m(*input_args, **input_kwargs)
453

454
            # === Perform gradient check on the input_args ===
455
            other_kwargs = {}
456
            kwarg_tensors = []
457
            for name, obj in input_kwargs.items():
458
                if isinstance(obj, torch.Tensor):
459
                    kwarg_tensors.append((name, obj))
460
                else:
461
                    other_kwargs[name] = obj
462

463
            def fn_to_gradcheck(*flat_input_and_params):
464
                input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec)
465
                new_input_args = input_and_params[:len(input_args)]
466
                kwarg_args = input_and_params[-len(kwarg_tensors):]
467
                new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)}
468

469
                with freeze_rng_state():
470
                    output = m(*new_input_args, **new_kwargs, **other_kwargs)
471
                    output_flattened = torch.utils._pytree.tree_leaves(output)
472
                    return output_flattened
473

474
            # check total derivative
475
            grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
476
            flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
477

478
            self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
479

480
            # check partial derivatives
481
            old_params_requires_grad = [p.requires_grad for p in params]
482
            for p in params:
483
                p.requires_grad = False
484

485
            old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors]
486
            for (_, obj) in kwarg_tensors:
487
                obj.requires_grad = False
488

489
            for p, old in zip(params, old_params_requires_grad):
490
                p.requires_grad = old
491
                grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
492
                flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
493
                self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
494
                p.requires_grad = False
495

496
            for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
497
                obj.requires_grad = old
498
                grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
499
                flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
500
                self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
501
                obj.requires_grad = False
502

503
    @modules(module_db, allowed_dtypes=[torch.double])
504
    def test_grad(self, device, dtype, module_info, training):
505
        self._test_gradients_helper(device, dtype, module_info, training, gradcheck)
506

507
    @modules([m for m in module_db if m.supports_gradgrad],
508
             allowed_dtypes=[torch.double])
509
    def test_gradgrad(self, device, dtype, module_info, training):
510
        self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck)
511

512
    @onlyCUDA
513
    @with_tf32_off  # Turn off TF32 to compute at full precision https://github.com/pytorch/pytorch/issues/86798
514
    @toleranceOverride({torch.float32: tol(5e-2, 0),
515
                        torch.float64: tol(4e-4, 0)})
516
    @modules(module_db)
517
    def test_cpu_gpu_parity(self, device, dtype, module_info, training):
518
        # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a
519
        # nicer way for eval mode only.
520
        # See https://github.com/pytorch/pytorch/issues/79161
521
        rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM}
522
        if (module_info.module_cls in rnn_modules
523
                and not training
524
                and 'cuda' in device
525
                and torch.backends.cudnn.enabled):
526
            return
527

528
        # Test cpu and gpu results are the same
529
        module_cls = module_info.module_cls
530
        module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
531
                                                           requires_grad=True, training=training)
532

533
        def _to_device(obj):
534
            if isinstance(obj, torch.Tensor):
535
                res = obj.detach().to(device=device)
536
                res.requires_grad = obj.requires_grad
537
                return res
538
            elif isinstance(obj, tuple):
539
                return tuple(_to_device(o) for o in obj)
540
            elif isinstance(obj, dict):
541
                return {key: _to_device(o) for key, o in obj.items()}
542
            else:
543
                return deepcopy(obj)
544

545
        for module_input in module_inputs_cpu:
546
            # === Move input from cpu to device ===
547
            cpu_forward_args = module_input.forward_input.args
548
            cpu_forward_kwargs = module_input.forward_input.kwargs
549

550
            gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs))
551

552
            self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs))
553

554
            # === Construct module on cpu and gpu ===
555
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
556

557
            cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
558
            cpu_module.train(training)
559
            gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)
560
            gpu_module.train(training)
561

562
            # === Lazy modules need to see an input to initialize params ===
563
            if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin):
564
                with torch.no_grad():
565
                    cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
566
                    gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
567

568
            for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
569
                gpu_p.data.copy_(cpu_p)
570

571
            # === Compare forward output between cpu and gpu ===
572
            cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
573
            gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)
574

575
            self.assertEqual(cpu_outputs, gpu_outputs)
576

577
            # === Run backwards on CPU and GPU and compare results ===
578
            def check_backward(cpu_output, gpu_output):
579
                cpu_grad_output = cpu_output.clone().normal_()
580
                gpu_grad_output = cpu_grad_output.type_as(gpu_output)
581

582
                cpu_output.backward(cpu_grad_output, retain_graph=True)
583
                gpu_output.backward(gpu_grad_output, retain_graph=True)
584

585
                cpu_grad_input = self._get_grads(cpu_forward_args)
586
                gpu_grad_input = self._get_grads(gpu_forward_args)
587
                self.assertEqual(cpu_grad_input, gpu_grad_input)
588

589
                for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
590
                    self.assertEqual(cpu_p.grad, gpu_p.grad)
591

592
                cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs)
593
                gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
594
                self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)
595

596
            for _ in range(5):
597
                if isinstance(cpu_outputs, torch.Tensor):
598
                    check_backward(cpu_outputs, gpu_outputs)
599
                else:
600
                    flatten_cpu_outputs = torch.utils._pytree.tree_leaves(cpu_outputs)
601
                    flatten_gpu_outputs = torch.utils._pytree.tree_leaves(gpu_outputs)
602
                    for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs):
603
                        if cpu_output.requires_grad:
604
                            check_backward(cpu_output, gpu_output)
605

606
    @with_tf32_off
607
    @modules(module_db)
608
    def test_memory_format(self, device, dtype, module_info, training):
609
        is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6)
610
                                                     or torch.cuda.get_device_capability(0) == (8, 0))
611
        # TODO tighten it to a specific module
612
        atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None)
613
        module_cls = module_info.module_cls
614
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
615
                                                       requires_grad=True, training=training)
616
        module_memformat_affects_out = module_info.module_memformat_affects_out
617

618
        def _get_mem_formats(channels_last=False, channels_last_3d=False):
619
            if channels_last:
620
                return ([torch.contiguous_format, torch.channels_last],
621
                        [torch.preserve_format, torch.contiguous_format, torch.channels_last])
622
            elif channels_last_3d:
623
                return ([torch.contiguous_format, torch.channels_last_3d],
624
                        [torch.preserve_format, torch.contiguous_format, torch.channels_last_3d])
625
            else:
626
                return ([torch.contiguous_format],
627
                        [torch.preserve_format, torch.contiguous_format])
628

629
        # Check that at least one Tensor input has dim == n
630
        def _check_dims(obj, n):
631
            if isinstance(obj, torch.Tensor):
632
                return obj.dim() == n
633
            elif isinstance(obj, (tuple, list)):
634
                return any(_check_dims(o, n) for o in obj)
635
            else:
636
                return False
637

638
        # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format
639
        def _to_mem_format(mem_format, obj):
640
            def inner_to_mem_format(obj):
641
                d = obj.dim()
642
                if ((mem_format == torch.channels_last and d != 4)
643
                   or (mem_format == torch.channels_last_3d and d != 5)):
644
                    return obj.clone().detach().requires_grad_(obj.requires_grad)
645
                return obj.clone().to(memory_format=mem_format).detach().requires_grad_(obj.requires_grad)
646

647
            return self._traverse_obj(obj, inner_to_mem_format)
648

649
        def _check_out_mem_format(output, input_mem_format, module_mem_format):
650
            def inner_check_out_mem_format(output):
651
                d = output.dim()
652
                if (d == 4 and ((input_mem_format == torch.channels_last)
653
                                or (module_mem_format == torch.channels_last and module_memformat_affects_out))):
654
                    self.assertTrue(output.is_contiguous(memory_format=torch.channels_last))
655
                elif (d == 5 and ((input_mem_format == torch.channels_last_3d)
656
                                  or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))):
657
                    self.assertTrue(output.is_contiguous(memory_format=torch.channels_last_3d))
658
                else:
659
                    self.assertTrue(output.is_contiguous())
660
            return self._traverse_obj(output, inner_check_out_mem_format)
661

662
        def _req_grad(t):
663
            return isinstance(t, torch.Tensor) and t.requires_grad
664

665
        for module_input in module_inputs:
666
            if module_input.forward_input is None:
667
                continue
668

669
            supports_channels_last = _check_dims(module_input.forward_input.args, 4)
670
            supports_channels_last_3d = _check_dims(module_input.forward_input.args, 5)
671
            input_mem_formats, module_mem_formats = _get_mem_formats(supports_channels_last, supports_channels_last_3d)
672

673
            with freeze_rng_state():
674
                # === Instantiate the module. ===
675
                args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
676

677
                m = module_cls(*args, **kwargs)
678
                m.to(device).to(dtype)
679
                m.train(training)
680

681
                # === Get output in (contiguous, contiguous) configuration. ===
682
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
683
                desired_outputs = m(*args, **kwargs)
684
                # === Do backward pass. ===
685
                ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(desired_outputs) if _req_grad(t))
686
                if training and len(ref_diff_outputs) > 0:
687
                    params = tuple(p for p in m.parameters())
688
                    ref_diff_inputs = tuple(
689
                        t
690
                        for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
691
                        if _req_grad(t)
692
                    )
693
                    ref_grad_outputs = tuple(
694
                        torch.rand_like(t)
695
                        for t in ref_diff_outputs
696
                    )
697
                    ref_grad_inputs = torch.autograd.grad(
698
                        ref_diff_outputs,
699
                        ref_diff_inputs,
700
                        grad_outputs=ref_grad_outputs,
701
                    )
702

703
                for input_mem_format in input_mem_formats:
704
                    # === Change memformat of input. ===
705
                    d_args = _to_mem_format(input_mem_format, module_input.forward_input.args)
706
                    d_kwargs = _to_mem_format(input_mem_format, module_input.forward_input.kwargs)
707

708
                    # See https://github.com/pytorch/pytorch/issues/107861
709
                    # When inductor tests are turned on, the setting of requires_grad will be lost
710
                    for t1, t2 in zip(
711
                        torch.utils._pytree.tree_leaves(d_args),
712
                        torch.utils._pytree.tree_leaves(module_input.forward_input.args),
713
                    ):
714
                        t1.requires_grad_(t2.requires_grad)
715
                    for t1, t2 in zip(
716
                        torch.utils._pytree.tree_leaves(d_kwargs),
717
                        torch.utils._pytree.tree_leaves(module_input.forward_input.kwargs),
718
                    ):
719
                        t1.requires_grad_(t2.requires_grad)
720

721
                    module_input.forward_input.args = d_args
722
                    module_input.forward_input.kwargs = d_kwargs
723

724
                    for module_mem_format in module_mem_formats:
725
                        # === Change memformat of module ===
726
                        m.to(memory_format=module_mem_format)
727

728
                        # === Do forward pass. ===
729
                        args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
730
                        outputs = m(*args, **kwargs)
731

732
                        # === Compare outputs to (contiguous, contiguous) output. ===
733
                        if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format:
734
                            self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol)
735

736
                        # === Check mem format of output. ===
737
                        _check_out_mem_format(outputs, input_mem_format, module_mem_format)
738

739
                        # === Do backward pass. ===
740
                        diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(outputs) if _req_grad(t))
741
                        if training and len(diff_outputs) > 0:
742
                            params = tuple(p for p in m.parameters())
743
                            diff_inputs = tuple(
744
                                t
745
                                for t in torch.utils._pytree.tree_leaves((args, kwargs, params))
746
                                if _req_grad(t)
747
                            )
748
                            grad_outputs = tuple(
749
                                torch.empty_like(t1).copy_(t2)
750
                                for (t1, t2) in zip(diff_outputs, ref_grad_outputs)
751
                            )
752

753
                            grad_inputs = torch.autograd.grad(
754
                                diff_outputs,
755
                                diff_inputs,
756
                                grad_outputs=grad_outputs,
757
                            )
758

759
                            if (
760
                                input_mem_format != torch.contiguous_format
761
                                or module_mem_format != torch.contiguous_format
762
                            ):
763
                                self.assertEqual(
764
                                    grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol
765
                                )
766

767
                            # === Check mem format of grad_inputs. ===
768
                            _check_out_mem_format(grad_inputs, input_mem_format, module_mem_format)
769

770
    # Test whether train and eval modes differ for each module. Use to verify
771
    # that the ModuleInfo entry flag is correct.
772
    @modules(module_db, train_eval_mode=TrainEvalMode.train_only)
773
    def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
774
        module_cls = module_info.module_cls
775
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
776
                                                       requires_grad=False, training=training)
777

778
        # Run forward inputs through to see if the training flag is accessed during forward.
779
        for module_input in module_inputs:
780
            if module_input.forward_input is None:
781
                continue
782

783
            # === Instantiate the module. ===
784
            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
785
            m = module_cls(*args, **kwargs)
786
            m.to(device).to(dtype)
787
            m.train(training)
788

789
            # Remove training attribute and see if forward still works.
790
            delattr(m, 'training')
791

792
            # === Do forward pass. ===
793
            try:
794
                args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
795
                m(*args, **kwargs)
796
            except AttributeError as e:
797
                if "'training'" in str(e):
798
                    self.assertTrue(module_info.train_and_eval_differ,
799
                                    f"The ModuleInfo entry for {module_info.name} has "
800
                                    "train_and_eval_differ=False, but the training mode was found to "
801
                                    "affect the forward pass. Consider setting train_and_eval_differ=True "
802
                                    "for this ModuleInfo entry.")
803
                else:
804
                    raise e
805

806

807
    @onlyCPU
808
    @modules(module_db)
809
    def test_device_ctx_init(self, device, dtype, module_info, training):
810
        module_cls = module_info.module_cls
811
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
812
                                                       requires_grad=False, training=training)
813
        with torch.device('meta'):
814
            module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
815
                                                                requires_grad=False, training=training)
816

817
        for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
818
            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
819
            fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
820

821
            c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
822
            fw_args_meta, fw_kwargs_meta = module_input_meta.forward_input.args, module_input_meta.forward_input.kwargs
823

824
            m_cpu = module_cls(*c_args, **c_kwargs)
825

826
            with torch.device('meta'):
827
                m = module_cls(*c_args_meta, **c_kwargs_meta)
828

829
            for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()),
830
                                         zip(m.buffers(), m_cpu.buffers())):
831
                if torch.nn.parameter.is_lazy(p_meta):
832
                    continue
833
                self.assertTrue(p_meta.is_meta)
834
                assert_metadata_eq(self.assertEqual, p_meta, p_cpu)
835

836

837
    @modules([module for module in module_db if module.module_error_inputs_func is not None])
838
    def test_errors(self, device, dtype, module_info, training):
839
        module_cls = module_info.module_cls
840
        error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype,
841
                                                            requires_grad=False, training=training)
842
        for error_input in error_inputs:
843
            module_input = error_input.module_error_input
844
            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
845
            if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR:
846
                with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
847
                    m = module_cls(*c_args, **c_kwargs)
848
            elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR:
849
                m = module_cls(*c_args, **c_kwargs)
850
                fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
851
                with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
852
                    m(*fw_args, **fw_kwargs)
853
            else:
854
                raise NotImplementedError(f"Unknown error type {error_input.error_on}")
855

856
    @modules([module for module in module_db if not module.is_lazy])
857
    @parametrize('swap', [True, False])
858
    @parametrize('set_grad', [True, False])
859
    @wrapSwapTensorsTest()
860
    def test_to(self, device, dtype, module_info, training, swap, set_grad):
861
        module_cls = module_info.module_cls
862
        devices = ['cpu']
863
        if torch.cuda.is_available():
864
            devices += ['cuda']
865
        dtypes = module_info.dtypes
866
        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
867
                                                       requires_grad=False, training=training)
868
        torch.__future__.set_swap_module_params_on_conversion(swap)
869

870
        for module_input in module_inputs:
871
            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
872

873
            m = module_cls(*c_args, **c_kwargs)
874

875
            # Avoid using `module.to()` when constructing module since that is the method we are testing
876
            def _to(m, set_grad=False):
877
                for c in m.children():
878
                    _to(c, set_grad=set_grad)
879
                for n, p in m.named_parameters(recurse=False):
880
                    new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype))
881
                    setattr(m, n, new_p)
882
                    if set_grad:
883
                        new_p.grad = torch.randn_like(new_p)
884
                for n, b in m.named_buffers(recurse=False):
885
                    new_b = b.detach().clone().to(device, dtype)
886
                    setattr(m, n, new_b)
887
            _to(m, set_grad=set_grad)
888

889
            prev_device, prev_dtype = device, dtype
890
            for device_, dtype_ in product(devices, dtypes):
891
                # if device/dtype do not change, grad.to(device, dtype) is a no-op so
892
                # swapping will not change ._cdata
893
                # parameters will be wrapped in an nn.Parameter before swapping
894
                # which will cause the ._cdata to change
895
                g_no_swap = device_ == prev_device and dtype_ == prev_dtype
896
                prev_device, prev_dtype = device_, dtype_
897

898
                p_ids_before = [id(p) for p in m.parameters()]
899
                p_cdatas_before = [p._cdata for p in m.parameters()]
900
                if set_grad:
901
                    g_ids_before = [id(p.grad) for p in m.parameters()]
902
                    g_cdatas_before = [p.grad._cdata for p in m.parameters()]
903

904
                m.to(device=device_, dtype=dtype_)
905

906
                self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
907
                self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
908
                self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
909
                p_ids_after = [id(p) for p in m.parameters()]
910
                p_cdatas_after = [p._cdata for p in m.parameters()]
911

912
                if set_grad:
913
                    self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters()))
914
                    self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters()))
915
                    g_ids_after = [id(p.grad) for p in m.parameters()]
916
                    g_cdatas_after = [p.grad._cdata for p in m.parameters()]
917

918
                if swap:
919
                    # id same, ._cdata differs --> swapped cdata of THPVariable
920
                    self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
921
                    self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
922
                    if set_grad:
923
                        self.assertTrue(
924
                            all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after)))
925
                else:
926
                    # id and _cdata remain the same --> .data setting
927
                    self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after)))
928
                    self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
929
                    if set_grad:
930
                        self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
931
                        self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
932

933

934
    @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
935
    @parametrize('swap', [True, False])
936
    @wrapSwapTensorsTest()
937
    def test_to_empty(self, device, dtype, module_info, swap, training):
938
        module_cls = module_info.module_cls
939

940
        with torch.device("meta"):
941
            module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
942
                                                           requires_grad=False, training=training)
943

944
        torch.__future__.set_swap_module_params_on_conversion(swap)
945
        device_ = torch.device(device)
946

947
        for module_input in module_inputs:
948
            c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
949

950
            with torch.device("meta"):
951
                m = module_cls(*c_args, **c_kwargs)
952

953
            p_ids_before = [id(p) for p in m.parameters()]
954
            p_cdatas_before = [p._cdata for p in m.parameters()]
955
            m.to_empty(device=device_)
956

957
            self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
958
            self.assertTrue(all(p.device == device_ for p in m.parameters()))
959
            self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
960
            p_ids_after = [id(p) for p in m.parameters()]
961
            p_cdatas_after = [p._cdata for p in m.parameters()]
962

963
            if swap:
964
                # id same, ._cdata differs --> swapped cdata of THPVariable
965
                self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
966
                self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
967
            else:
968
                # id and ._cdata differ
969
                # meta and device have different shallow copy types, so this will create a new
970
                # parameter and assign it to the module
971
                self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
972
                self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
973

974

975
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
976

977
if __name__ == '__main__':
978
    run_tests()
979

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.