1
# Owner(s): ["oncall: jit"]
4
from torch.cuda.amp import autocast
5
from typing import Optional, Tuple
8
from test_jit import JitTestCase
9
from torch.testing._internal.common_cuda import TEST_CUDA
10
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
11
from torch.testing import FileCheck
12
from jit.test_models import MnistNet
14
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
16
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
17
class TestAutocast(JitTestCase):
19
# common input tensors
21
self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
22
self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
23
self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
24
self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
25
self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
26
self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
27
self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
28
self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
29
self.old_value = torch._C._jit_set_autocast_mode(True)
33
torch._C._jit_set_autocast_mode(self.old_value)
36
@unittest.skipIf(not TEST_CUDA, "No cuda")
37
def test_jit_generic_autocast(self):
39
def fn_cuda_autocast(a, b):
46
def fn_generic_autocast(a, b):
47
with torch.amp.autocast(device_type='cuda'):
51
self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32))
53
@unittest.skipIf(not TEST_CUDA, "No cuda")
54
def test_minimal(self):
61
x, y = fn(self.a_fp32, self.b_fp32)
62
self.assertEqual(x.dtype, torch.float16)
63
self.assertEqual(y.dtype, torch.float32)
65
@unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
66
def test_linear_bf16(self):
69
with autocast(dtype=torch.bfloat16):
73
x, y = fn(self.a_fp32, self.b_fp32)
74
self.assertEqual(x.dtype, torch.bfloat16)
75
self.assertEqual(y.dtype, torch.float32)
77
@unittest.skipIf(not TEST_CUDA, "No cuda")
78
def test_minimal_cpu(self):
83
result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu'))
84
self.assertEqual(result.dtype, torch.float32)
86
@unittest.skipIf(not TEST_CUDA, "No cuda")
87
def test_minimal_off(self):
90
with autocast(enabled=False):
92
result = fn(self.a_fp32, self.b_fp32)
93
self.assertEqual(result.dtype, torch.float32)
95
@unittest.skipIf(not TEST_CUDA, "No cuda")
96
def test_runtime_autocast_state(self):
98
def fn(a, b, use_amp: bool):
99
with autocast(enabled=use_amp):
100
return torch.mm(a, b)
101
# runtime values for autocast enable argument are not supported
102
with self.assertRaises(RuntimeError):
103
fn(self.a_fp32, self.b_fp32, True)
105
@unittest.skipIf(not TEST_CUDA, "No cuda")
106
def test_runtime_autocast_state_expr(self):
109
with autocast(enabled=True if a[0][0] > 0.5 else False):
110
return torch.mm(a, b)
111
# runtime values for autocast enable argument are not supported
112
with self.assertRaises(RuntimeError):
113
fn(self.a_fp32, self.b_fp32)
115
@unittest.skipIf(not TEST_CUDA, "No cuda")
116
def test_explicit_casts(self):
120
e = torch.mm(a.double(), b.double()).float()
121
f = torch.mm(c, d).double()
122
g = torch.mm(c.double(), f)
124
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
125
self.assertEqual(e.dtype, torch.float32)
126
self.assertEqual(f.dtype, torch.float64)
127
self.assertEqual(g.dtype, torch.float64)
129
# multiple uses of the same input value
130
@unittest.skipIf(not TEST_CUDA, "No cuda")
131
def test_duplicate_inputs(self):
138
e, f = fn(self.a_fp32, self.b_fp32)
139
self.assertEqual(e.dtype, torch.float16)
140
self.assertEqual(f.dtype, torch.float16)
142
@unittest.skipIf(not TEST_CUDA, "No cuda")
143
def test_fp32_policy(self):
146
with autocast(enabled=True):
148
result = fn(self.a_fp16)
149
self.assertEqual(result.dtype, torch.float32)
151
@unittest.skipIf(not TEST_CUDA, "No cuda")
152
def test_fp32_policy_with_fp64(self):
155
with autocast(enabled=True):
157
# fp32 policy should not narrow fp64 to fp32!
158
result = fn(self.a_fp32.double())
159
self.assertEqual(result.dtype, torch.float64)
161
@unittest.skipIf(not TEST_CUDA, "No cuda")
162
def test_promote_policy(self):
167
f = torch.addcmul(e, c, d, value=0.1)
169
e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
170
self.assertEqual(e.dtype, torch.float16)
171
self.assertEqual(f.dtype, torch.float32)
173
@unittest.skipIf(not TEST_CUDA, "No cuda")
174
def test_promote_policy_fp64(self):
177
with autocast(enabled=True):
178
return torch.addcmul(a, a, b, value=0.1)
179
result = fn(self.a_fp32.double(), self.b_fp32.double())
180
self.assertEqual(result.dtype, torch.float64)
182
@unittest.skipIf(not TEST_CUDA, "No cuda")
183
def test_fp32_set_opt_dtype_policy(self):
185
def fn(a, b, c, d, dtype: Optional[int]):
186
with autocast(enabled=True):
187
x = torch.softmax(a, 0)
188
y = torch.softmax(b, 0, None)
189
z = torch.softmax(c, 0, torch.float64)
190
w = torch.softmax(d, 0, dtype)
192
x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None)
193
self.assertEqual(x.dtype, torch.float32)
194
self.assertEqual(y.dtype, torch.float32)
195
self.assertEqual(z.dtype, torch.float64)
196
self.assertEqual(w.dtype, torch.float16)
198
@unittest.skipIf(not TEST_CUDA, "No cuda")
199
def test_fp32_set_opt_dtype_policy_fp64(self):
201
def fn(a, b, c, d, dtype: Optional[int]):
202
with autocast(enabled=True):
203
x = torch.softmax(a, 0)
204
y = torch.softmax(b, 0, None)
205
z = torch.softmax(c, 0, torch.float64)
206
w = torch.softmax(d, 0, dtype)
208
x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None)
209
self.assertEqual(x.dtype, torch.float64)
210
self.assertEqual(y.dtype, torch.float64)
211
self.assertEqual(z.dtype, torch.float64)
212
self.assertEqual(w.dtype, torch.float64)
214
@unittest.skipIf(True, "broken due to lack of type propagation")
215
@unittest.skipIf(not TEST_CUDA, "No cuda")
216
def test_control_flow(self):
226
f = torch.mm(d, e) * x
228
e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
229
self.assertEqual(e.dtype, torch.float16)
230
self.assertEqual(f.dtype, torch.float16)
232
# this works find in regular Python, but it creates a delicate
233
# situation in TorchScript where the types are not consistent across
234
# the then/else branches
235
@unittest.skipIf(not TEST_CUDA, "No cuda")
236
def test_divergent_types(self):
242
f = torch.mm(a, b).float()
244
e = torch.mm(c, d).float()
246
return torch.mm(e.float(), f.float())
247
result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
248
self.assertEqual(result.dtype, torch.float32)
250
# another, more complex case of divergent types
251
@unittest.skipIf(not TEST_CUDA, "No cuda")
252
def test_divergent_autocast(self):
255
autocast_on = autocast(enabled=True)
256
autocast_off = autocast(enabled=False)
263
return torch.mm(e, e)
264
fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
266
@unittest.skipIf(not TEST_CUDA, "No cuda")
267
def test_conditional_autocast(self):
270
autocast_on = autocast(enabled=True)
271
autocast_off = autocast(enabled=False)
272
with autocast_on if a[0][0] > 0.5 else autocast_off:
273
return torch.mm(a, b)
274
# conditional autocast expressions are not supported
275
with self.assertRaises(RuntimeError):
276
fn(self.a_fp32, self.b_fp32)
278
@unittest.skipIf(not TEST_CUDA, "No cuda")
279
def test_nested_autocast(self):
282
with autocast(enabled=False):
284
with autocast(enabled=True):
286
with autocast(enabled=False):
289
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
290
self.assertEqual(e.dtype, torch.float32)
291
self.assertEqual(f.dtype, torch.float16)
292
self.assertEqual(g.dtype, torch.float32)
294
@unittest.skipIf(not TEST_CUDA, "No cuda")
295
def test_implicitly_nested_autocast(self):
298
with autocast(enabled=False), autocast(enabled=True):
299
return torch.mm(a, b)
300
result = fn(self.a_fp32, self.b_fp32)
301
self.assertEqual(result.dtype, torch.float16)
303
@unittest.skipIf(not TEST_CUDA, "No cuda")
304
def test_reused_autocast(self):
307
autocast_instance = autocast(enabled=True)
308
with autocast_instance:
310
with autocast_instance:
315
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
316
self.assertEqual(e.dtype, torch.float16)
317
self.assertEqual(f.dtype, torch.float16)
318
self.assertEqual(g.dtype, torch.float16)
320
# TODO: fix and enable this test?
321
# (we could technically fix this, but is it really worth it?)
322
@unittest.skipIf(True, "unsuported autocast syntax")
323
def test_reused_autocast_expr(self):
326
with autocast(enabled=True) as autocast_instance:
328
with autocast_instance:
333
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
334
self.assertEqual(e.dtype, torch.float16)
335
self.assertEqual(f.dtype, torch.float16)
336
self.assertEqual(g.dtype, torch.float16)
338
@unittest.skipIf(not TEST_CUDA, "No cuda")
339
def test_callees(self):
341
return torch.mm(a, b)
345
with autocast(enabled=True):
347
tmp = helper(tmp, tmp)
348
tmp = helper(tmp, tmp)
349
tmp = helper(tmp, tmp)
350
return helper(tmp, b)
352
result = fn(self.a_fp32, self.b_fp32)
353
self.assertEqual(result.dtype, torch.float16)
355
@unittest.skipIf(not TEST_CUDA, "No cuda")
356
def test_callees_with_autocast_on(self):
358
with autocast(enabled=True):
359
return torch.mm(a, b)
363
with autocast(enabled=False):
366
result = fn(self.a_fp32, self.b_fp32)
367
self.assertEqual(result.dtype, torch.float16)
369
@unittest.skipIf(not TEST_CUDA, "No cuda")
370
def test_callees_with_autocast_off(self):
372
with autocast(enabled=False):
373
return torch.mm(a, b)
377
with autocast(enabled=True):
380
result = fn(self.a_fp32, self.b_fp32)
381
self.assertEqual(result.dtype, torch.float32)
383
# scripting inside eager autocast
384
@unittest.skipIf(not TEST_CUDA, "No cuda")
385
def test_eager_and_script(self):
388
return torch.mm(a, b)
390
use_autocast = (i % 2 == 0)
391
expected_dtype = torch.float16 if use_autocast else torch.float32
392
with autocast(enabled=use_autocast):
393
result = fn(self.a_fp32, self.b_fp32)
394
self.assertEqual(result.dtype, expected_dtype)
396
# traced inside scripting
397
@unittest.skipIf(not TEST_CUDA, "No cuda")
398
def test_script_and_tracing(self):
400
return torch.mm(a, b)
402
traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
406
with autocast(enabled=True):
409
result = fn(self.a_fp32, self.b_fp32)
410
self.assertEqual(result.dtype, torch.float16)
412
# traced with autocast inside scripting
413
@unittest.skipIf(True, "autocast(False) is ignored inside traced functions")
414
@unittest.skipIf(not TEST_CUDA, "No cuda")
415
def test_script_and_tracing_with_autocast(self):
417
with autocast(enabled=False):
418
return torch.mm(a, b) * 2.0
420
traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
424
with autocast(enabled=True):
427
result = fn(self.a_fp32, self.b_fp32)
428
self.assertEqual(result.dtype, torch.float32)
430
# scripted called from traced
431
@unittest.skipIf(not TEST_CUDA, "No cuda")
432
def test_tracing_and_script(self):
436
return torch.mm(a, b)
441
traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
442
result = traced(self.a_fp32, self.b_fp32)
443
self.assertEqual(result.dtype, torch.float16)
445
# scripted called from traced with autocast
446
@unittest.skipIf(True, "scripted called from traced TorchScript is not yet working")
447
@unittest.skipIf(not TEST_CUDA, "No cuda")
448
def test_tracing_with_autocast_and_script(self):
451
return torch.mm(a, b)
454
with autocast(enabled=True):
457
traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
458
result = traced(self.a_fp32, self.b_fp32)
459
self.assertEqual(result.dtype, torch.float16)
461
@unittest.skipIf(not TEST_CUDA, "No cuda")
462
def test_script_module(self):
463
class TestModule(torch.nn.Module):
464
def __init__(self, N, M):
466
self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32))
467
self.linear = torch.nn.Linear(N, M).float()
469
def forward(self, input):
470
with autocast(enabled=True):
471
output = self.weight.mv(input)
472
output = self.linear(output)
475
scripted_module = torch.jit.script(TestModule(2, 3)).cuda()
476
input = torch.rand(3, dtype=torch.float32, device='cuda')
477
result = scripted_module(input)
478
self.assertEqual(result.dtype, torch.float16)
480
@unittest.skipIf(True, "autocast decorators not supported")
481
@unittest.skipIf(not TEST_CUDA, "No cuda")
482
def test_autocast_decorator(self):
484
@autocast(enabled=True)
486
return torch.mm(a, b)
487
result = fn(self.a_fp32, self.b_fp32)
488
self.assertEqual(result.dtype, torch.float16)
490
# this is equivalent to running scripted functions inside autocast)
491
# (see also test_eager_and_script)
492
@unittest.skipIf(not TEST_CUDA, "No cuda")
493
def test_autocast_decorator_outside_jit(self):
494
@autocast(enabled=True)
497
return torch.mm(a, b)
498
result = fn(self.a_fp32, self.b_fp32)
499
self.assertEqual(result.dtype, torch.float16)
501
@unittest.skipIf(not TEST_CUDA, "No cuda")
502
def test_inplace(self):
505
with autocast(enabled=True):
506
x = torch.addmm(a, b, c)
507
y = torch.addmm(a, b, c, out=a)
510
x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32)
511
self.assertEqual(x.dtype, torch.float16)
512
self.assertEqual(y.dtype, torch.float32)
513
self.assertEqual(z.dtype, torch.float32)
515
def _test_autocast(self, func, cast_op, *args):
516
jit_func = torch.jit.script(func)
518
jit_o = jit_func(*args)
519
if cast_op is not None:
520
FileCheck().check(cast_op).run(jit_func.graph_for(*args))
521
for o0, o1 in zip(o, jit_o):
522
self.assertEqual(o0.dtype, o1.dtype)
524
@unittest.skipIf(not TEST_CUDA, "No cuda")
525
def test_autocast_api(self):
527
def t_autocast_cpu(x, y):
528
with torch.autocast("cpu", dtype=torch.bfloat16):
529
return torch.mm(x, y)
531
def t_autocast_cuda(x, y):
532
with torch.autocast("cuda", dtype=torch.half):
533
return torch.mm(x, y)
535
def t_cuda_amp_autocast(x, y):
536
with torch.cuda.amp.autocast():
537
return torch.mm(x, y)
539
def t_cpu_amp_autocast(x, y):
540
with torch.cpu.amp.autocast():
541
return torch.mm(x, y)
543
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
544
y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
545
self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
546
self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
547
self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
548
self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
550
@unittest.skipIf(True, "we need to provide dtype argument at this moment")
551
@unittest.skipIf(not TEST_CUDA, "No cuda")
552
def test_autocast_api_not_supported(self):
554
def t_autocast_cpu(x, y):
555
# no dtype provided is not currently supported
556
with torch.autocast("cpu"):
557
return torch.mm(x, y)
559
def t_autocast_cuda(x, y):
560
# no dtype provided is not currently supported
561
with torch.autocast("cuda"):
562
return torch.mm(x, y)
564
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
565
y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
566
self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
567
self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
569
@unittest.skipIf(not TEST_CUDA, "No cuda")
570
def test_autocast_mixed_dtypes(self):
572
def t(cpu0, cpu1, cuda0, cuda1):
573
with torch.autocast("cpu", torch.bfloat16):
574
with torch.autocast("cuda", torch.float16):
575
cpu_o = torch.mm(cpu0, cpu1)
576
cuda_o = torch.mm(cuda0, cuda1)
579
jit_t = torch.jit.script(t)
580
cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
581
cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
582
cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
583
cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
584
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
586
@unittest.skipIf(not TEST_CUDA, "No cuda")
587
def test_jit_executor_under_autocast(self):
589
def t(cpu0, cpu1, cuda0, cuda1):
590
cpu_o = torch.mm(cpu0, cpu1)
591
cuda_o = torch.mm(cuda0, cuda1)
594
jit_t = torch.jit.script(t)
595
cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
596
cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
597
cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
598
cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
600
with torch.autocast("cpu", torch.bfloat16):
601
with torch.autocast("cuda", torch.float16):
602
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
604
with torch.autocast("cpu", torch.bfloat16):
605
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
607
with torch.autocast("cuda", torch.float16):
608
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
610
# no cast op should be observed when executing outside autocast context
611
self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
613
@unittest.skipIf(not TEST_CUDA, "No cuda")
614
def test_autocast_autodiff(self):
619
jit_t = torch.jit.script(t)
620
t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
621
t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
625
with torch.autocast("cuda", torch.float16):
626
jit_o = jit_t(t0, t1)
627
jit_o.sum().backward()
631
ref_t0 = t0.detach().requires_grad_()
632
ref_t1 = t1.detach().requires_grad_()
634
with torch.autocast("cuda", torch.float16):
635
o = t(ref_t0, ref_t1)
636
jit_o = jit_t(t0, t1)
637
jit_o.sum().backward()
639
self.assertEqual(o, jit_o)
640
self.assertEqual(t0.grad, ref_t0.grad)
641
self.assertEqual(t1.grad, ref_t1.grad)
642
self.assertEqual(o.dtype, jit_o.dtype)
643
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
644
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
646
@unittest.skipIf(not TEST_CUDA, "No cuda")
647
def test_jit_call_method_under_autocast(self):
649
class Iface(torch.nn.Module):
650
def forward(self, x, y) -> torch.Tensor:
654
def forward(self, x, y):
655
return torch.mm(x, y)
657
class Thing1(torch.nn.Module):
660
def forward(self, x, y):
661
with torch.cuda.amp.autocast():
663
b = self.impl.forward(a, x)
666
scripted_impl = torch.jit.script(Impl())
668
thing1.impl = scripted_impl
669
scripted_thing1 = torch.jit.script(thing1)
670
x = torch.rand([2, 2])
671
y = torch.rand([2, 2])
673
# make sure this doesn't throw an error
674
with torch.cuda.amp.autocast():
675
ans = scripted_thing1.forward(x, y)
676
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
678
# sanity check: this isn't supported currently when global autocasting
680
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
682
@unittest.skipIf(not TEST_CUDA, "No cuda")
683
def test_jit_freeze_autocast_basic(self):
684
class TestModule(torch.nn.Module):
685
def forward(self, x, y):
686
with torch.cuda.amp.autocast():
687
return torch.mm(x, y)
689
x = torch.rand((3, 4), dtype=torch.float).cuda()
690
y = torch.rand((4, 5), dtype=torch.float).cuda()
692
mod = TestModule().eval()
695
self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
697
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
698
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
700
# make sure that the runtime pass doesn't duplicate autocast nodes
702
optimized_graph = frozen_mod.graph_for(x, y)
703
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
705
@unittest.skipIf(not TEST_CUDA, "No cuda")
706
def test_jit_freeze_autocast_constants(self):
707
class TestModule(torch.nn.Module):
708
def __init__(self) -> None:
710
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
712
def forward(self, y):
713
with torch.cuda.amp.autocast():
714
return torch.mm(self.x, y)
716
y = torch.rand((4, 5), dtype=torch.float).cuda()
717
mod = TestModule().eval()
719
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
720
# freezing should pre-cast the constant self.x to remove one autocast call
721
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
723
# the runtime autocasting pass will re-insert the second autocast call,
724
# but constant propagation will merge it with the constant that it's casting.
726
optimized_graph = frozen_mod.graph_for(y)
727
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
729
@unittest.skipIf(TEST_CUDA, "CPU-only test")
730
def test_jit_autocast_softmax_cpu(self):
732
with torch.cpu.amp.autocast():
733
return torch.nn.functional.softmax(x, dim=0)
735
fn_s = torch.jit.script(fn)
736
x = torch.rand((2, 2), dtype=torch.bfloat16)
740
self.assertTrue(y.dtype == torch.bfloat16)
742
@unittest.skipIf(not TEST_CUDA, "No cuda")
743
def test_jit_autocast_softmax_gpu(self):
745
with torch.cuda.amp.autocast():
746
return torch.nn.functional.softmax(x, dim=0)
748
fn_s = torch.jit.script(fn)
749
x = torch.rand((2, 2), dtype=torch.half).cuda()
753
self.assertTrue(y.dtype == torch.float)
755
def test_ignore_amp(self):
758
return torch.mm(x, x)
760
inp = torch.rand([10, 10], dtype=torch.float)
761
foo._set_ignore_amp(True)
762
with torch.cpu.amp.autocast():
766
g = torch.jit.last_executed_optimized_graph()
767
FileCheck().check_not("_autocast_to_reduced").run(g)
769
class convbn(torch.nn.Module):
770
def __init__(self, bias_enabled=True):
772
self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
773
self.bn = torch.nn.BatchNorm2d(64)
775
def forward(self, x):
776
return self.bn(self.conv(x))
778
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
779
class TestJitTraceAutocast(JitTestCase):
782
self.previous_default_dtype = torch.get_default_dtype()
783
torch.set_default_dtype(torch.float32)
784
self.models = [MnistNet(),
785
convbn(bias_enabled=True),
786
convbn(bias_enabled=False)]
787
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
788
torch.randn(32, 3, 224, 224, device='cpu'),
789
torch.randn(32, 3, 224, 224, device='cpu')]
790
self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
793
torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
794
torch.set_default_dtype(self.previous_default_dtype)
797
def test_generate_autocast_jit_trace_model(self):
798
def test_generate_autocast_jit_trace_model(model, x):
800
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
801
traced_model = torch.jit.trace(model, x)
802
traced_model = torch.jit.freeze(traced_model)
803
for i in range(self.models.__len__()):
804
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
806
def test_nchw_autocast_jit_trace_model(self):
807
def test_nchw_autocast_jit_trace_model(model, x):
809
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
810
traced_model = torch.jit.trace(model, x)
811
traced_model = torch.jit.freeze(traced_model)
812
with torch.no_grad():
813
y = traced_model(x.clone())
814
with torch.cpu.amp.autocast(), torch.no_grad():
815
y2 = model(x.clone())
816
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
817
for i in range(self.models.__len__()):
818
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
820
def test_nhwc_autocast_jit_trace_model(self):
821
def test_nhwc_autocast_jit_trace_model(model, x):
822
model = model.to(memory_format=torch.channels_last)
824
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
825
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
826
traced_model = torch.jit.freeze(traced_model)
827
with torch.no_grad():
828
y = traced_model(x.clone().to(memory_format=torch.channels_last))
829
with torch.cpu.amp.autocast(), torch.no_grad():
830
y2 = model(x.clone().to(memory_format=torch.channels_last))
831
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
832
for i in range(self.models.__len__()):
833
if self.inputs[i].size().__len__() == 5:
834
# NHWC 3D case not support yet
836
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
838
def test_cat_promote(self):
839
class TestModel(torch.nn.Module):
840
def forward(self, a, b):
841
return torch.cat([a, b], 0)
843
with torch.jit.fuser("none"):
844
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
845
# To avoid the fusion group from TE, we will disable the fuser here.
846
for jit_freeze_or_not in [False, True]:
847
test_model = TestModel().eval()
848
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
849
a = torch.rand(24, 128, 128)
850
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
852
traced = torch.jit.trace(test_model, (a, b))
853
if jit_freeze_or_not:
854
traced = torch.jit.freeze(traced)
857
self.assertTrue(c.dtype, torch.float32)
858
self.assertTrue(c2.dtype, torch.float32)
859
traced_graph = traced.graph_for(a, b)
860
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
862
def test_script_autocast_cpu(self):
864
if torch.is_autocast_cpu_enabled():
869
fn_s = torch.jit.script(fn)
871
x = torch.rand((4, 4)) - 0.5
872
with torch.cpu.amp.autocast():
873
self.assertEqual(fn_s(x), fn(x))
875
with torch.cpu.amp.autocast(enabled=True):
876
self.assertEqual(fn_s(x), fn(x))
878
self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
880
@unittest.skipIf(not TEST_CUDA, "No cuda")
881
def test_script_autocast_cuda(self):
883
if torch.is_autocast_enabled():
888
fn_s = torch.jit.script(fn)
890
x = torch.rand((4, 4)) - 0.5
891
with torch.cpu.amp.autocast():
892
self.assertEqual(fn_s(x), fn(x))
894
with torch.cuda.amp.autocast(enabled=True):
895
self.assertEqual(fn_s(x), fn(x))
897
self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
900
def test_scripted_aliasing(self):
901
# torch.is_autocast_enabled should not be able to move inside of the autocast context.
903
if torch.is_autocast_enabled():
907
with torch.cuda.amp.autocast(enabled=True):
911
fn_s = torch.jit.script(fn)
914
aliasdb = graph.alias_db()
916
is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
917
enter_nodes = graph.findAllNodes("prim::Enter")
919
self.assertEqual(len(is_enabled_nodes), 1)
920
self.assertEqual(len(enter_nodes), 1)
922
self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
925
def test_script_autocast_enable_and_check(self):
926
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
927
b1 = torch.is_autocast_cpu_enabled()
929
with torch.cpu.amp.autocast(enabled=True):
930
b2 = torch.is_autocast_cpu_enabled()
932
with torch.cpu.amp.autocast(enabled=False):
933
b3 = torch.is_autocast_cpu_enabled()
935
return (v1, b1, v2, b2, v3, b3)
937
# bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
938
def check_fn_results(arr):
939
[v1, b1, v2, b2, v3, b3] = arr
940
self.assertTrue((v1.dtype == torch.float) != b1)
941
self.assertTrue((v2.dtype == torch.float) != b2)
942
self.assertTrue((v3.dtype == torch.float) != b3)
944
x = torch.rand((2, 2), dtype=torch.float)
945
y = torch.rand((2, 2), dtype=torch.float)
947
fn_s = torch.jit.script(fn)
949
with torch.cpu.amp.autocast(enabled=False):
950
check_fn_results(fn(x, y))
951
check_fn_results(fn_s(x, y))
953
with torch.cpu.amp.autocast(enabled=True):
954
check_fn_results(fn(x, y))
955
check_fn_results(fn_s(x, y))
958
if __name__ == "__main__":