pytorch

Форк
0
/
test_jit_autocast.py 
959 строк · 36.2 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import torch
4
from torch.cuda.amp import autocast
5
from typing import Optional, Tuple
6

7
import unittest
8
from test_jit import JitTestCase
9
from torch.testing._internal.common_cuda import TEST_CUDA
10
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
11
from torch.testing import FileCheck
12
from jit.test_models import MnistNet
13

14
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
15

16
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
17
class TestAutocast(JitTestCase):
18
    def setUp(self):
19
        # common input tensors
20
        if TEST_CUDA:
21
            self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
22
            self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
23
            self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
24
            self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
25
            self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
26
            self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
27
            self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
28
            self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
29
        self.old_value = torch._C._jit_set_autocast_mode(True)
30
        super().setUp()
31

32
    def tearDown(self):
33
        torch._C._jit_set_autocast_mode(self.old_value)
34
        super().tearDown()
35

36
    @unittest.skipIf(not TEST_CUDA, "No cuda")
37
    def test_jit_generic_autocast(self):
38
        @torch.jit.script
39
        def fn_cuda_autocast(a, b):
40
            with autocast():
41
                x = torch.mm(a, b)
42
                y = torch.sum(x)
43
                return x, y
44

45
        @torch.jit.script
46
        def fn_generic_autocast(a, b):
47
            with torch.amp.autocast(device_type='cuda'):
48
                x = torch.mm(a, b)
49
                y = torch.sum(x)
50
                return x, y
51
        self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32))
52

53
    @unittest.skipIf(not TEST_CUDA, "No cuda")
54
    def test_minimal(self):
55
        @torch.jit.script
56
        def fn(a, b):
57
            with autocast():
58
                x = torch.mm(a, b)
59
                y = torch.sum(x)
60
                return x, y
61
        x, y = fn(self.a_fp32, self.b_fp32)
62
        self.assertEqual(x.dtype, torch.float16)
63
        self.assertEqual(y.dtype, torch.float32)
64

65
    @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
66
    def test_linear_bf16(self):
67
        @torch.jit.script
68
        def fn(a, b):
69
            with autocast(dtype=torch.bfloat16):
70
                x = torch.mm(a, b)
71
                y = torch.sum(x)
72
                return x, y
73
        x, y = fn(self.a_fp32, self.b_fp32)
74
        self.assertEqual(x.dtype, torch.bfloat16)
75
        self.assertEqual(y.dtype, torch.float32)
76

77
    @unittest.skipIf(not TEST_CUDA, "No cuda")
78
    def test_minimal_cpu(self):
79
        @torch.jit.script
80
        def fn(a, b):
81
            with autocast():
82
                return torch.mm(a, b)
83
        result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu'))
84
        self.assertEqual(result.dtype, torch.float32)
85

86
    @unittest.skipIf(not TEST_CUDA, "No cuda")
87
    def test_minimal_off(self):
88
        @torch.jit.script
89
        def fn(a, b):
90
            with autocast(enabled=False):
91
                return torch.mm(a, b)
92
        result = fn(self.a_fp32, self.b_fp32)
93
        self.assertEqual(result.dtype, torch.float32)
94

95
    @unittest.skipIf(not TEST_CUDA, "No cuda")
96
    def test_runtime_autocast_state(self):
97
        @torch.jit.script
98
        def fn(a, b, use_amp: bool):
99
            with autocast(enabled=use_amp):
100
                return torch.mm(a, b)
101
        # runtime values for autocast enable argument are not supported
102
        with self.assertRaises(RuntimeError):
103
            fn(self.a_fp32, self.b_fp32, True)
104

105
    @unittest.skipIf(not TEST_CUDA, "No cuda")
106
    def test_runtime_autocast_state_expr(self):
107
        @torch.jit.script
108
        def fn(a, b):
109
            with autocast(enabled=True if a[0][0] > 0.5 else False):
110
                return torch.mm(a, b)
111
        # runtime values for autocast enable argument are not supported
112
        with self.assertRaises(RuntimeError):
113
            fn(self.a_fp32, self.b_fp32)
114

115
    @unittest.skipIf(not TEST_CUDA, "No cuda")
116
    def test_explicit_casts(self):
117
        @torch.jit.script
118
        def fn(a, b, c, d):
119
            with autocast():
120
                e = torch.mm(a.double(), b.double()).float()
121
                f = torch.mm(c, d).double()
122
            g = torch.mm(c.double(), f)
123
            return e, f, g
124
        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
125
        self.assertEqual(e.dtype, torch.float32)
126
        self.assertEqual(f.dtype, torch.float64)
127
        self.assertEqual(g.dtype, torch.float64)
128

129
    # multiple uses of the same input value
130
    @unittest.skipIf(not TEST_CUDA, "No cuda")
131
    def test_duplicate_inputs(self):
132
        @torch.jit.script
133
        def fn(a, b):
134
            with autocast():
135
                e = torch.mm(a, a)
136
                f = torch.mm(e, e)
137
            return e, f
138
        e, f = fn(self.a_fp32, self.b_fp32)
139
        self.assertEqual(e.dtype, torch.float16)
140
        self.assertEqual(f.dtype, torch.float16)
141

142
    @unittest.skipIf(not TEST_CUDA, "No cuda")
143
    def test_fp32_policy(self):
144
        @torch.jit.script
145
        def fn(a):
146
            with autocast(enabled=True):
147
                return torch.log(a)
148
        result = fn(self.a_fp16)
149
        self.assertEqual(result.dtype, torch.float32)
150

151
    @unittest.skipIf(not TEST_CUDA, "No cuda")
152
    def test_fp32_policy_with_fp64(self):
153
        @torch.jit.script
154
        def fn(a):
155
            with autocast(enabled=True):
156
                return torch.log(a)
157
        # fp32 policy should not narrow fp64 to fp32!
158
        result = fn(self.a_fp32.double())
159
        self.assertEqual(result.dtype, torch.float64)
160

161
    @unittest.skipIf(not TEST_CUDA, "No cuda")
162
    def test_promote_policy(self):
163
        @torch.jit.script
164
        def fn(a, b, c, d):
165
            with autocast():
166
                e = torch.mm(a, b)
167
                f = torch.addcmul(e, c, d, value=0.1)
168
            return e, f
169
        e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
170
        self.assertEqual(e.dtype, torch.float16)
171
        self.assertEqual(f.dtype, torch.float32)
172

173
    @unittest.skipIf(not TEST_CUDA, "No cuda")
174
    def test_promote_policy_fp64(self):
175
        @torch.jit.script
176
        def fn(a, b):
177
            with autocast(enabled=True):
178
                return torch.addcmul(a, a, b, value=0.1)
179
        result = fn(self.a_fp32.double(), self.b_fp32.double())
180
        self.assertEqual(result.dtype, torch.float64)
181

182
    @unittest.skipIf(not TEST_CUDA, "No cuda")
183
    def test_fp32_set_opt_dtype_policy(self):
184
        @torch.jit.script
185
        def fn(a, b, c, d, dtype: Optional[int]):
186
            with autocast(enabled=True):
187
                x = torch.softmax(a, 0)
188
                y = torch.softmax(b, 0, None)
189
                z = torch.softmax(c, 0, torch.float64)
190
                w = torch.softmax(d, 0, dtype)
191
            return x, y, z, w
192
        x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None)
193
        self.assertEqual(x.dtype, torch.float32)
194
        self.assertEqual(y.dtype, torch.float32)
195
        self.assertEqual(z.dtype, torch.float64)
196
        self.assertEqual(w.dtype, torch.float16)
197

198
    @unittest.skipIf(not TEST_CUDA, "No cuda")
199
    def test_fp32_set_opt_dtype_policy_fp64(self):
200
        @torch.jit.script
201
        def fn(a, b, c, d, dtype: Optional[int]):
202
            with autocast(enabled=True):
203
                x = torch.softmax(a, 0)
204
                y = torch.softmax(b, 0, None)
205
                z = torch.softmax(c, 0, torch.float64)
206
                w = torch.softmax(d, 0, dtype)
207
            return x, y, z, w
208
        x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None)
209
        self.assertEqual(x.dtype, torch.float64)
210
        self.assertEqual(y.dtype, torch.float64)
211
        self.assertEqual(z.dtype, torch.float64)
212
        self.assertEqual(w.dtype, torch.float64)
213

214
    @unittest.skipIf(True, "broken due to lack of type propagation")
215
    @unittest.skipIf(not TEST_CUDA, "No cuda")
216
    def test_control_flow(self):
217
        @torch.jit.script
218
        def fn(a, b, c, d):
219
            with autocast():
220
                if a[0][0] > 0.5:
221
                    e = torch.mm(a, b)
222
                    x = 1
223
                else:
224
                    e = torch.mm(c, d)
225
                    x = 2
226
                f = torch.mm(d, e) * x
227
            return e, f
228
        e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
229
        self.assertEqual(e.dtype, torch.float16)
230
        self.assertEqual(f.dtype, torch.float16)
231

232
    # this works find in regular Python, but it creates a delicate
233
    # situation in TorchScript where the types are not consistent across
234
    # the then/else branches
235
    @unittest.skipIf(not TEST_CUDA, "No cuda")
236
    def test_divergent_types(self):
237
        @torch.jit.script
238
        def fn(a, b, c, d):
239
            with autocast():
240
                if a[0][0] > 0.5:
241
                    e = torch.mm(a, b)
242
                    f = torch.mm(a, b).float()
243
                else:
244
                    e = torch.mm(c, d).float()
245
                    f = torch.mm(a, b)
246
            return torch.mm(e.float(), f.float())
247
        result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
248
        self.assertEqual(result.dtype, torch.float32)
249

250
    # another, more complex case of divergent types
251
    @unittest.skipIf(not TEST_CUDA, "No cuda")
252
    def test_divergent_autocast(self):
253
        @torch.jit.script
254
        def fn(a, b, c, d):
255
            autocast_on = autocast(enabled=True)
256
            autocast_off = autocast(enabled=False)
257
            if a[0][0] > 0.5:
258
                with autocast_on:
259
                    e = torch.mm(a, b)
260
            else:
261
                with autocast_off:
262
                    e = torch.mm(c, d)
263
            return torch.mm(e, e)
264
        fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
265

266
    @unittest.skipIf(not TEST_CUDA, "No cuda")
267
    def test_conditional_autocast(self):
268
        @torch.jit.script
269
        def fn(a, b):
270
            autocast_on = autocast(enabled=True)
271
            autocast_off = autocast(enabled=False)
272
            with autocast_on if a[0][0] > 0.5 else autocast_off:
273
                return torch.mm(a, b)
274
        # conditional autocast expressions are not supported
275
        with self.assertRaises(RuntimeError):
276
            fn(self.a_fp32, self.b_fp32)
277

278
    @unittest.skipIf(not TEST_CUDA, "No cuda")
279
    def test_nested_autocast(self):
280
        @torch.jit.script
281
        def fn(a, b, c, d):
282
            with autocast(enabled=False):
283
                e = torch.mm(a, b)
284
                with autocast(enabled=True):
285
                    f = torch.mm(e, c)
286
                    with autocast(enabled=False):
287
                        g = torch.mm(e, d)
288
            return e, f, g
289
        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
290
        self.assertEqual(e.dtype, torch.float32)
291
        self.assertEqual(f.dtype, torch.float16)
292
        self.assertEqual(g.dtype, torch.float32)
293

294
    @unittest.skipIf(not TEST_CUDA, "No cuda")
295
    def test_implicitly_nested_autocast(self):
296
        @torch.jit.script
297
        def fn(a, b):
298
            with autocast(enabled=False), autocast(enabled=True):
299
                return torch.mm(a, b)
300
        result = fn(self.a_fp32, self.b_fp32)
301
        self.assertEqual(result.dtype, torch.float16)
302

303
    @unittest.skipIf(not TEST_CUDA, "No cuda")
304
    def test_reused_autocast(self):
305
        @torch.jit.script
306
        def fn(a, b, c, d):
307
            autocast_instance = autocast(enabled=True)
308
            with autocast_instance:
309
                e = torch.mm(a, b)
310
                with autocast_instance:
311
                    e = torch.mm(c, d)
312
                    f = torch.mm(d, e)
313
            g = torch.mm(e, f)
314
            return e, f, g
315
        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
316
        self.assertEqual(e.dtype, torch.float16)
317
        self.assertEqual(f.dtype, torch.float16)
318
        self.assertEqual(g.dtype, torch.float16)
319

320
    # TODO: fix and enable this test?
321
    #   (we could technically fix this, but is it really worth it?)
322
    @unittest.skipIf(True, "unsuported autocast syntax")
323
    def test_reused_autocast_expr(self):
324
        @torch.jit.script
325
        def fn(a, b, c, d):
326
            with autocast(enabled=True) as autocast_instance:
327
                e = torch.mm(a, b)
328
                with autocast_instance:
329
                    e = torch.mm(c, d)
330
                    f = torch.mm(d, e)
331
            g = torch.mm(e, f)
332
            return e, f, g
333
        e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
334
        self.assertEqual(e.dtype, torch.float16)
335
        self.assertEqual(f.dtype, torch.float16)
336
        self.assertEqual(g.dtype, torch.float16)
337

338
    @unittest.skipIf(not TEST_CUDA, "No cuda")
339
    def test_callees(self):
340
        def helper(a, b):
341
            return torch.mm(a, b)
342

343
        @torch.jit.script
344
        def fn(a, b):
345
            with autocast(enabled=True):
346
                tmp = helper(a, b)
347
                tmp = helper(tmp, tmp)
348
                tmp = helper(tmp, tmp)
349
                tmp = helper(tmp, tmp)
350
                return helper(tmp, b)
351

352
        result = fn(self.a_fp32, self.b_fp32)
353
        self.assertEqual(result.dtype, torch.float16)
354

355
    @unittest.skipIf(not TEST_CUDA, "No cuda")
356
    def test_callees_with_autocast_on(self):
357
        def helper(a, b):
358
            with autocast(enabled=True):
359
                return torch.mm(a, b)
360

361
        @torch.jit.script
362
        def fn(a, b):
363
            with autocast(enabled=False):
364
                return helper(a, b)
365

366
        result = fn(self.a_fp32, self.b_fp32)
367
        self.assertEqual(result.dtype, torch.float16)
368

369
    @unittest.skipIf(not TEST_CUDA, "No cuda")
370
    def test_callees_with_autocast_off(self):
371
        def helper(a, b):
372
            with autocast(enabled=False):
373
                return torch.mm(a, b)
374

375
        @torch.jit.script
376
        def fn(a, b):
377
            with autocast(enabled=True):
378
                return helper(a, b)
379

380
        result = fn(self.a_fp32, self.b_fp32)
381
        self.assertEqual(result.dtype, torch.float32)
382

383
    # scripting inside eager autocast
384
    @unittest.skipIf(not TEST_CUDA, "No cuda")
385
    def test_eager_and_script(self):
386
        @torch.jit.script
387
        def fn(a, b):
388
            return torch.mm(a, b)
389
        for i in range(8):
390
            use_autocast = (i % 2 == 0)
391
            expected_dtype = torch.float16 if use_autocast else torch.float32
392
            with autocast(enabled=use_autocast):
393
                result = fn(self.a_fp32, self.b_fp32)
394
            self.assertEqual(result.dtype, expected_dtype)
395

396
    # traced inside scripting
397
    @unittest.skipIf(not TEST_CUDA, "No cuda")
398
    def test_script_and_tracing(self):
399
        def helper(a, b):
400
            return torch.mm(a, b)
401

402
        traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
403

404
        @torch.jit.script
405
        def fn(a, b):
406
            with autocast(enabled=True):
407
                return traced(a, b)
408

409
        result = fn(self.a_fp32, self.b_fp32)
410
        self.assertEqual(result.dtype, torch.float16)
411

412
    # traced with autocast inside scripting
413
    @unittest.skipIf(True, "autocast(False) is ignored inside traced functions")
414
    @unittest.skipIf(not TEST_CUDA, "No cuda")
415
    def test_script_and_tracing_with_autocast(self):
416
        def helper(a, b):
417
            with autocast(enabled=False):
418
                return torch.mm(a, b) * 2.0
419

420
        traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
421

422
        @torch.jit.script
423
        def fn(a, b):
424
            with autocast(enabled=True):
425
                return traced(a, b)
426

427
        result = fn(self.a_fp32, self.b_fp32)
428
        self.assertEqual(result.dtype, torch.float32)
429

430
    # scripted called from traced
431
    @unittest.skipIf(not TEST_CUDA, "No cuda")
432
    def test_tracing_and_script(self):
433
        @torch.jit.script
434
        def fn(a, b):
435
            with autocast():
436
                return torch.mm(a, b)
437

438
        def traced(a, b):
439
            return fn(a, b)
440

441
        traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
442
        result = traced(self.a_fp32, self.b_fp32)
443
        self.assertEqual(result.dtype, torch.float16)
444

445
    # scripted called from traced with autocast
446
    @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working")
447
    @unittest.skipIf(not TEST_CUDA, "No cuda")
448
    def test_tracing_with_autocast_and_script(self):
449
        @torch.jit.script
450
        def fn(a, b):
451
            return torch.mm(a, b)
452

453
        def traced(a, b):
454
            with autocast(enabled=True):
455
                return fn(a, b)
456

457
        traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
458
        result = traced(self.a_fp32, self.b_fp32)
459
        self.assertEqual(result.dtype, torch.float16)
460

461
    @unittest.skipIf(not TEST_CUDA, "No cuda")
462
    def test_script_module(self):
463
        class TestModule(torch.nn.Module):
464
            def __init__(self, N, M):
465
                super().__init__()
466
                self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32))
467
                self.linear = torch.nn.Linear(N, M).float()
468

469
            def forward(self, input):
470
                with autocast(enabled=True):
471
                    output = self.weight.mv(input)
472
                    output = self.linear(output)
473
                    return output
474

475
        scripted_module = torch.jit.script(TestModule(2, 3)).cuda()
476
        input = torch.rand(3, dtype=torch.float32, device='cuda')
477
        result = scripted_module(input)
478
        self.assertEqual(result.dtype, torch.float16)
479

480
    @unittest.skipIf(True, "autocast decorators not supported")
481
    @unittest.skipIf(not TEST_CUDA, "No cuda")
482
    def test_autocast_decorator(self):
483
        @torch.jit.script
484
        @autocast(enabled=True)
485
        def fn(a, b):
486
            return torch.mm(a, b)
487
        result = fn(self.a_fp32, self.b_fp32)
488
        self.assertEqual(result.dtype, torch.float16)
489

490
    # this is equivalent to running scripted functions inside autocast)
491
    # (see also test_eager_and_script)
492
    @unittest.skipIf(not TEST_CUDA, "No cuda")
493
    def test_autocast_decorator_outside_jit(self):
494
        @autocast(enabled=True)
495
        @torch.jit.script
496
        def fn(a, b):
497
            return torch.mm(a, b)
498
        result = fn(self.a_fp32, self.b_fp32)
499
        self.assertEqual(result.dtype, torch.float16)
500

501
    @unittest.skipIf(not TEST_CUDA, "No cuda")
502
    def test_inplace(self):
503
        @torch.jit.script
504
        def fn(a, b, c):
505
            with autocast(enabled=True):
506
                x = torch.addmm(a, b, c)
507
                y = torch.addmm(a, b, c, out=a)
508
                z = a.addmm_(b, c)
509
                return x, y, z
510
        x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32)
511
        self.assertEqual(x.dtype, torch.float16)
512
        self.assertEqual(y.dtype, torch.float32)
513
        self.assertEqual(z.dtype, torch.float32)
514

515
    def _test_autocast(self, func, cast_op, *args):
516
        jit_func = torch.jit.script(func)
517
        o = func(*args)
518
        jit_o = jit_func(*args)
519
        if cast_op is not None:
520
            FileCheck().check(cast_op).run(jit_func.graph_for(*args))
521
        for o0, o1 in zip(o, jit_o):
522
            self.assertEqual(o0.dtype, o1.dtype)
523

524
    @unittest.skipIf(not TEST_CUDA, "No cuda")
525
    def test_autocast_api(self):
526

527
        def t_autocast_cpu(x, y):
528
            with torch.autocast("cpu", dtype=torch.bfloat16):
529
                return torch.mm(x, y)
530

531
        def t_autocast_cuda(x, y):
532
            with torch.autocast("cuda", dtype=torch.half):
533
                return torch.mm(x, y)
534

535
        def t_cuda_amp_autocast(x, y):
536
            with torch.cuda.amp.autocast():
537
                return torch.mm(x, y)
538

539
        def t_cpu_amp_autocast(x, y):
540
            with torch.cpu.amp.autocast():
541
                return torch.mm(x, y)
542

543
        x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
544
        y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
545
        self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
546
        self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
547
        self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
548
        self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
549

550
    @unittest.skipIf(True, "we need to provide dtype argument at this moment")
551
    @unittest.skipIf(not TEST_CUDA, "No cuda")
552
    def test_autocast_api_not_supported(self):
553

554
        def t_autocast_cpu(x, y):
555
            # no dtype provided is not currently supported
556
            with torch.autocast("cpu"):
557
                return torch.mm(x, y)
558

559
        def t_autocast_cuda(x, y):
560
            # no dtype provided is not currently supported
561
            with torch.autocast("cuda"):
562
                return torch.mm(x, y)
563

564
        x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
565
        y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
566
        self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
567
        self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
568

569
    @unittest.skipIf(not TEST_CUDA, "No cuda")
570
    def test_autocast_mixed_dtypes(self):
571

572
        def t(cpu0, cpu1, cuda0, cuda1):
573
            with torch.autocast("cpu", torch.bfloat16):
574
                with torch.autocast("cuda", torch.float16):
575
                    cpu_o = torch.mm(cpu0, cpu1)
576
                    cuda_o = torch.mm(cuda0, cuda1)
577
                    return cpu_o, cuda_o
578

579
        jit_t = torch.jit.script(t)
580
        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
581
        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
582
        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
583
        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
584
        self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
585

586
    @unittest.skipIf(not TEST_CUDA, "No cuda")
587
    def test_jit_executor_under_autocast(self):
588

589
        def t(cpu0, cpu1, cuda0, cuda1):
590
            cpu_o = torch.mm(cpu0, cpu1)
591
            cuda_o = torch.mm(cuda0, cuda1)
592
            return cpu_o, cuda_o
593

594
        jit_t = torch.jit.script(t)
595
        cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
596
        cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
597
        cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
598
        cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
599

600
        with torch.autocast("cpu", torch.bfloat16):
601
            with torch.autocast("cuda", torch.float16):
602
                self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
603

604
        with torch.autocast("cpu", torch.bfloat16):
605
            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
606

607
        with torch.autocast("cuda", torch.float16):
608
            self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
609

610
        # no cast op should be observed when executing outside autocast context
611
        self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
612

613
    @unittest.skipIf(not TEST_CUDA, "No cuda")
614
    def test_autocast_autodiff(self):
615
        def t(t0, t1):
616
            o = torch.mm(t0, t1)
617
            return o.relu()
618

619
        jit_t = torch.jit.script(t)
620
        t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
621
        t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
622

623
        # run optimization
624
        for i in range(5):
625
            with torch.autocast("cuda", torch.float16):
626
                jit_o = jit_t(t0, t1)
627
            jit_o.sum().backward()
628

629
        t0.grad = None
630
        t1.grad = None
631
        ref_t0 = t0.detach().requires_grad_()
632
        ref_t1 = t1.detach().requires_grad_()
633

634
        with torch.autocast("cuda", torch.float16):
635
            o = t(ref_t0, ref_t1)
636
            jit_o = jit_t(t0, t1)
637
        jit_o.sum().backward()
638
        o.sum().backward()
639
        self.assertEqual(o, jit_o)
640
        self.assertEqual(t0.grad, ref_t0.grad)
641
        self.assertEqual(t1.grad, ref_t1.grad)
642
        self.assertEqual(o.dtype, jit_o.dtype)
643
        self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
644
        self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
645

646
    @unittest.skipIf(not TEST_CUDA, "No cuda")
647
    def test_jit_call_method_under_autocast(self):
648
        @torch.jit.interface
649
        class Iface(torch.nn.Module):
650
            def forward(self, x, y) -> torch.Tensor:
651
                pass
652

653
        class Impl(Iface):
654
            def forward(self, x, y):
655
                return torch.mm(x, y)
656

657
        class Thing1(torch.nn.Module):
658
            impl: Iface
659

660
            def forward(self, x, y):
661
                with torch.cuda.amp.autocast():
662
                    a = torch.mm(x, y)
663
                    b = self.impl.forward(a, x)
664
                    return b
665

666
        scripted_impl = torch.jit.script(Impl())
667
        thing1 = Thing1()
668
        thing1.impl = scripted_impl
669
        scripted_thing1 = torch.jit.script(thing1)
670
        x = torch.rand([2, 2])
671
        y = torch.rand([2, 2])
672

673
        # make sure this doesn't throw an error
674
        with torch.cuda.amp.autocast():
675
            ans = scripted_thing1.forward(x, y)
676
        self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
677

678
        # sanity check: this isn't supported currently when global autocasting
679
        # isn't enabled
680
        self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
681

682
    @unittest.skipIf(not TEST_CUDA, "No cuda")
683
    def test_jit_freeze_autocast_basic(self):
684
        class TestModule(torch.nn.Module):
685
            def forward(self, x, y):
686
                with torch.cuda.amp.autocast():
687
                    return torch.mm(x, y)
688

689
        x = torch.rand((3, 4), dtype=torch.float).cuda()
690
        y = torch.rand((4, 5), dtype=torch.float).cuda()
691

692
        mod = TestModule().eval()
693

694
        # sanity check
695
        self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
696

697
        frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
698
        FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
699

700
        # make sure that the runtime pass doesn't duplicate autocast nodes
701
        frozen_mod(x, y)
702
        optimized_graph = frozen_mod.graph_for(x, y)
703
        FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
704

705
    @unittest.skipIf(not TEST_CUDA, "No cuda")
706
    def test_jit_freeze_autocast_constants(self):
707
        class TestModule(torch.nn.Module):
708
            def __init__(self) -> None:
709
                super().__init__()
710
                self.x = torch.rand((3, 4), dtype=torch.float).cuda()
711

712
            def forward(self, y):
713
                with torch.cuda.amp.autocast():
714
                    return torch.mm(self.x, y)
715

716
        y = torch.rand((4, 5), dtype=torch.float).cuda()
717
        mod = TestModule().eval()
718

719
        frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
720
        # freezing should pre-cast the constant self.x to remove one autocast call
721
        FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
722

723
        # the runtime autocasting pass will re-insert the second autocast call,
724
        # but constant propagation will merge it with the constant that it's casting.
725
        frozen_mod(y)
726
        optimized_graph = frozen_mod.graph_for(y)
727
        FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
728

729
    @unittest.skipIf(TEST_CUDA, "CPU-only test")
730
    def test_jit_autocast_softmax_cpu(self):
731
        def fn(x):
732
            with torch.cpu.amp.autocast():
733
                return torch.nn.functional.softmax(x, dim=0)
734

735
        fn_s = torch.jit.script(fn)
736
        x = torch.rand((2, 2), dtype=torch.bfloat16)
737
        fn_s(x)
738
        y = fn_s(x)
739

740
        self.assertTrue(y.dtype == torch.bfloat16)
741

742
    @unittest.skipIf(not TEST_CUDA, "No cuda")
743
    def test_jit_autocast_softmax_gpu(self):
744
        def fn(x):
745
            with torch.cuda.amp.autocast():
746
                return torch.nn.functional.softmax(x, dim=0)
747

748
        fn_s = torch.jit.script(fn)
749
        x = torch.rand((2, 2), dtype=torch.half).cuda()
750
        fn_s(x)
751
        y = fn_s(x)
752

753
        self.assertTrue(y.dtype == torch.float)
754

755
    def test_ignore_amp(self):
756
        @torch.jit.script
757
        def foo(x):
758
            return torch.mm(x, x)
759

760
        inp = torch.rand([10, 10], dtype=torch.float)
761
        foo._set_ignore_amp(True)
762
        with torch.cpu.amp.autocast():
763
            foo(inp)
764
            foo(inp)
765

766
        g = torch.jit.last_executed_optimized_graph()
767
        FileCheck().check_not("_autocast_to_reduced").run(g)
768

769
class convbn(torch.nn.Module):
770
    def __init__(self, bias_enabled=True):
771
        super().__init__()
772
        self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
773
        self.bn = torch.nn.BatchNorm2d(64)
774

775
    def forward(self, x):
776
        return self.bn(self.conv(x))
777

778
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
779
class TestJitTraceAutocast(JitTestCase):
780
    def setUp(self):
781
        super().setUp()
782
        self.previous_default_dtype = torch.get_default_dtype()
783
        torch.set_default_dtype(torch.float32)
784
        self.models = [MnistNet(),
785
                       convbn(bias_enabled=True),
786
                       convbn(bias_enabled=False)]
787
        self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
788
                       torch.randn(32, 3, 224, 224, device='cpu'),
789
                       torch.randn(32, 3, 224, 224, device='cpu')]
790
        self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
791

792
    def tearDown(self):
793
        torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
794
        torch.set_default_dtype(self.previous_default_dtype)
795
        super().tearDown()
796

797
    def test_generate_autocast_jit_trace_model(self):
798
        def test_generate_autocast_jit_trace_model(model, x):
799
            model.eval()
800
            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
801
                traced_model = torch.jit.trace(model, x)
802
            traced_model = torch.jit.freeze(traced_model)
803
        for i in range(self.models.__len__()):
804
            test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
805

806
    def test_nchw_autocast_jit_trace_model(self):
807
        def test_nchw_autocast_jit_trace_model(model, x):
808
            model.eval()
809
            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
810
                traced_model = torch.jit.trace(model, x)
811
            traced_model = torch.jit.freeze(traced_model)
812
            with torch.no_grad():
813
                y = traced_model(x.clone())
814
            with torch.cpu.amp.autocast(), torch.no_grad():
815
                y2 = model(x.clone())
816
            torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
817
        for i in range(self.models.__len__()):
818
            test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
819

820
    def test_nhwc_autocast_jit_trace_model(self):
821
        def test_nhwc_autocast_jit_trace_model(model, x):
822
            model = model.to(memory_format=torch.channels_last)
823
            model.eval()
824
            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
825
                traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
826
            traced_model = torch.jit.freeze(traced_model)
827
            with torch.no_grad():
828
                y = traced_model(x.clone().to(memory_format=torch.channels_last))
829
            with torch.cpu.amp.autocast(), torch.no_grad():
830
                y2 = model(x.clone().to(memory_format=torch.channels_last))
831
            torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
832
        for i in range(self.models.__len__()):
833
            if self.inputs[i].size().__len__() == 5:
834
                # NHWC 3D case not support yet
835
                continue
836
            test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
837

838
    def test_cat_promote(self):
839
        class TestModel(torch.nn.Module):
840
            def forward(self, a, b):
841
                return torch.cat([a, b], 0)
842

843
        with torch.jit.fuser("none"):
844
            # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
845
            # To avoid the fusion group from TE, we will disable the fuser here.
846
            for jit_freeze_or_not in [False, True]:
847
                test_model = TestModel().eval()
848
                with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
849
                    a = torch.rand(24, 128, 128)
850
                    b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
851
                    c = test_model(a, b)
852
                    traced = torch.jit.trace(test_model, (a, b))
853
                if jit_freeze_or_not:
854
                    traced = torch.jit.freeze(traced)
855
                for _ in range(3):
856
                    c2 = traced(a, b)
857
                self.assertTrue(c.dtype, torch.float32)
858
                self.assertTrue(c2.dtype, torch.float32)
859
                traced_graph = traced.graph_for(a, b)
860
                self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
861

862
    def test_script_autocast_cpu(self):
863
        def fn(x):
864
            if torch.is_autocast_cpu_enabled():
865
                return x.relu()
866
            else:
867
                return x.sin()
868

869
        fn_s = torch.jit.script(fn)
870

871
        x = torch.rand((4, 4)) - 0.5
872
        with torch.cpu.amp.autocast():
873
            self.assertEqual(fn_s(x), fn(x))
874

875
        with torch.cpu.amp.autocast(enabled=True):
876
            self.assertEqual(fn_s(x), fn(x))
877

878
        self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()))
879

880
    @unittest.skipIf(not TEST_CUDA, "No cuda")
881
    def test_script_autocast_cuda(self):
882
        def fn(x):
883
            if torch.is_autocast_enabled():
884
                return x.relu()
885
            else:
886
                return x.sin()
887

888
        fn_s = torch.jit.script(fn)
889

890
        x = torch.rand((4, 4)) - 0.5
891
        with torch.cpu.amp.autocast():
892
            self.assertEqual(fn_s(x), fn(x))
893

894
        with torch.cuda.amp.autocast(enabled=True):
895
            self.assertEqual(fn_s(x), fn(x))
896

897
        self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()))
898

899

900
    def test_scripted_aliasing(self):
901
        # torch.is_autocast_enabled should not be able to move inside of the autocast context.
902
        def fn(x):
903
            if torch.is_autocast_enabled():
904
                y = True
905
            else:
906
                y = False
907
            with torch.cuda.amp.autocast(enabled=True):
908
                z = x.relu()
909
            return y, z
910

911
        fn_s = torch.jit.script(fn)
912
        graph = fn_s.graph
913

914
        aliasdb = graph.alias_db()
915

916
        is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
917
        enter_nodes = graph.findAllNodes("prim::Enter")
918

919
        self.assertEqual(len(is_enabled_nodes), 1)
920
        self.assertEqual(len(enter_nodes), 1)
921

922
        self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
923

924

925
    def test_script_autocast_enable_and_check(self):
926
        def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
927
            b1 = torch.is_autocast_cpu_enabled()
928
            v1 = torch.mm(x, y)
929
            with torch.cpu.amp.autocast(enabled=True):
930
                b2 = torch.is_autocast_cpu_enabled()
931
                v2 = torch.mm(x, y)
932
                with torch.cpu.amp.autocast(enabled=False):
933
                    b3 = torch.is_autocast_cpu_enabled()
934
                    v3 = torch.mm(x, y)
935
            return (v1, b1, v2, b2, v3, b3)
936

937
        # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
938
        def check_fn_results(arr):
939
            [v1, b1, v2, b2, v3, b3] = arr
940
            self.assertTrue((v1.dtype == torch.float) != b1)
941
            self.assertTrue((v2.dtype == torch.float) != b2)
942
            self.assertTrue((v3.dtype == torch.float) != b3)
943

944
        x = torch.rand((2, 2), dtype=torch.float)
945
        y = torch.rand((2, 2), dtype=torch.float)
946

947
        fn_s = torch.jit.script(fn)
948

949
        with torch.cpu.amp.autocast(enabled=False):
950
            check_fn_results(fn(x, y))
951
            check_fn_results(fn_s(x, y))
952

953
        with torch.cpu.amp.autocast(enabled=True):
954
            check_fn_results(fn(x, y))
955
            check_fn_results(fn_s(x, y))
956

957

958
if __name__ == "__main__":
959
    run_tests()
960

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.