1
# Owner(s): ["module: unknown"]
7
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfTorchDynamo
8
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
9
from torch.utils._python_dispatch import TorchDispatchMode
11
class TestAutocastCPU(TestCase):
14
self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
17
del self.autocast_lists
20
def _run_autocast_outofplace(
28
amp_dtype=torch.bfloat16,
31
def cast(val, to_type):
32
if isinstance(val, torch.Tensor):
33
return val.to(to_type) if val.is_floating_point() else val
34
elif isinstance(val, collections.abc.Iterable):
35
return type(val)(cast(v, to_type) for v in val)
39
if add_kwargs is None:
42
self.assertFalse(torch.is_autocast_cpu_enabled())
43
with torch.cpu.amp.autocast(dtype=amp_dtype):
44
self.assertTrue(torch.is_autocast_cpu_enabled())
45
out_type = out_type if out_type is not None else run_as_type
46
output = output_method = None
48
# Try module.* variant, if requested:
49
if module is not None and hasattr(module, op):
50
output = getattr(module, op)(*args, **add_kwargs)
51
if isinstance(output, torch.Tensor):
52
self.assertTrue(out_type == output.dtype,
53
f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}")
54
# Try Tensor.* variant:
55
if hasattr(torch.Tensor, op):
56
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
57
if isinstance(output_method, torch.Tensor):
58
self.assertTrue(out_type == output_method.dtype,
59
"autocast for torch.{} produced {}, should produce torch.{}"
60
.format(op, output_method.dtype, out_type))
62
self.assertTrue((output is not None) or (output_method is not None),
63
f"{op} not found as an attribute on either Tensor or the requested module {module}")
65
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
66
# For example, lstm_cell returns a tuple and equal returns bool.
67
def compare(first, second):
68
if isinstance(first, torch.Tensor):
69
return torch.equal(first, second)
70
elif isinstance(first, collections.abc.Iterable):
71
return all(compare(f, s) for f, s in zip(first, second))
73
return first == second
75
# If both torch.* and Tensor.* variants were found, check outputs are identical
76
if (output is not None) and (output_method is not None):
77
self.assertTrue(type(output) == type(output_method))
78
comparison = compare(output, output_method)
79
self.assertTrue(comparison, f"torch.{op} result did not match Tensor.{op} result")
81
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
82
# as the C++-side autocasting, and should be bitwise accurate.
83
output_to_compare = output if output is not None else output_method
84
with torch.cpu.amp.autocast(enabled=False):
85
self.assertFalse(torch.is_autocast_cpu_enabled())
87
if module is not None and hasattr(module, op):
88
control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs)
90
control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs)
91
self.assertTrue(type(output_to_compare) == type(control))
92
comparison = compare(output_to_compare, control)
93
self.assertTrue(comparison, f"torch.{op} result did not match control")
94
self.assertTrue(torch.is_autocast_cpu_enabled())
95
self.assertFalse(torch.is_autocast_cpu_enabled())
97
def args_maybe_kwargs(self, op_with_args):
98
if len(op_with_args) == 2:
99
return op_with_args[0], op_with_args[1], {}
101
return op_with_args[0], op_with_args[1], op_with_args[2]
104
def test_autocast_torch_expect_builtin_promote(self):
105
for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
106
self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
107
self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
110
def test_autocast_methods_expect_builtin_promote(self):
111
for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
112
self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
113
self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
116
def test_autocast_torch_16(self):
117
for op_with_args in self.autocast_lists.torch_16:
118
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
119
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
120
self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
123
def test_autocast_nn_16(self):
124
for op_with_args in self.autocast_lists.nn_16:
125
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
126
self._run_autocast_outofplace(
127
op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
129
self._run_autocast_outofplace(
134
add_kwargs=maybe_kwargs,
135
amp_dtype=torch.float16,
139
def test_autocast_torch_fp32(self):
140
for op_with_args in self.autocast_lists.torch_fp32:
141
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
142
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
143
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
146
def test_autocast_nn_fp32(self):
147
for op_with_args in self.autocast_lists.nn_fp32:
148
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
149
self._run_autocast_outofplace(
150
op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
152
self._run_autocast_outofplace(
157
add_kwargs=maybe_kwargs,
158
amp_dtype=torch.float16,
162
def test_autocast_torch_need_autocast_promote(self):
163
for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
164
self._run_autocast_outofplace(op, args1, torch.float32)
165
self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
167
@unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
168
def test_autocast_rnn(self):
169
if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported():
170
x = torch.randn(1, 2, 1)
171
hx = torch.randn(2, 2, 1)
172
cx = torch.randn(2, 2, 1)
174
m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
176
# Raise ValueError when autocast is not enabled
177
with self.assertRaisesRegex(ValueError, "input must have the type"):
180
# Should be able to run the below case with autocast
181
with torch.cpu.amp.autocast():
184
def test_autocast_disabled_with_fp32_dtype(self):
185
with torch.autocast(device_type='cpu', dtype=torch.float32, enabled=False):
188
class CustomLinear(torch.autograd.Function):
190
def forward(ctx, x, w_t):
191
ctx.save_for_backward(x, w_t)
192
return torch.nn.functional.linear(x, w_t)
195
def backward(ctx, grad_output):
196
x, w_t = ctx.saved_tensors
197
with torch.autocast(device_type='cuda'):
198
dL_dX = torch.matmul(grad_output, w_t)
199
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
202
class WeightDTypeCastCounterMode(TorchDispatchMode):
204
def __init__(self, weight):
206
self.dtype_cast_counter = 0
209
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
211
func is torch.ops.aten._to_copy.default and
212
args[0] is self.weight and
213
kwargs['dtype'] is torch.float16
215
self.dtype_cast_counter += 1
216
return func(*args, **kwargs)
219
self.old_clear_cache = torch.clear_autocast_cache
220
torch.clear_autocast_cache = lambda: None
221
return super().__enter__()
223
def __exit__(self, exc_type, exc_val, exc_tb):
224
torch.clear_autocast_cache = self.old_clear_cache
225
return super().__exit__(exc_type, exc_val, exc_tb)
227
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
228
class TestAutocastGPU(TestCase):
229
def test_cast_cache_is_global(self):
231
Verifies that the autocast cache is global. This is done by
232
mocking out cache clearing at the end of the forward pass,
233
running forward+backward with an explicit call to autocast in the
234
backward, and verifying that the weight only get cast to float16 once.
237
data = torch.randn(2, 3).cuda()
238
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
240
with WeightDTypeCastCounterMode(weight) as mode:
241
with torch.autocast(device_type='cuda'):
242
output = CustomLinear.apply(data, weight)
246
self.assertEqual(mode.dtype_cast_counter, 1)
248
def test_cache_disabled(self):
250
data = torch.randn(2, 3).cuda()
251
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
254
torch._C._set_cached_tensors_enabled(True)
255
torch._C._add_cached_tensor(weight)
257
with WeightDTypeCastCounterMode(weight) as mode:
258
with torch.autocast(device_type='cuda'):
259
output = CustomLinear.apply(data, weight)
263
# we should not have cached the conversion of the weight
264
self.assertEqual(mode.dtype_cast_counter, 2)
267
torch._C._set_cached_tensors_enabled(False)
270
class TestTorchAutocast(TestCase):
271
def test_autocast_fast_dtype(self):
272
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
273
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
274
self.assertEqual(gpu_fast_dtype, torch.half)
275
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
277
def test_invalid_device(self):
278
dev = 'not a real device'
279
msg = f'unsupported autocast device_type \'{dev}\''
280
with self.assertRaisesRegex(RuntimeError, msg):
281
with torch.autocast(device_type=dev):
285
if __name__ == '__main__':