3
from dataclasses import dataclass
4
from functools import partial
5
from itertools import chain, product
9
import torch.nn.functional as F
10
from torch.nn import CrossEntropyLoss
11
from torch.nn.utils._expanded_weights import ExpandedWeight
12
from torch.nn.utils._expanded_weights.expanded_weights_utils import (
14
set_grad_sample_if_exists,
16
sum_over_all_but_batch_and_last_n,
17
unpack_expanded_weight_or_tensor,
19
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads
20
from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off
21
from torch.testing._internal.common_device_type import (
22
instantiate_device_type_tests,
26
from torch.testing._internal.common_methods_invocations import op_db, SampleInput
27
from torch.testing._internal.common_modules import module_db, modules
28
from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase
29
from torch.testing._internal.common_utils import (
37
from torch.utils._pytree import tree_map_only
44
class TestExpandedWeightHelperFunction(TestCase):
45
def test_forward_helper(self, device):
46
input = torch.randn(3, 4, device=device)
47
weight = torch.randn(5, 4, device=device)
48
bias = torch.randn(5, device=device)
49
for weight_batched, bias_batched in product([True, False], [True, False]):
50
maybe_batched_weight = weight
51
maybe_batched_bias = bias
53
maybe_batched_weight = ExpandedWeight(
54
weight.clone().requires_grad_(), 3, loss_reduction="sum"
57
maybe_batched_bias = ExpandedWeight(
58
bias.clone().requires_grad_(), 3, loss_reduction="sum"
60
args = (input, maybe_batched_weight, maybe_batched_bias)
61
expanded_args, expanded_kwargs = standard_kwargs(("bias",), args)
62
res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
63
expected = nn.functional.linear(input, weight, bias)
64
self.assertEqual(res, expected)
66
self.assertEqual(len(expanded_args), 2)
67
assert expanded_args[0] is args[0]
68
assert expanded_args[1] is args[1]
69
self.assertEqual(len(expanded_kwargs), 1)
71
expanded_kwargs["bias"] is args[2]
74
def test_forward_helper_failure_args(self, device):
75
weight = torch.randn(5, 4, device=device)
76
bias = torch.randn(5, device=device)
77
with self.assertRaisesRegex(
78
RuntimeError, r"do not support inputs that are also ExpandedWeights."
80
input = ExpandedWeight(
81
torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum"
83
expanded_args, expanded_kwargs = standard_kwargs(
84
("bias",), (input, weight, bias)
86
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
87
with self.assertRaisesRegex(
88
RuntimeError, r"requires a Tensor as the first input"
90
expanded_args, expanded_kwargs = standard_kwargs(
91
("bias",), (3, weight, bias)
93
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
94
with self.assertRaisesRegex(
95
RuntimeError, r"requires a batch dimension but got an input of size 0"
97
expanded_args, expanded_kwargs = standard_kwargs(
98
("bias",), (torch.tensor(3), weight, bias)
100
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
101
with self.assertRaisesRegex(
102
RuntimeError, r"0 is not a valid batch size for Expanded Weights"
104
expanded_args, expanded_kwargs = standard_kwargs(
105
("bias",), (torch.randn(0, 1, 2), weight, bias)
107
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
108
input = torch.randn(3, 4)
109
for weight_batched, bias_batched in product([True, False], [True, False]):
110
if not weight_batched and not bias_batched:
112
maybe_batched_weight = weight
113
maybe_batched_bias = bias
115
maybe_batched_weight = ExpandedWeight(
116
weight.clone().requires_grad_(), 4, loss_reduction="sum"
119
maybe_batched_bias = ExpandedWeight(
120
bias.clone().requires_grad_(), 4, loss_reduction="sum"
122
with self.assertRaisesRegex(
124
r"Expected ExpandedWeights to have batch size matching input",
126
expanded_args, expanded_kwargs = standard_kwargs(
127
("bias",), (input, maybe_batched_weight, maybe_batched_bias)
129
forward_helper(nn.functional.linear, expanded_args, expanded_kwargs)
131
def test_set_grad_sample_if_exists(self, device):
135
orig_weight = torch.randn(4, device=device, requires_grad=True)
136
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
137
grad_sample = torch.randn(3)
138
set_grad_sample_if_exists(expanded_weight, test_fn)
139
self.assertTrue(hasattr(orig_weight, "grad_sample"))
140
self.assertEqual(orig_weight.grad_sample, grad_sample)
142
basic_tensor = torch.randn(4, device=device)
143
set_grad_sample_if_exists(basic_tensor, test_fn)
144
self.assertFalse(hasattr(basic_tensor, "grad_sample"))
147
set_grad_sample_if_exists(non_tensor, test_fn)
148
self.assertFalse(hasattr(non_tensor, "grad_sample"))
150
def test_set_grad_sample_if_exists_failure(self, device):
154
grad_tensor = torch.randn(4, requires_grad=True, device=device)
155
with self.assertRaisesRegex(
157
r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
159
set_grad_sample_if_exists(grad_tensor, test_fn)
161
def test_unpack_expanded_weight_or_tensor(self, device):
162
input = torch.randn(3, requires_grad=True, device=device)
165
unpack_expanded_weight_or_tensor(
166
ExpandedWeight(input, 3, loss_reduction="sum")
170
input.requires_grad_(False)
171
self.assertEqual(input, unpack_expanded_weight_or_tensor(input))
172
self.assertTrue(unpack_expanded_weight_or_tensor(4) is None)
174
def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device):
175
input = torch.randn(3, requires_grad=True, device=device)
177
unpack_expanded_weight_or_tensor(
178
ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input
182
input.requires_grad_(False)
183
self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input))
185
unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None
188
def test_unpack_expanded_weight_or_tensor_failure(self, device):
189
input = torch.randn(3, requires_grad=True, device=device)
190
with self.assertRaisesRegex(
192
r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
194
unpack_expanded_weight_or_tensor(input)
196
with self.assertRaisesRegex(
198
r"does not support a mixture of ExpandedWeight parameters and normal Parameters",
200
unpack_expanded_weight_or_tensor(input, lambda x: x is input)
202
def test_sum_over_all_but_batch_and_last_n(self, device):
203
input = torch.randn(1, 2, 3, 4, 5, device=device)
204
res = sum_over_all_but_batch_and_last_n(input, 2)
205
expected = input.sum((1, 2))
206
self.assertEqual(res, expected)
208
res = sum_over_all_but_batch_and_last_n(input, 0)
209
expected = input.sum((1, 2, 3, 4))
210
self.assertEqual(res, expected)
212
res = sum_over_all_but_batch_and_last_n(input, 4)
213
self.assertEqual(res, input)
216
class TestExpandedWeightFunctional(TestCase):
217
def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction):
218
input = sample_input.input
219
args = sample_input.args
220
kwargs = sample_input.kwargs
221
batch_size = input.shape[0] if len(input.shape) > 1 else 1
224
loss_reduction = "sum" if reduction == torch.sum else "mean"
225
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(
226
sample_input, batch_size, loss_reduction
228
diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
229
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
231
i.orig_weight if isinstance(i, ExpandedWeight) else i
232
for i in diff_input_list
234
if not diff_input_list:
236
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
240
expanded_weight_grad = tuple(
241
i.grad_sample if hasattr(i, "grad_sample") else i.grad
242
for i in diff_input_list
246
func = partial(run_op, op)
248
per_sample_grad = for_loop_per_sample_grad(
249
batch_size, reduction, input, func, *args, **kwargs
253
self.assertEqual(len(per_sample_grad), len(expanded_weight_grad))
254
if loss_reduction == "mean":
256
expanded_weight_grad = expanded_weight_grad[1:]
257
per_sample_grad = per_sample_grad[1:]
258
for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad):
259
self.assertEqual(result_grad, expected_grad)
262
filter(lambda op: op.supports_expanded_weight, op_db),
263
dtypes=OpDTypes.supported,
264
allowed_dtypes=(torch.double,),
266
def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op):
267
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
268
for sample_input in supported_inputs(op, sample_inputs):
270
op.name == "nn.functional.embedding"
272
sample_input = SampleInput(
273
sample_input.args[0],
274
args=(sample_input.input,),
275
kwargs=sample_input.kwargs,
278
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum)
281
filter(lambda op: op.supports_expanded_weight, op_db),
282
dtypes=OpDTypes.supported,
283
allowed_dtypes=(torch.double,),
285
def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op):
286
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
287
for sample_input in supported_inputs(op, sample_inputs):
289
op.name == "nn.functional.embedding"
291
sample_input = SampleInput(
292
sample_input.args[0],
293
args=(sample_input.input,),
294
kwargs=sample_input.kwargs,
297
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
300
filter(lambda op: op.supports_expanded_weight, op_db),
301
dtypes=OpDTypes.supported,
302
allowed_dtypes=(torch.double,),
304
def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
305
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
306
for sample_input in supported_inputs(op, sample_inputs):
308
op.name == "nn.functional.embedding"
310
sample_input = SampleInput(
311
sample_input.args[0],
312
args=(sample_input.input,),
313
kwargs=sample_input.kwargs,
315
sample_input.input.requires_grad_(False)
317
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
319
@skipIfTorchDynamo("Checking error message doesn't work with dynamo")
321
filter(lambda op: op.supports_expanded_weight, op_db),
322
dtypes=OpDTypes.supported,
323
allowed_dtypes=(torch.double,),
325
def test_unsupported_expand_weights(self, device, dtype, op):
326
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
327
unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False)
328
for sample_input in unsupported_inputs:
329
with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"):
331
op.name == "nn.functional.embedding"
333
sample_input = SampleInput(
334
sample_input.args[0],
335
args=(sample_input.input,),
336
kwargs=sample_input.kwargs,
338
input = sample_input.input
340
batch_size = input.shape[0] if len(input.shape) > 1 else 1
343
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(
344
sample_input, batch_size
346
result = run_op(op, ew_input, *ew_args, **ew_kwargs)
348
(ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values())
350
diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)]
352
i.orig_weight if isinstance(i, ExpandedWeight) else i
353
for i in diff_input_list
355
result.sum().backward()
358
filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported
360
def test_expanded_weight_forward(self, device, dtype, op):
361
sample_inputs = op.sample_inputs(device, dtype)
362
for sample_input in supported_inputs(op, sample_inputs):
364
op.name == "nn.functional.embedding"
366
sample_input = SampleInput(
367
sample_input.args[0].clone(),
368
args=(sample_input.input.clone(),),
369
kwargs=sample_input.kwargs,
373
and "max_norm" in sample_input.kwargs
374
and "padding_idx" in sample_input.kwargs
377
"embedding is non-determinstic in this case, see issue #74679"
380
sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1
382
for loss_reduction in ["sum", "mean"]:
383
(ew_input, ew_args, ew_kwargs) = make_expanded_weight(
384
sample_input, batch_size, loss_reduction
386
expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs)
387
normal_result = run_op(
388
op, sample_input.input, *sample_input.args, **sample_input.kwargs
390
self.assertEqual(expanded_weight_result, normal_result)
392
def test_expanded_weight_error(self, device):
394
sample_input = make_tensor(
395
(batch_size, 4), dtype=torch.float32, device=device, requires_grad=True
397
sample_weight = make_tensor(
398
(4), dtype=torch.float32, device=device, requires_grad=True
400
with self.assertRaisesRegex(
401
RuntimeError, r"Expanded Weights encountered but cannot handle function"
405
ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"),
408
def _test_embedding_model(self, model, num_embedding, device):
410
input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device)
411
return self._test_model(
412
partial(model, num_embedding=num_embedding), batch_size, input, device
415
def _test_conv_model(
421
loss_reduction="sum",
426
input_ending = [input_size] * num_dim
427
input = torch.randn([batch_size, 3] + input_ending, device=device)
428
return self._test_model(
429
partial(model, num_dim=num_dim),
444
loss_reduction="sum",
448
model = model(10).to(device)
449
targets = torch.randint(0, 10, (batch_size,), device=device)
450
criterion = CrossEntropyLoss(reduction=loss_reduction)
451
result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input)
452
loss = criterion(result, targets)
455
for weight in model.parameters():
456
result.append(weight.grad_sample)
457
del weight.grad_sample
460
for i in range(batch_size):
461
loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0))
463
torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))
466
expected = [torch.stack(grad) for grad in zip(*expected)]
467
for res, exp in zip(result, expected):
468
self.assertEqual(res, exp, atol=atol, rtol=rtol)
470
def _compute_tolerances(self, device):
471
is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(
474
return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5)
477
def test_cnn_model_sum(self, device):
478
def convnet(num_classes, num_dim):
479
return nn.Sequential(
480
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
482
nn.AvgPool2d(kernel_size=2, stride=2),
483
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
485
nn.AvgPool2d(kernel_size=2, stride=2),
486
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
488
nn.AvgPool2d(kernel_size=2, stride=2),
489
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
491
nn.AdaptiveAvgPool2d((1, 1)),
492
nn.Flatten(start_dim=1, end_dim=-1),
493
nn.Linear(128, num_classes, bias=True),
496
atol, rtol = self._compute_tolerances(device)
497
return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol)
500
def test_cnn_model_mean(self, device):
501
def convnet(num_classes, num_dim):
502
return nn.Sequential(
503
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
505
nn.AvgPool2d(kernel_size=2, stride=2),
506
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
508
nn.AvgPool2d(kernel_size=2, stride=2),
509
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
511
nn.AvgPool2d(kernel_size=2, stride=2),
512
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
514
nn.AdaptiveAvgPool2d((1, 1)),
515
nn.Flatten(start_dim=1, end_dim=-1),
516
nn.Linear(128, num_classes, bias=True),
519
atol, rtol = self._compute_tolerances(device)
520
return self._test_conv_model(
521
convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol
524
@parametrize("num_dim", [1, 2, 3])
526
def test_instance_norm_model(self, num_dim, device):
527
def instance_norm_model(num_classes, num_dim):
529
nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
534
else nn.InstanceNorm2d
536
else nn.InstanceNorm3d
538
return nn.Sequential(
539
conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
540
norm_layer(32, affine=True),
541
nn.Flatten(start_dim=1, end_dim=-1),
542
nn.Linear(32 * (7**num_dim), num_classes, bias=True),
545
atol, rtol = self._compute_tolerances(device)
546
return self._test_conv_model(
547
instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
550
@parametrize("num_dim", [1, 2, 3])
552
def test_group_norm_model(self, num_dim, device):
553
def group_norm_model(num_classes, num_dim):
555
nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
557
return nn.Sequential(
558
conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
559
nn.GroupNorm(8, 32, affine=True),
560
nn.Flatten(start_dim=1, end_dim=-1),
561
nn.Linear(32 * (7**num_dim), num_classes, bias=True),
564
atol, rtol = self._compute_tolerances(device)
565
return self._test_conv_model(
566
group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
569
@parametrize("num_dim", [1, 2, 3])
571
def test_layer_norm_model(self, num_dim, device):
572
def layer_norm_model(num_classes, num_dim):
574
nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d
576
normalized_shape = [7] * num_dim
577
return nn.Sequential(
578
conv_layer(3, 32, kernel_size=3, stride=1, padding=1),
579
nn.LayerNorm(normalized_shape, elementwise_affine=True),
580
nn.Flatten(start_dim=1, end_dim=-1),
581
nn.Linear(32 * (7**num_dim), num_classes, bias=True),
584
atol, rtol = self._compute_tolerances(device)
585
return self._test_conv_model(
586
layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol
589
def test_embedding_model(self, device):
590
def embedding_model(num_classes, num_embedding):
591
return nn.Sequential(
592
nn.Embedding(num_embedding, 15),
593
nn.Flatten(start_dim=1, end_dim=-1),
594
nn.Linear(375, num_classes, bias=True),
597
return self._test_embedding_model(embedding_model, 16, device)
599
def test_group_norm_error(self, device):
605
inp = torch.randn(N, C)
606
with self.assertRaisesRegex(
607
RuntimeError, r"Expected number of channels in input to be divisible"
612
class TestExpandedWeightModule(TestCase):
624
kwargs = kwargs or {}
626
batch_dim = 0 if batch_first else 1
627
batch_size = input.shape[batch_dim]
628
diff_input = input.dtype == torch.float or input.dtype == torch.double
630
input.requires_grad_()
632
with freeze_rng_state():
634
actual_res = call_for_per_sample_grads(
636
batch_size=batch_size,
637
loss_reduction="sum",
638
batch_first=batch_first,
639
)(input, *args, **kwargs).sum()
640
actual_res.backward()
642
for param in module.parameters():
643
actual_grads.append(param.grad_sample)
644
del param.grad_sample
646
actual_grads.append(input.grad.clone())
647
input.grad = torch.zeros_like(input.grad)
650
expected_res = torch.tensor(
651
0.0, device=input.device, dtype=actual_res.dtype
654
for i in range(batch_size):
655
input_slice = input.narrow(batch_dim, i, 1)
656
input_slice = input_slice.squeeze(batch_dim)
659
sliced_args = tree_map_only(
660
torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args
662
diff_params = module.parameters()
664
diff_params = chain(diff_params, (input_slice,))
666
input_slice.unsqueeze(batch_dim).contiguous(),
670
out_grads = torch.autograd.grad(
671
res, diff_params, torch.ones_like(res), allow_unused=True
673
expected_grads.append(out_grads)
675
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
677
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
678
self.assertEqual(actual_res, expected_res)
680
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
681
for (actual, expected) in zip(actual_grads, expected_grads)
684
def _do_test_multi_input(self, module, input):
685
class TestModule(nn.Module):
686
def __init__(self, module):
690
def forward(self, input):
691
return self.module(input) + self.module(input)
693
batch_size = input.shape[0]
694
diff_input = input.dtype == torch.float or input.dtype == torch.double
696
input.requires_grad_()
697
with freeze_rng_state():
699
test_module = TestModule(module)
700
actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(
703
actual_res.backward()
705
for param in module.parameters():
706
actual_grads.append(param.grad_sample)
707
del param.grad_sample
709
actual_grads.append(input.grad.clone())
710
input.grad = torch.zeros_like(input.grad)
714
for i in range(batch_size):
715
input_slice = input[i]
716
diff_params = module.parameters()
718
diff_params = chain(diff_params, (input_slice,))
719
res = module(input_slice.unsqueeze(0)).sum()
720
out_grads = torch.autograd.grad(
721
res, diff_params, torch.ones_like(res), allow_unused=True
723
expected_grads.append(out_grads)
724
expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads))
725
expected_grads = tuple(
727
for expected_grad in expected_grads
728
if expected_grad is not None
731
self.assertEqual(actual, 2 * expected)
732
for (actual, expected) in zip(actual_grads, expected_grads)
735
def _do_test_rnn_packed_sequence(
736
self, module, input, args=None, kwargs=None, atol=None, rtol=None
738
args = args if args is not None else ()
739
kwargs = kwargs if kwargs is not None else {}
741
batch_size = max(tuple(input.batch_sizes)).item()
743
with freeze_rng_state():
745
actual_res = call_for_per_sample_grads(
746
module, batch_size=batch_size, loss_reduction="sum"
747
)(input, *args, **kwargs).data.sum()
748
actual_res.backward()
750
for param in module.parameters():
751
self.assertEqual(param.grad_sample.shape[0], batch_size)
752
actual_grads.append(param.grad_sample)
753
del param.grad_sample
755
input.data.grad = torch.zeros_like(input.data)
758
expected_res = torch.zeros_like(actual_res)
760
padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(
761
input, batch_first=True
763
for i in range(len(seq_sizes)):
764
input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
765
diff_params = module.parameters()
766
batch_dim = 0 if module.m.batch_first else 1
767
res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
769
out_grads = torch.autograd.grad(
770
res, diff_params, torch.ones_like(res), allow_unused=True
772
expected_grads.append(out_grads)
774
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
775
self.assertEqual(actual_res, expected_res)
777
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
778
for (actual, expected) in zip(actual_grads, expected_grads)
783
lambda m_info: m_info.module_cls
784
in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
789
def test_module(self, device, dtype, module_info, training):
790
class RNNWrapper(torch.nn.Module):
791
def __init__(self, m_cons, args, kwargs):
793
self.m = m_cons(*args, **kwargs)
795
def forward(self, *inps):
797
assert isinstance(ret, tuple)
801
new_h_shape = [1] * (len(h.shape) + 1)
803
return h.unsqueeze(1).repeat(new_h_shape)
805
module_cls = module_info.module_cls
808
if module_cls == torch.nn.GRU and dtype == torch.float32
811
module_inputs = module_info.module_inputs_func(
817
with_packed_sequence=True,
819
for module_input in module_inputs:
820
if module_input.forward_input is None:
823
module_input.constructor_input.args,
824
module_input.constructor_input.kwargs,
826
m = RNNWrapper(module_cls, args, kwargs)
827
batch_first = m.m.batch_first
828
m.to(device).to(dtype)
831
module_input.forward_input.args,
832
module_input.forward_input.kwargs,
837
if isinstance(input, torch.Tensor) and input.dim() == 2:
838
input = input.detach()
839
new_input_shape = [1] * (len(input.shape) + 1)
841
new_input_shape[0] = 2
842
input = input.repeat(new_input_shape)
844
new_input_shape[1] = 2
845
input = input.unsqueeze(1).repeat(new_input_shape)
847
h = args[1] if len(args) > 1 else None
851
if isinstance(h, torch.Tensor)
852
else tuple(batch_hidden(hx) for hx in h)
857
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
858
self._do_test_rnn_packed_sequence(
859
m, input, args[1:], kwargs, atol=atol, rtol=rtol
867
batch_first=batch_first,
872
def test_per_sample_api_failing(self):
873
module = nn.Linear(10, 10)
874
input = torch.randn(64, 10)
875
with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"):
876
call_for_per_sample_grads("fail")(input)
877
with self.assertRaisesRegex(
878
RuntimeError, r"Batch size passed must be None or an integer"
880
call_for_per_sample_grads(module, batch_size=6.4)(input)
881
with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"):
882
call_for_per_sample_grads(module, batch_size=-64)(input)
883
with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"):
884
loss = call_for_per_sample_grads(module)(input).sum()
886
call_for_per_sample_grads(module)(input)
888
module = nn.Linear(10, 10)
889
with self.assertRaisesRegex(
890
RuntimeError, r"Expected loss_reduction argument to be sum or mean"
892
call_for_per_sample_grads(module, loss_reduction="")(input)
894
def test_per_sample_api_compute_batch_size(self):
895
class CustomModule(nn.Module):
896
def __init__(self) -> None:
898
self.linear = nn.Linear(5, 5)
900
def forward(self, input1, input2):
901
return self.linear(input1) + self.linear(input2)
903
module = CustomModule()
904
input1 = torch.randn(4, 5)
905
input2 = torch.randn(5, 5)
907
with self.assertRaisesRegex(
909
"found at least one input with batch size 4 and one with batch size 5",
911
call_for_per_sample_grads(module)(input1, input2)
913
input2 = torch.randn(4, 5)
914
call_for_per_sample_grads(module)(input1, input2)
916
module = CustomModule()
917
call_for_per_sample_grads(module)(input1, input2=input2)
919
module = CustomModule()
920
call_for_per_sample_grads(module)(input1=input1, input2=input2)
922
def test_per_sample_api_compute_batch_size_not_pytreeable(self):
924
class NonPytreeableTuple:
928
class CustomModule(nn.Module):
929
def __init__(self) -> None:
931
self.linear = nn.Linear(5, 5)
933
def forward(self, input1, input2):
934
return self.linear(input1.elem1) + self.linear(input1.elem2)
936
input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5))
937
model = CustomModule()
938
with self.assertRaisesRegex(
940
"ExpandedWeights cannot compute the batch size from the inputs",
942
call_for_per_sample_grads(model)(input, "")
945
with self.assertRaisesRegex(
946
RuntimeError, "Expected ExpandedWeights to have batch size matching input"
948
call_for_per_sample_grads(model)(input, torch.randn(5))
950
model = CustomModule()
951
call_for_per_sample_grads(model)(input, torch.randn(4, 5))
952
model = CustomModule()
953
call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5))
956
class ContextManagerTests(TestBase):
957
def __init__(self, *args, **kwargs):
958
self.test_cpu = kwargs.get("test_cpu", True)
959
self.test_cuda = kwargs.get("test_cuda", True)
960
super().__init__(*args, **kwargs)
963
def constructor_args(self):
964
return self._get_arg("constructor_args", False)
966
def test_context_manager(self, test_case, device):
967
kwargs = {"device": device, "dtype": torch.double}
968
module = self.constructor(*self.constructor_args).to(**kwargs)
969
if "Embedding" in self.get_name():
970
kwargs["dtype"] = torch.long
971
input = self._get_input().to(**kwargs)
972
if len(input.shape) == 0 or input.shape[0] == 0:
973
raise unittest.SkipTest(
974
"Can't get per sample gradients when no batch dim or batch dim is 0"
976
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
977
raise unittest.SkipTest(
978
"Can't get per sample gradients for input of rank 1"
980
test_case._do_test(module, input)
982
def test_context_manager_multiple_inputs(self, test_case, device):
983
module = self.constructor(*self.constructor_args).to(device)
984
input = self._get_input()
985
if len(input.shape) == 0 or input.shape[0] == 0:
986
raise unittest.SkipTest(
987
"Can't get per sample gradients when no batch dim or batch dim is 0"
989
if self.constructor == torch.nn.Linear and len(input.shape) == 1:
990
raise unittest.SkipTest(
991
"Can't get per sample gradients for input of rank 1"
993
test_case._do_test_multi_input(module, input)
996
def filter_supported_tests(t):
997
supported_modules = [
1007
if "module_name" in t and t["module_name"] in supported_modules:
1014
t for t in module_tests + new_module_tests if filter_supported_tests(t)
1016
for test_param in supported_tests:
1017
if "constructor" not in test_param:
1018
name = test_param.pop("module_name")
1019
test_param["constructor"] = getattr(nn, name)
1020
decorator = test_param.pop("decorator", lambda test: test)
1021
test = ContextManagerTests(**test_param)
1022
test_name = test.get_name()
1023
if hasattr(TestExpandedWeightModule, test_name):
1024
raise RuntimeError("Found two tests with the same name: " + test_name)
1025
test_name_multi_input = test.get_name() + "_multiple_inputs"
1026
if hasattr(TestExpandedWeightModule, test_name_multi_input):
1027
raise RuntimeError("Found two tests with the same name: " + test_name)
1030
TestExpandedWeightModule,
1032
decorator(lambda self, test=test: test.test_context_manager(self, "cpu")),
1035
TestExpandedWeightModule,
1036
test_name_multi_input,
1038
lambda self, test=test: test.test_context_manager_multiple_inputs(
1043
if TEST_CUDA and test.test_cuda:
1046
TestExpandedWeightModule,
1047
test_name + "_cuda_double",
1048
decorator(lambda self, test=test: test.test_context_manager(self, "cuda")),
1054
def run_op(op, input, *args, **kwargs):
1056
OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative
1057
of the weight, not the input, which can't be differentiable since its dtype is int. Calls op,
1058
using the special ordering that Embedding's OpInfo expects for that case.
1060
if op.name == "nn.functional.embedding":
1061
return op(args[0], input, **kwargs)
1063
return op(input, *args, **kwargs)
1066
def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"):
1067
def expanded_weight_or_clone(arg):
1068
if is_diff_tensor(arg):
1069
return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction)
1070
return clone_if_tensor(arg)
1072
ew_input = clone_if_tensor(sample_input.input)
1073
ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args)
1075
name: expanded_weight_or_clone(arg)
1076
for (name, arg) in sample_input.kwargs.items()
1078
return ew_input, ew_args, ew_kwargs
1081
def supported_inputs(op, sample_inputs, supported_inputs=True):
1083
ExpandedWeights currently does not support some use cases when there's no batch dimension or
1084
operations that would cause inter-batch operations. Removes all of the cases it cannot deal with
1087
def filter_fn(input):
1089
"nn.functional.conv1d",
1090
"nn.functional.conv2d",
1091
"nn.functional.conv3d",
1093
batched_input_size = dict(zip(convolutions, [3, 4, 5]))
1094
if op.name == "nn.functional.linear":
1095
is_supported_input = (
1096
input.input.dim() > 1
1098
elif op.name == "nn.functional.layer_norm":
1099
normalized_shape = input.args[0]
1100
is_supported_input = (
1101
input.input.shape != normalized_shape
1103
elif op.name in convolutions:
1105
is_supported_input = input.input.dim() == batched_input_size[op.name]
1106
elif op.name == "nn.functional.embedding":
1108
is_supported_input = len(idx.shape) > 1
1110
is_supported_input = True
1111
is_supported_input = (
1112
is_supported_input and input.input.shape[0] > 0
1114
return is_supported_input if supported_inputs else not is_supported_input
1116
return [input for input in sample_inputs if filter_fn(input)]
1119
def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs):
1121
per_sample_grad = []
1122
for i in range(batch_size):
1123
per_sample_input = input[i]
1124
result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs))
1125
diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values())
1128
for i in diff_input_list
1129
if isinstance(i, torch.Tensor) and i.requires_grad
1131
per_sample_grad.append(
1132
torch.autograd.grad(
1133
result, diff_input_list, torch.ones_like(result), allow_unused=True
1136
if len(per_sample_grad) == batch_size:
1137
per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad))
1138
return per_sample_grad
1141
def is_diff_tensor(t):
1142
return isinstance(t, ExpandedWeight) or (
1143
isinstance(t, torch.Tensor) and t.requires_grad
1147
def clone_if_tensor(t):
1148
if isinstance(t, torch.Tensor):
1149
res = torch.clone(t).detach()
1150
res.requires_grad_(t.requires_grad)
1156
instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals())
1157
instantiate_device_type_tests(TestExpandedWeightFunctional, globals())
1158
instantiate_device_type_tests(TestExpandedWeightModule, globals())
1159
if __name__ == "__main__":