pytorch

Форк
0
/
test_expanded_weights.py 
1160 строк · 45.1 Кб
1
# Owner(s): ["module: nn"]
2
import unittest
3
from dataclasses import dataclass
4
from functools import partial
5
from itertools import chain, product
6

7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
from torch.nn import CrossEntropyLoss
11
from torch.nn.utils._expanded_weights import ExpandedWeight
12
from torch.nn.utils._expanded_weights.expanded_weights_utils import (
13
    forward_helper,
14
    set_grad_sample_if_exists,
15
    standard_kwargs,
16
    sum_over_all_but_batch_and_last_n,
17
    unpack_expanded_weight_or_tensor,
18
)
19
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
20
from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
21
from torch.testing._internal.common_device_type import (
22
    instantiate_device_type_tests,
23
    OpDTypes,
24
    ops,
25
)
26
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
27
from torch.testing._internal.common_modules import module_db, modules
28
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
29
from torch.testing._internal.common_utils import (
30
    freeze_rng_state,
31
    make_tensor,
32
    parametrize,
33
    run_tests,
34
    skipIfTorchDynamo,
35
    TestCase,
36
)
37
from torch.utils._pytree import tree_map_only
38

39

40
class TestContext:
41
    pass
42

43

44
class TestExpandedWeightHelperFunction(TestCase):
45
    def test_forward_helper(self, device):
46
        input = torch.randn(3, 4, device=device)
47
        weight = torch.randn(5, 4, device=device)
48
        bias = torch.randn(5, device=device)
49
        for weight_batched, bias_batched in product([True, False], [True, False]):
50
            maybe_batched_weight = weight
51
            maybe_batched_bias = bias
52
            if weight_batched:
53
                maybe_batched_weight = ExpandedWeight(
54
                    weight.clone().requires_grad_(), 3, loss_reduction="sum"
55
                )
56
            if bias_batched:
57
                maybe_batched_bias = ExpandedWeight(
58
                    bias.clone().requires_grad_(), 3, loss_reduction="sum"
59
                )
60
            args = (input, maybe_batched_weight, maybe_batched_bias)
61
            expanded_args, expanded_kwargs = standard_kwargs(("bias",), args)
62
            res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
63
            expected = nn.functional.linear(input, weight, bias)
64
            self.assertEqual(res, expected)
65

66
            self.assertEqual(len(expanded_args), 2)
67
            assert expanded_args[0] is args[0]  # avoids property checks in assertEquals
68
            assert expanded_args[1] is args[1]  # avoids property checks in assertEquals
69
            self.assertEqual(len(expanded_kwargs), 1)
70
            assert (
71
                expanded_kwargs["bias"] is args[2]
72
            )  # avoids property checks in assertEquals
73

74
    def test_forward_helper_failure_args(self, device):
75
        weight = torch.randn(5, 4, device=device)
76
        bias = torch.randn(5, device=device)
77
        with self.assertRaisesRegex(
78
            RuntimeError, r"do not support inputs that are also ExpandedWeights."
79
        ):
80
            input = ExpandedWeight(
81
                torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum"
82
            )
83
            expanded_args, expanded_kwargs = standard_kwargs(
84
                ("bias",), (input, weight, bias)
85
            )
86
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
87
        with self.assertRaisesRegex(
88
            RuntimeError, r"requires a Tensor as the first input"
89
        ):
90
            expanded_args, expanded_kwargs = standard_kwargs(
91
                ("bias",), (3, weight, bias)
92
            )
93
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
94
        with self.assertRaisesRegex(
95
            RuntimeError, r"requires a batch dimension but got an input of size 0"
96
        ):
97
            expanded_args, expanded_kwargs = standard_kwargs(
98
                ("bias",), (torch.tensor(3), weight, bias)
99
            )
100
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
101
        with self.assertRaisesRegex(
102
            RuntimeError, r"0 is not a valid batch size for Expanded Weights"
103
        ):
104
            expanded_args, expanded_kwargs = standard_kwargs(
105
                ("bias",), (torch.randn(0, 1, 2), weight, bias)
106
            )
107
            forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
108
        input = torch.randn(3, 4)
109
        for weight_batched, bias_batched in product([True, False], [True, False]):
110
            if not weight_batched and not bias_batched:
111
                continue
112
            maybe_batched_weight = weight
113
            maybe_batched_bias = bias
114
            if weight_batched:
115
                maybe_batched_weight = ExpandedWeight(
116
                    weight.clone().requires_grad_(), 4, loss_reduction="sum"
117
                )
118
            if bias_batched:
119
                maybe_batched_bias = ExpandedWeight(
120
                    bias.clone().requires_grad_(), 4, loss_reduction="sum"
121
                )
122
            with self.assertRaisesRegex(
123
                RuntimeError,
124
                r"Expected ExpandedWeights to have batch size matching input",
125
            ):
126
                expanded_args, expanded_kwargs = standard_kwargs(
127
                    ("bias",), (input, maybe_batched_weight, maybe_batched_bias)
128
                )
129
                forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
130

131
    def test_set_grad_sample_if_exists(self, device):
132
        def test_fn(a):
133
            return grad_sample
134

135
        orig_weight = torch.randn(4, device=device, requires_grad=True)
136
        expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
137
        grad_sample = torch.randn(3)
138
        set_grad_sample_if_exists(expanded_weight, test_fn)
139
        self.assertTrue(hasattr(orig_weight, "grad_sample"))
140
        self.assertEqual(orig_weight.grad_sample, grad_sample)
141

142
        basic_tensor = torch.randn(4, device=device)
143
        set_grad_sample_if_exists(basic_tensor, test_fn)
144
        self.assertFalse(hasattr(basic_tensor, "grad_sample"))
145

146
        non_tensor = 3
147
        set_grad_sample_if_exists(non_tensor, test_fn)
148
        self.assertFalse(hasattr(non_tensor, "grad_sample"))
149

150
    def test_set_grad_sample_if_exists_failure(self, device):
151
        def test_fn(a):
152
            return True
153

154
        grad_tensor = torch.randn(4, requires_grad=True, device=device)
155
        with self.assertRaisesRegex(
156
            RuntimeError,
157
            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
158
        ):
159
            set_grad_sample_if_exists(grad_tensor, test_fn)
160

161
    def test_unpack_expanded_weight_or_tensor(self, device):
162
        input = torch.randn(3, requires_grad=True, device=device)
163
        self.assertEqual(
164
            input,
165
            unpack_expanded_weight_or_tensor(
166
                ExpandedWeight(input, 3, loss_reduction="sum")
167
            ),
168
        )
169

170
        input.requires_grad_(False)
171
        self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
172
        self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
173

174
    def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
175
        input = torch.randn(3, requires_grad=True, device=device)
176
        self.assertTrue(
177
            unpack_expanded_weight_or_tensor(
178
                ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input
179
            )
180
        )
181

182
        input.requires_grad_(False)
183
        self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
184
        self.assertTrue(
185
            unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None
186
        )
187

188
    def test_unpack_expanded_weight_or_tensor_failure(self, device):
189
        input = torch.randn(3, requires_grad=True, device=device)
190
        with self.assertRaisesRegex(
191
            RuntimeError,
192
            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
193
        ):
194
            unpack_expanded_weight_or_tensor(input)
195

196
        with self.assertRaisesRegex(
197
            RuntimeError,
198
            r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
199
        ):
200
            unpack_expanded_weight_or_tensor(input, lambda x: x is input)
201

202
    def test_sum_over_all_but_batch_and_last_n(self, device):
203
        input = torch.randn(1, 2, 3, 4, 5, device=device)
204
        res = sum_over_all_but_batch_and_last_n(input, 2)
205
        expected = input.sum((1, 2))
206
        self.assertEqual(res, expected)
207

208
        res = sum_over_all_but_batch_and_last_n(input, 0)
209
        expected = input.sum((1, 2, 3, 4))
210
        self.assertEqual(res, expected)
211

212
        res = sum_over_all_but_batch_and_last_n(input, 4)
213
        self.assertEqual(res, input)
214

215

216
class TestExpandedWeightFunctional(TestCase):
217
    def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
218
        input = sample_input.input
219
        args = sample_input.args
220
        kwargs = sample_input.kwargs
221
        batch_size = input.shape[0] if len(input.shape) > 1 else 1
222

223
        # get per sample grads with ExpandedWeights objects
224
        loss_reduction = "sum" if reduction == torch.sum else "mean"
225
        (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
226
            sample_input, batch_size, loss_reduction
227
        )
228
        diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
229
        diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
230
        diff_input_list = [
231
            i.orig_weight if isinstance(i, ExpandedWeight) else i
232
            for i in diff_input_list
233
        ]
234
        if not diff_input_list:
235
            return
236
        result = run_op(op, ew_input, *ew_args, **ew_kwargs)
237
        reduction(
238
            result
239
        ).backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
240
        expanded_weight_grad = tuple(
241
            i.grad_sample if hasattr(i, "grad_sample") else i.grad
242
            for i in diff_input_list
243
        )
244

245
        # get per sample grads with for loop
246
        func = partial(run_op, op)
247

248
        per_sample_grad = for_loop_per_sample_grad(
249
            batch_size, reduction, input, func, *args, **kwargs
250
        )
251

252
        # check equality
253
        self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
254
        if loss_reduction == "mean":
255
            # don't check equality of `input.grad`s since these vanilla tensors won't be scaled
256
            expanded_weight_grad = expanded_weight_grad[1:]
257
            per_sample_grad = per_sample_grad[1:]
258
        for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad):
259
            self.assertEqual(result_grad, expected_grad)
260

261
    @ops(
262
        filter(lambda op: op.supports_expanded_weight, op_db),
263
        dtypes=OpDTypes.supported,
264
        allowed_dtypes=(torch.double,),
265
    )
266
    def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
267
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
268
        for sample_input in supported_inputs(op, sample_inputs):
269
            if (
270
                op.name == "nn.functional.embedding"
271
            ):  # embedding flips its argument order for autograd tests
272
                sample_input = SampleInput(
273
                    sample_input.args[0],
274
                    args=(sample_input.input,),
275
                    kwargs=sample_input.kwargs,
276
                )
277

278
            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
279

280
    @ops(
281
        filter(lambda op: op.supports_expanded_weight, op_db),
282
        dtypes=OpDTypes.supported,
283
        allowed_dtypes=(torch.double,),
284
    )
285
    def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
286
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
287
        for sample_input in supported_inputs(op, sample_inputs):
288
            if (
289
                op.name == "nn.functional.embedding"
290
            ):  # embedding flips its argument order for autograd tests
291
                sample_input = SampleInput(
292
                    sample_input.args[0],
293
                    args=(sample_input.input,),
294
                    kwargs=sample_input.kwargs,
295
                )
296

297
            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
298

299
    @ops(
300
        filter(lambda op: op.supports_expanded_weight, op_db),
301
        dtypes=OpDTypes.supported,
302
        allowed_dtypes=(torch.double,),
303
    )
304
    def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
305
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
306
        for sample_input in supported_inputs(op, sample_inputs):
307
            if (
308
                op.name == "nn.functional.embedding"
309
            ):  # embedding flips its argument order for autograd tests
310
                sample_input = SampleInput(
311
                    sample_input.args[0],
312
                    args=(sample_input.input,),
313
                    kwargs=sample_input.kwargs,
314
                )
315
            sample_input.input.requires_grad_(False)
316

317
            self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
318

319
    @skipIfTorchDynamo("Checking error message doesn't work with dynamo")
320
    @ops(
321
        filter(lambda op: op.supports_expanded_weight, op_db),
322
        dtypes=OpDTypes.supported,
323
        allowed_dtypes=(torch.double,),
324
    )
325
    def test_unsupported_expand_weights(self, device, dtype, op):
326
        sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
327
        unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
328
        for sample_input in unsupported_inputs:
329
            with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
330
                if (
331
                    op.name == "nn.functional.embedding"
332
                ):  # embedding flips its argument order for autograd tests
333
                    sample_input = SampleInput(
334
                        sample_input.args[0],
335
                        args=(sample_input.input,),
336
                        kwargs=sample_input.kwargs,
337
                    )
338
                input = sample_input.input
339

340
                batch_size = input.shape[0] if len(input.shape) > 1 else 1
341

342
                # get per sample grads with ExpandedWeights objects
343
                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
344
                    sample_input, batch_size
345
                )
346
                result = run_op(op, ew_input, *ew_args, **ew_kwargs)
347
                diff_input_list = (
348
                    (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
349
                )
350
                diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
351
                diff_input_list = [
352
                    i.orig_weight if isinstance(i, ExpandedWeight) else i
353
                    for i in diff_input_list
354
                ]
355
                result.sum().backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__
356

357
    @ops(
358
        filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported
359
    )
360
    def test_expanded_weight_forward(self, device, dtype, op):
361
        sample_inputs = op.sample_inputs(device, dtype)
362
        for sample_input in supported_inputs(op, sample_inputs):
363
            if (
364
                op.name == "nn.functional.embedding"
365
            ):  # embedding flips its argument order for autograd tests
366
                sample_input = SampleInput(
367
                    sample_input.args[0].clone(),
368
                    args=(sample_input.input.clone(),),
369
                    kwargs=sample_input.kwargs,
370
                )
371
                if (
372
                    "cuda" in device
373
                    and "max_norm" in sample_input.kwargs
374
                    and "padding_idx" in sample_input.kwargs
375
                ):
376
                    self.skipTest(
377
                        "embedding is non-determinstic in this case, see issue #74679"
378
                    )
379
            batch_size = (
380
                sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
381
            )
382
            for loss_reduction in ["sum", "mean"]:
383
                (ew_input, ew_args, ew_kwargs) = make_expanded_weight(
384
                    sample_input, batch_size, loss_reduction
385
                )
386
                expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
387
                normal_result = run_op(
388
                    op, sample_input.input, *sample_input.args, **sample_input.kwargs
389
                )
390
                self.assertEqual(expanded_weight_result, normal_result)
391

392
    def test_expanded_weight_error(self, device):
393
        batch_size = 3
394
        sample_input = make_tensor(
395
            (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True
396
        )
397
        sample_weight = make_tensor(
398
            (4), dtype=torch.float32, device=device, requires_grad=True
399
        )
400
        with self.assertRaisesRegex(
401
            RuntimeError, r"Expanded Weights encountered but cannot handle function"
402
        ):
403
            torch.add(
404
                sample_input,
405
                ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"),
406
            )
407

408
    def _test_embedding_model(self, model, num_embedding, device):
409
        batch_size = 32
410
        input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
411
        return self._test_model(
412
            partial(model, num_embedding=num_embedding), batch_size, input, device
413
        )
414

415
    def _test_conv_model(
416
        self,
417
        model,
418
        input_size,
419
        num_dim,
420
        device,
421
        loss_reduction="sum",
422
        atol=1e-4,
423
        rtol=5e-5,
424
    ):
425
        batch_size = 32
426
        input_ending = [input_size] * num_dim
427
        input = torch.randn([batch_size, 3] + input_ending, device=device)
428
        return self._test_model(
429
            partial(model, num_dim=num_dim),
430
            batch_size,
431
            input,
432
            device,
433
            loss_reduction,
434
            atol,
435
            rtol,
436
        )
437

438
    def _test_model(
439
        self,
440
        model,
441
        batch_size,
442
        input,
443
        device,
444
        loss_reduction="sum",
445
        atol=1e-4,
446
        rtol=5e-5,
447
    ):
448
        model = model(10).to(device)
449
        targets = torch.randint(0, 10, (batch_size,), device=device)
450
        criterion = CrossEntropyLoss(reduction=loss_reduction)
451
        result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input)
452
        loss = criterion(result, targets)
453
        loss.backward()
454
        result = []
455
        for weight in model.parameters():
456
            result.append(weight.grad_sample)
457
            del weight.grad_sample
458

459
        expected = []
460
        for i in range(batch_size):
461
            loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
462
            expected.append(
463
                torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))
464
            )
465

466
        expected = [torch.stack(grad) for grad in zip(*expected)]
467
        for res, exp in zip(result, expected):
468
            self.assertEqual(res, exp, atol=atol, rtol=rtol)
469

470
    def _compute_tolerances(self, device):
471
        is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
472
            0
473
        ) == (8, 6)
474
        return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5)
475

476
    @tf32_off()
477
    def test_cnn_model_sum(self, device):
478
        def convnet(num_classes, num_dim):
479
            return nn.Sequential(
480
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
481
                nn.ReLU(),
482
                nn.AvgPool2d(kernel_size=2, stride=2),
483
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
484
                nn.ReLU(),
485
                nn.AvgPool2d(kernel_size=2, stride=2),
486
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
487
                nn.ReLU(),
488
                nn.AvgPool2d(kernel_size=2, stride=2),
489
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
490
                nn.ReLU(),
491
                nn.AdaptiveAvgPool2d((1, 1)),
492
                nn.Flatten(start_dim=1, end_dim=-1),
493
                nn.Linear(128, num_classes, bias=True),
494
            )
495

496
        atol, rtol = self._compute_tolerances(device)
497
        return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
498

499
    @tf32_off()
500
    def test_cnn_model_mean(self, device):
501
        def convnet(num_classes, num_dim):
502
            return nn.Sequential(
503
                nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
504
                nn.ReLU(),
505
                nn.AvgPool2d(kernel_size=2, stride=2),
506
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
507
                nn.ReLU(),
508
                nn.AvgPool2d(kernel_size=2, stride=2),
509
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
510
                nn.ReLU(),
511
                nn.AvgPool2d(kernel_size=2, stride=2),
512
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
513
                nn.ReLU(),
514
                nn.AdaptiveAvgPool2d((1, 1)),
515
                nn.Flatten(start_dim=1, end_dim=-1),
516
                nn.Linear(128, num_classes, bias=True),
517
            )
518

519
        atol, rtol = self._compute_tolerances(device)
520
        return self._test_conv_model(
521
            convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol
522
        )
523

524
    @parametrize("num_dim", [1, 2, 3])
525
    @tf32_off()
526
    def test_instance_norm_model(self, num_dim, device):
527
        def instance_norm_model(num_classes, num_dim):
528
            conv_layer = (
529
                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
530
            )
531
            norm_layer = (
532
                nn.InstanceNorm1d
533
                if num_dim == 1
534
                else nn.InstanceNorm2d
535
                if num_dim == 2
536
                else nn.InstanceNorm3d
537
            )
538
            return nn.Sequential(
539
                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
540
                norm_layer(32, affine=True),
541
                nn.Flatten(start_dim=1, end_dim=-1),
542
                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
543
            )
544

545
        atol, rtol = self._compute_tolerances(device)
546
        return self._test_conv_model(
547
            instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
548
        )
549

550
    @parametrize("num_dim", [1, 2, 3])
551
    @tf32_off()
552
    def test_group_norm_model(self, num_dim, device):
553
        def group_norm_model(num_classes, num_dim):
554
            conv_layer = (
555
                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
556
            )
557
            return nn.Sequential(
558
                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
559
                nn.GroupNorm(8, 32, affine=True),
560
                nn.Flatten(start_dim=1, end_dim=-1),
561
                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
562
            )
563

564
        atol, rtol = self._compute_tolerances(device)
565
        return self._test_conv_model(
566
            group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
567
        )
568

569
    @parametrize("num_dim", [1, 2, 3])
570
    @tf32_off()
571
    def test_layer_norm_model(self, num_dim, device):
572
        def layer_norm_model(num_classes, num_dim):
573
            conv_layer = (
574
                nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
575
            )
576
            normalized_shape = [7] * num_dim
577
            return nn.Sequential(
578
                conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
579
                nn.LayerNorm(normalized_shape, elementwise_affine=True),
580
                nn.Flatten(start_dim=1, end_dim=-1),
581
                nn.Linear(32 * (7**num_dim), num_classes, bias=True),
582
            )
583

584
        atol, rtol = self._compute_tolerances(device)
585
        return self._test_conv_model(
586
            layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
587
        )
588

589
    def test_embedding_model(self, device):
590
        def embedding_model(num_classes, num_embedding):
591
            return nn.Sequential(
592
                nn.Embedding(num_embedding, 15),
593
                nn.Flatten(start_dim=1, end_dim=-1),
594
                nn.Linear(375, num_classes, bias=True),
595
            )
596

597
        return self._test_embedding_model(embedding_model, 16, device)
598

599
    def test_group_norm_error(self, device):
600
        # group norm has to call native_group_norm. This checks that it hits the same errors
601
        # that normal group norm would
602

603
        N = 3
604
        C = 5
605
        inp = torch.randn(N, C)
606
        with self.assertRaisesRegex(
607
            RuntimeError, r"Expected number of channels in input to be divisible"
608
        ):
609
            F.group_norm(inp, 2)  # 5 is not divisible by 2
610

611

612
class TestExpandedWeightModule(TestCase):
613
    def _do_test(
614
        self,
615
        module,
616
        input,
617
        args=None,
618
        kwargs=None,
619
        batch_first=True,
620
        atol=None,
621
        rtol=None,
622
    ):
623
        args = args or ()
624
        kwargs = kwargs or {}
625

626
        batch_dim = 0 if batch_first else 1
627
        batch_size = input.shape[batch_dim]
628
        diff_input = input.dtype == torch.float or input.dtype == torch.double
629
        if diff_input:
630
            input.requires_grad_()
631

632
        with freeze_rng_state():
633
            # get per sample grads with ExpandedWeights context manager
634
            actual_res = call_for_per_sample_grads(
635
                module,
636
                batch_size=batch_size,
637
                loss_reduction="sum",
638
                batch_first=batch_first,
639
            )(input, *args, **kwargs).sum()
640
            actual_res.backward()
641
            actual_grads = []
642
            for param in module.parameters():
643
                actual_grads.append(param.grad_sample)
644
                del param.grad_sample
645
            if diff_input:
646
                actual_grads.append(input.grad.clone())
647
                input.grad = torch.zeros_like(input.grad)
648

649
            # get per sample grads with a for loop
650
            expected_res = torch.tensor(
651
                0.0, device=input.device, dtype=actual_res.dtype
652
            )
653
            expected_grads = []
654
            for i in range(batch_size):
655
                input_slice = input.narrow(batch_dim, i, 1)
656
                input_slice = input_slice.squeeze(batch_dim)
657

658
                # h's batch dim is always the first dim. Must be contiguous for CUDA
659
                sliced_args = tree_map_only(
660
                    torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args
661
                )
662
                diff_params = module.parameters()
663
                if diff_input:
664
                    diff_params = chain(diff_params, (input_slice,))
665
                res = module(
666
                    input_slice.unsqueeze(batch_dim).contiguous(),
667
                    *sliced_args,
668
                    **kwargs,
669
                ).sum()
670
                out_grads = torch.autograd.grad(
671
                    res, diff_params, torch.ones_like(res), allow_unused=True
672
                )
673
                expected_grads.append(out_grads)
674
                expected_res += res
675
            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
676
            if not batch_first:
677
                expected_grads[-1] = expected_grads[-1].transpose(0, 1)
678
        self.assertEqual(actual_res, expected_res)
679
        [
680
            self.assertEqual(actual, expected, atol=atol, rtol=rtol)
681
            for (actual, expected) in zip(actual_grads, expected_grads)
682
        ]
683

684
    def _do_test_multi_input(self, module, input):
685
        class TestModule(nn.Module):
686
            def __init__(self, module):
687
                super().__init__()
688
                self.module = module
689

690
            def forward(self, input):
691
                return self.module(input) + self.module(input)
692

693
        batch_size = input.shape[0]
694
        diff_input = input.dtype == torch.float or input.dtype == torch.double
695
        if diff_input:
696
            input.requires_grad_()
697
        with freeze_rng_state():
698
            # get per sample grads with ExpandedWeights context manager, calling .backward() twice
699
            test_module = TestModule(module)
700
            actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(
701
                input
702
            ).sum()
703
            actual_res.backward()
704
            actual_grads = []
705
            for param in module.parameters():
706
                actual_grads.append(param.grad_sample)
707
                del param.grad_sample
708
            if diff_input:
709
                actual_grads.append(input.grad.clone())
710
                input.grad = torch.zeros_like(input.grad)
711

712
            # get per sample grads with a for loop, running over the input twice
713
            expected_grads = []
714
            for i in range(batch_size):
715
                input_slice = input[i]
716
                diff_params = module.parameters()
717
                if diff_input:
718
                    diff_params = chain(diff_params, (input_slice,))
719
                res = module(input_slice.unsqueeze(0)).sum()
720
                out_grads = torch.autograd.grad(
721
                    res, diff_params, torch.ones_like(res), allow_unused=True
722
                )
723
                expected_grads.append(out_grads)
724
        expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
725
        expected_grads = tuple(
726
            expected_grad
727
            for expected_grad in expected_grads
728
            if expected_grad is not None
729
        )
730
        assert [
731
            self.assertEqual(actual, 2 * expected)
732
            for (actual, expected) in zip(actual_grads, expected_grads)
733
        ]
734

735
    def _do_test_rnn_packed_sequence(
736
        self, module, input, args=None, kwargs=None, atol=None, rtol=None
737
    ):
738
        args = args if args is not None else ()
739
        kwargs = kwargs if kwargs is not None else {}
740

741
        batch_size = max(tuple(input.batch_sizes)).item()
742

743
        with freeze_rng_state():
744
            # get per sample grads with ExpandedWeights context manager
745
            actual_res = call_for_per_sample_grads(
746
                module, batch_size=batch_size, loss_reduction="sum"
747
            )(input, *args, **kwargs).data.sum()
748
            actual_res.backward()
749
            actual_grads = []
750
            for param in module.parameters():
751
                self.assertEqual(param.grad_sample.shape[0], batch_size)
752
                actual_grads.append(param.grad_sample)
753
                del param.grad_sample
754

755
            input.data.grad = torch.zeros_like(input.data)
756

757
            # compute the per sample grads with a for loop
758
            expected_res = torch.zeros_like(actual_res)
759
            expected_grads = []
760
            padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(
761
                input, batch_first=True
762
            )
763
            for i in range(len(seq_sizes)):
764
                input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
765
                diff_params = module.parameters()
766
                batch_dim = 0 if module.m.batch_first else 1
767
                res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
768
                expected_res += res
769
                out_grads = torch.autograd.grad(
770
                    res, diff_params, torch.ones_like(res), allow_unused=True
771
                )
772
                expected_grads.append(out_grads)
773

774
            expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
775
            self.assertEqual(actual_res, expected_res)
776
            [
777
                self.assertEqual(actual, expected, atol=atol, rtol=rtol)
778
                for (actual, expected) in zip(actual_grads, expected_grads)
779
            ]
780

781
    @modules(
782
        filter(
783
            lambda m_info: m_info.module_cls
784
            in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
785
            module_db,
786
        )
787
    )
788
    @tf32_off()
789
    def test_module(self, device, dtype, module_info, training):
790
        class RNNWrapper(torch.nn.Module):
791
            def __init__(self, m_cons, args, kwargs):
792
                super().__init__()
793
                self.m = m_cons(*args, **kwargs)
794

795
            def forward(self, *inps):
796
                ret = self.m(*inps)
797
                assert isinstance(ret, tuple)
798
                return ret[0]
799

800
        def batch_hidden(h):
801
            new_h_shape = [1] * (len(h.shape) + 1)
802
            new_h_shape[1] = 2
803
            return h.unsqueeze(1).repeat(new_h_shape)
804

805
        module_cls = module_info.module_cls
806
        atol, rtol = (
807
            (1e-4, 1e-5)
808
            if module_cls == torch.nn.GRU and dtype == torch.float32
809
            else (None, None)
810
        )
811
        module_inputs = module_info.module_inputs_func(
812
            module_info,
813
            device=device,
814
            dtype=dtype,
815
            requires_grad=True,
816
            training=training,
817
            with_packed_sequence=True,
818
        )
819
        for module_input in module_inputs:
820
            if module_input.forward_input is None:
821
                continue
822
            args, kwargs = (
823
                module_input.constructor_input.args,
824
                module_input.constructor_input.kwargs,
825
            )
826
            m = RNNWrapper(module_cls, args, kwargs)
827
            batch_first = m.m.batch_first
828
            m.to(device).to(dtype)
829

830
            args, kwargs = (
831
                module_input.forward_input.args,
832
                module_input.forward_input.kwargs,
833
            )
834

835
            # if the RNN tests use unbatched inputs--batch the inputs
836
            input = args[0]
837
            if isinstance(input, torch.Tensor) and input.dim() == 2:
838
                input = input.detach()
839
                new_input_shape = [1] * (len(input.shape) + 1)
840
                if batch_first:
841
                    new_input_shape[0] = 2
842
                    input = input.repeat(new_input_shape)
843
                else:
844
                    new_input_shape[1] = 2
845
                    input = input.unsqueeze(1).repeat(new_input_shape)
846

847
                h = args[1] if len(args) > 1 else None
848
                if h is not None:
849
                    h = (
850
                        batch_hidden(h)
851
                        if isinstance(h, torch.Tensor)
852
                        else tuple(batch_hidden(hx) for hx in h)
853
                    )
854
                    args = list(args)
855
                    args[1] = h
856

857
            if isinstance(input, torch.nn.utils.rnn.PackedSequence):
858
                self._do_test_rnn_packed_sequence(
859
                    m, input, args[1:], kwargs, atol=atol, rtol=rtol
860
                )
861
            else:
862
                self._do_test(
863
                    m,
864
                    input,
865
                    args[1:],
866
                    kwargs,
867
                    batch_first=batch_first,
868
                    atol=atol,
869
                    rtol=rtol,
870
                )
871

872
    def test_per_sample_api_failing(self):
873
        module = nn.Linear(10, 10)
874
        input = torch.randn(64, 10)
875
        with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
876
            call_for_per_sample_grads("fail")(input)
877
        with self.assertRaisesRegex(
878
            RuntimeError, r"Batch size passed must be None or an integer"
879
        ):
880
            call_for_per_sample_grads(module, batch_size=6.4)(input)
881
        with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
882
            call_for_per_sample_grads(module, batch_size=-64)(input)
883
        with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
884
            loss = call_for_per_sample_grads(module)(input).sum()
885
            loss.backward()  # populate grad_sample fields
886
            call_for_per_sample_grads(module)(input)
887

888
        module = nn.Linear(10, 10)  # reset to not have grad_sample fields
889
        with self.assertRaisesRegex(
890
            RuntimeError, r"Expected loss_reduction argument to be sum or mean"
891
        ):
892
            call_for_per_sample_grads(module, loss_reduction="")(input)
893

894
    def test_per_sample_api_compute_batch_size(self):
895
        class CustomModule(nn.Module):
896
            def __init__(self) -> None:
897
                super().__init__()
898
                self.linear = nn.Linear(5, 5)
899

900
            def forward(self, input1, input2):
901
                return self.linear(input1) + self.linear(input2)
902

903
        module = CustomModule()
904
        input1 = torch.randn(4, 5)
905
        input2 = torch.randn(5, 5)
906

907
        with self.assertRaisesRegex(
908
            RuntimeError,
909
            "found at least one input with batch size 4 and one with batch size 5",
910
        ):
911
            call_for_per_sample_grads(module)(input1, input2)
912

913
        input2 = torch.randn(4, 5)
914
        call_for_per_sample_grads(module)(input1, input2)
915

916
        module = CustomModule()
917
        call_for_per_sample_grads(module)(input1, input2=input2)
918

919
        module = CustomModule()
920
        call_for_per_sample_grads(module)(input1=input1, input2=input2)
921

922
    def test_per_sample_api_compute_batch_size_not_pytreeable(self):
923
        @dataclass
924
        class NonPytreeableTuple:
925
            elem1: torch.Tensor
926
            elem2: torch.Tensor
927

928
        class CustomModule(nn.Module):
929
            def __init__(self) -> None:
930
                super().__init__()
931
                self.linear = nn.Linear(5, 5)
932

933
            def forward(self, input1, input2):
934
                return self.linear(input1.elem1) + self.linear(input1.elem2)
935

936
        input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
937
        model = CustomModule()
938
        with self.assertRaisesRegex(
939
            RuntimeError,
940
            "ExpandedWeights cannot compute the batch size from the inputs",
941
        ):
942
            call_for_per_sample_grads(model)(input, "")
943

944
        # would prefer for it to error because input is not pytree-able but that's hard to detect
945
        with self.assertRaisesRegex(
946
            RuntimeError, "Expected ExpandedWeights to have batch size matching input"
947
        ):
948
            call_for_per_sample_grads(model)(input, torch.randn(5))
949

950
        model = CustomModule()  # TODO: functional call bug, sam will fix
951
        call_for_per_sample_grads(model)(input, torch.randn(4, 5))
952
        model = CustomModule()
953
        call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
954

955

956
class ContextManagerTests(TestBase):
957
    def __init__(self, *args, **kwargs):
958
        self.test_cpu = kwargs.get("test_cpu", True)
959
        self.test_cuda = kwargs.get("test_cuda", True)
960
        super().__init__(*args, **kwargs)
961

962
    @property
963
    def constructor_args(self):
964
        return self._get_arg("constructor_args", False)
965

966
    def test_context_manager(self, test_case, device):
967
        kwargs = {"device": device, "dtype": torch.double}
968
        module = self.constructor(*self.constructor_args).to(**kwargs)
969
        if "Embedding" in self.get_name():
970
            kwargs["dtype"] = torch.long
971
        input = self._get_input().to(**kwargs)
972
        if len(input.shape) == 0 or input.shape[0] == 0:
973
            raise unittest.SkipTest(
974
                "Can't get per sample gradients when no batch dim or batch dim is 0"
975
            )
976
        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
977
            raise unittest.SkipTest(
978
                "Can't get per sample gradients for input of rank 1"
979
            )
980
        test_case._do_test(module, input)
981

982
    def test_context_manager_multiple_inputs(self, test_case, device):
983
        module = self.constructor(*self.constructor_args).to(device)
984
        input = self._get_input()
985
        if len(input.shape) == 0 or input.shape[0] == 0:
986
            raise unittest.SkipTest(
987
                "Can't get per sample gradients when no batch dim or batch dim is 0"
988
            )
989
        if self.constructor == torch.nn.Linear and len(input.shape) == 1:
990
            raise unittest.SkipTest(
991
                "Can't get per sample gradients for input of rank 1"
992
            )
993
        test_case._do_test_multi_input(module, input)
994

995

996
def filter_supported_tests(t):
997
    supported_modules = [
998
        "Linear",
999
        "Conv1d",
1000
        "Conv2d",
1001
        "Conv3d",
1002
        "Embedding",
1003
        "LayerNorm",
1004
        "GroupNorm",
1005
        "InstanceNorm",
1006
    ]
1007
    if "module_name" in t and t["module_name"] in supported_modules:
1008
        return True
1009

1010

1011
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
1012
# These currently use the legacy nn tests
1013
supported_tests = [
1014
    t for t in module_tests + new_module_tests if filter_supported_tests(t)
1015
]
1016
for test_param in supported_tests:
1017
    if "constructor" not in test_param:
1018
        name = test_param.pop("module_name")
1019
        test_param["constructor"] = getattr(nn, name)
1020
    decorator = test_param.pop("decorator", lambda test: test)
1021
    test = ContextManagerTests(**test_param)
1022
    test_name = test.get_name()
1023
    if hasattr(TestExpandedWeightModule, test_name):
1024
        raise RuntimeError("Found two tests with the same name: " + test_name)
1025
    test_name_multi_input = test.get_name() + "_multiple_inputs"
1026
    if hasattr(TestExpandedWeightModule, test_name_multi_input):
1027
        raise RuntimeError("Found two tests with the same name: " + test_name)
1028
    if test.test_cpu:
1029
        setattr(
1030
            TestExpandedWeightModule,
1031
            test_name,
1032
            decorator(lambda self, test=test: test.test_context_manager(self, "cpu")),
1033
        )
1034
        setattr(
1035
            TestExpandedWeightModule,
1036
            test_name_multi_input,
1037
            decorator(
1038
                lambda self, test=test: test.test_context_manager_multiple_inputs(
1039
                    self, "cpu"
1040
                )
1041
            ),
1042
        )
1043
    if TEST_CUDA and test.test_cuda:
1044
        # since this checks derivatives, only use double for precision
1045
        setattr(
1046
            TestExpandedWeightModule,
1047
            test_name + "_cuda_double",
1048
            decorator(lambda self, test=test: test.test_context_manager(self, "cuda")),
1049
        )
1050

1051
# ------------- HELPER FUNCTIONS -----------------
1052

1053

1054
def run_op(op, input, *args, **kwargs):
1055
    r"""
1056
    OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
1057
    of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
1058
    using the special ordering that Embedding's OpInfo expects for that case.
1059
    """
1060
    if op.name == "nn.functional.embedding":
1061
        return op(args[0], input, **kwargs)
1062
    else:
1063
        return op(input, *args, **kwargs)
1064

1065

1066
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
1067
    def expanded_weight_or_clone(arg):
1068
        if is_diff_tensor(arg):
1069
            return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
1070
        return clone_if_tensor(arg)
1071

1072
    ew_input = clone_if_tensor(sample_input.input)
1073
    ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
1074
    ew_kwargs = {
1075
        name: expanded_weight_or_clone(arg)
1076
        for (name, arg) in sample_input.kwargs.items()
1077
    }
1078
    return ew_input, ew_args, ew_kwargs
1079

1080

1081
def supported_inputs(op, sample_inputs, supported_inputs=True):
1082
    r"""
1083
    ExpandedWeights currently does not support some use cases when there's no batch dimension or
1084
    operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
1085
    """
1086

1087
    def filter_fn(input):
1088
        convolutions = [
1089
            "nn.functional.conv1d",
1090
            "nn.functional.conv2d",
1091
            "nn.functional.conv3d",
1092
        ]
1093
        batched_input_size = dict(zip(convolutions, [3, 4, 5]))
1094
        if op.name == "nn.functional.linear":
1095
            is_supported_input = (
1096
                input.input.dim() > 1
1097
            )  # input of rank 1 means no batch dim
1098
        elif op.name == "nn.functional.layer_norm":
1099
            normalized_shape = input.args[0]
1100
            is_supported_input = (
1101
                input.input.shape != normalized_shape
1102
            )  # would cause inter-batch operations
1103
        elif op.name in convolutions:
1104
            # currently can't deal with padding computation on Python level
1105
            is_supported_input = input.input.dim() == batched_input_size[op.name]
1106
        elif op.name == "nn.functional.embedding":
1107
            idx = input.args[0]
1108
            is_supported_input = len(idx.shape) > 1  # there's no batch size
1109
        else:
1110
            is_supported_input = True
1111
        is_supported_input = (
1112
            is_supported_input and input.input.shape[0] > 0
1113
        )  # 0 is not a valid batch size
1114
        return is_supported_input if supported_inputs else not is_supported_input
1115

1116
    return [input for input in sample_inputs if filter_fn(input)]
1117

1118

1119
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
1120
    # get per sample grads by getting derivative for each input in a for loop
1121
    per_sample_grad = []
1122
    for i in range(batch_size):
1123
        per_sample_input = input[i]
1124
        result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
1125
        diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
1126
        diff_input_list = [
1127
            i
1128
            for i in diff_input_list
1129
            if isinstance(i, torch.Tensor) and i.requires_grad
1130
        ]
1131
        per_sample_grad.append(
1132
            torch.autograd.grad(
1133
                result, diff_input_list, torch.ones_like(result), allow_unused=True
1134
            )
1135
        )
1136
    if len(per_sample_grad) == batch_size:
1137
        per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
1138
    return per_sample_grad
1139

1140

1141
def is_diff_tensor(t):
1142
    return isinstance(t, ExpandedWeight) or (
1143
        isinstance(t, torch.Tensor) and t.requires_grad
1144
    )
1145

1146

1147
def clone_if_tensor(t):
1148
    if isinstance(t, torch.Tensor):
1149
        res = torch.clone(t).detach()
1150
        res.requires_grad_(t.requires_grad)
1151
        return res
1152
    else:
1153
        return t
1154

1155

1156
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
1157
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
1158
instantiate_device_type_tests(TestExpandedWeightModule, globals())
1159
if __name__ == "__main__":
1160
    run_tests()
1161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.