pytorch

Форк
0
/
test_autocast.py 
286 строк · 11.8 Кб
1
# Owner(s): ["module: unknown"]
2

3
import collections
4
import unittest
5

6
import torch
7
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfTorchDynamo
8
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
9
from torch.utils._python_dispatch import TorchDispatchMode
10

11
class TestAutocastCPU(TestCase):
12
    def setUp(self):
13
        super().setUp()
14
        self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
15

16
    def tearDown(self):
17
        del self.autocast_lists
18
        super().tearDown()
19

20
    def _run_autocast_outofplace(
21
        self,
22
        op,
23
        args,
24
        run_as_type,
25
        out_type=None,
26
        module=torch,
27
        add_kwargs=None,
28
        amp_dtype=torch.bfloat16,
29
    ):
30
        # helper to cast args
31
        def cast(val, to_type):
32
            if isinstance(val, torch.Tensor):
33
                return val.to(to_type) if val.is_floating_point() else val
34
            elif isinstance(val, collections.abc.Iterable):
35
                return type(val)(cast(v, to_type) for v in val)
36
            else:
37
                return val
38

39
        if add_kwargs is None:
40
            add_kwargs = {}
41

42
        self.assertFalse(torch.is_autocast_cpu_enabled())
43
        with torch.cpu.amp.autocast(dtype=amp_dtype):
44
            self.assertTrue(torch.is_autocast_cpu_enabled())
45
            out_type = out_type if out_type is not None else run_as_type
46
            output = output_method = None
47

48
            # Try module.* variant, if requested:
49
            if module is not None and hasattr(module, op):
50
                output = getattr(module, op)(*args, **add_kwargs)
51
                if isinstance(output, torch.Tensor):
52
                    self.assertTrue(out_type == output.dtype,
53
                                    f"autocast for torch.{op} produced {output.dtype}, should produce {out_type}")
54
            # Try Tensor.* variant:
55
            if hasattr(torch.Tensor, op):
56
                output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
57
                if isinstance(output_method, torch.Tensor):
58
                    self.assertTrue(out_type == output_method.dtype,
59
                                    "autocast for torch.{} produced {}, should produce torch.{}"
60
                                    .format(op, output_method.dtype, out_type))
61

62
            self.assertTrue((output is not None) or (output_method is not None),
63
                            f"{op} not found as an attribute on either Tensor or the requested module {module}")
64

65
            # Accounts for ops that return Tensors, iterables, and other non-Tensors.
66
            # For example, lstm_cell returns a tuple and equal returns bool.
67
            def compare(first, second):
68
                if isinstance(first, torch.Tensor):
69
                    return torch.equal(first, second)
70
                elif isinstance(first, collections.abc.Iterable):
71
                    return all(compare(f, s) for f, s in zip(first, second))
72
                else:
73
                    return first == second
74

75
            # If both torch.* and Tensor.* variants were found, check outputs are identical
76
            if (output is not None) and (output_method is not None):
77
                self.assertTrue(type(output) == type(output_method))
78
                comparison = compare(output, output_method)
79
                self.assertTrue(comparison, f"torch.{op} result did not match Tensor.{op} result")
80

81
            # Compare numerics to Python-side "autocasting" that (we expect) does the same thing
82
            # as the C++-side autocasting, and should be bitwise accurate.
83
            output_to_compare = output if output is not None else output_method
84
            with torch.cpu.amp.autocast(enabled=False):
85
                self.assertFalse(torch.is_autocast_cpu_enabled())
86

87
                if module is not None and hasattr(module, op):
88
                    control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs)
89
                else:
90
                    control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs)
91
                self.assertTrue(type(output_to_compare) == type(control))
92
                comparison = compare(output_to_compare, control)
93
                self.assertTrue(comparison, f"torch.{op} result did not match control")
94
            self.assertTrue(torch.is_autocast_cpu_enabled())
95
        self.assertFalse(torch.is_autocast_cpu_enabled())
96

97
    def args_maybe_kwargs(self, op_with_args):
98
        if len(op_with_args) == 2:
99
            return op_with_args[0], op_with_args[1], {}
100
        else:
101
            return op_with_args[0], op_with_args[1], op_with_args[2]
102

103
    @skipIfTorchDynamo()
104
    def test_autocast_torch_expect_builtin_promote(self):
105
        for op, args1, args2, out_type in self.autocast_lists.torch_expect_builtin_promote:
106
            self._run_autocast_outofplace(op, args1, torch.float32, out_type=out_type)
107
            self._run_autocast_outofplace(op, args2, torch.float32, out_type=out_type, amp_dtype=torch.float16)
108

109
    @skipIfTorchDynamo()
110
    def test_autocast_methods_expect_builtin_promote(self):
111
        for op, args1, args2, out_type in self.autocast_lists.methods_expect_builtin_promote:
112
            self._run_autocast_outofplace(op, args1, torch.float32, module=None, out_type=out_type)
113
            self._run_autocast_outofplace(op, args2, torch.float32, module=None, out_type=out_type, amp_dtype=torch.float16)
114

115
    @skipIfTorchDynamo()
116
    def test_autocast_torch_16(self):
117
        for op_with_args in self.autocast_lists.torch_16:
118
            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
119
            self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
120
            self._run_autocast_outofplace(op, args, torch.float16, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
121

122
    @skipIfTorchDynamo()
123
    def test_autocast_nn_16(self):
124
        for op_with_args in self.autocast_lists.nn_16:
125
            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
126
            self._run_autocast_outofplace(
127
                op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs
128
            )
129
            self._run_autocast_outofplace(
130
                op,
131
                args,
132
                torch.float16,
133
                module=torch._C._nn,
134
                add_kwargs=maybe_kwargs,
135
                amp_dtype=torch.float16,
136
            )
137

138
    @skipIfTorchDynamo()
139
    def test_autocast_torch_fp32(self):
140
        for op_with_args in self.autocast_lists.torch_fp32:
141
            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
142
            self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
143
            self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs, amp_dtype=torch.float16)
144

145
    @skipIfTorchDynamo()
146
    def test_autocast_nn_fp32(self):
147
        for op_with_args in self.autocast_lists.nn_fp32:
148
            op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
149
            self._run_autocast_outofplace(
150
                op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs
151
            )
152
            self._run_autocast_outofplace(
153
                op,
154
                args,
155
                torch.float32,
156
                module=torch._C._nn,
157
                add_kwargs=maybe_kwargs,
158
                amp_dtype=torch.float16,
159
            )
160

161
    @skipIfTorchDynamo()
162
    def test_autocast_torch_need_autocast_promote(self):
163
        for op, args1, args2 in self.autocast_lists.torch_need_autocast_promote:
164
            self._run_autocast_outofplace(op, args1, torch.float32)
165
            self._run_autocast_outofplace(op, args2, torch.float32, amp_dtype=torch.float16)
166

167
    @unittest.skipIf(IS_WINDOWS, "Limit support for bf16 path")
168
    def test_autocast_rnn(self):
169
        if torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported():
170
            x = torch.randn(1, 2, 1)
171
            hx = torch.randn(2, 2, 1)
172
            cx = torch.randn(2, 2, 1)
173

174
            m = torch.nn.LSTM(1, 1, 2).to(torch.bfloat16)
175

176
            # Raise ValueError when autocast is not enabled
177
            with self.assertRaisesRegex(ValueError, "input must have the type"):
178
                m(x, (hx, cx))
179

180
            # Should be able to run the below case with autocast
181
            with torch.cpu.amp.autocast():
182
                m(x, (hx, cx))
183

184
    def test_autocast_disabled_with_fp32_dtype(self):
185
        with torch.autocast(device_type='cpu', dtype=torch.float32, enabled=False):
186
            _ = torch.ones(10)
187

188
class CustomLinear(torch.autograd.Function):
189
    @staticmethod
190
    def forward(ctx, x, w_t):
191
        ctx.save_for_backward(x, w_t)
192
        return torch.nn.functional.linear(x, w_t)
193

194
    @staticmethod
195
    def backward(ctx, grad_output):
196
        x, w_t = ctx.saved_tensors
197
        with torch.autocast(device_type='cuda'):
198
            dL_dX = torch.matmul(grad_output, w_t)
199
            dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
200
        return dL_dX, dL_dW
201

202
class WeightDTypeCastCounterMode(TorchDispatchMode):
203

204
    def __init__(self, weight):
205
        super().__init__()
206
        self.dtype_cast_counter = 0
207
        self.weight = weight
208

209
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
210
        if (
211
            func is torch.ops.aten._to_copy.default and
212
            args[0] is self.weight and
213
            kwargs['dtype'] is torch.float16
214
        ):
215
            self.dtype_cast_counter += 1
216
        return func(*args, **kwargs)
217

218
    def __enter__(self):
219
        self.old_clear_cache = torch.clear_autocast_cache
220
        torch.clear_autocast_cache = lambda: None
221
        return super().__enter__()
222

223
    def __exit__(self, exc_type, exc_val, exc_tb):
224
        torch.clear_autocast_cache = self.old_clear_cache
225
        return super().__exit__(exc_type, exc_val, exc_tb)
226

227
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
228
class TestAutocastGPU(TestCase):
229
    def test_cast_cache_is_global(self):
230
        """
231
        Verifies that the autocast cache is global. This is done by
232
        mocking out cache clearing at the end of the forward pass,
233
        running forward+backward with an explicit call to autocast in the
234
        backward, and verifying that the weight only get cast to float16 once.
235
        """
236

237
        data = torch.randn(2, 3).cuda()
238
        weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
239

240
        with WeightDTypeCastCounterMode(weight) as mode:
241
            with torch.autocast(device_type='cuda'):
242
                output = CustomLinear.apply(data, weight)
243
                s = output.sum()
244
            s.backward()
245

246
        self.assertEqual(mode.dtype_cast_counter, 1)
247

248
    def test_cache_disabled(self):
249

250
        data = torch.randn(2, 3).cuda()
251
        weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
252

253
        try:
254
            torch._C._set_cached_tensors_enabled(True)
255
            torch._C._add_cached_tensor(weight)
256

257
            with WeightDTypeCastCounterMode(weight) as mode:
258
                with torch.autocast(device_type='cuda'):
259
                    output = CustomLinear.apply(data, weight)
260
                    s = output.sum()
261
                s.backward()
262

263
            # we should not have cached the conversion of the weight
264
            self.assertEqual(mode.dtype_cast_counter, 2)
265

266
        finally:
267
            torch._C._set_cached_tensors_enabled(False)
268

269

270
class TestTorchAutocast(TestCase):
271
    def test_autocast_fast_dtype(self):
272
        gpu_fast_dtype = torch.get_autocast_gpu_dtype()
273
        cpu_fast_dtype = torch.get_autocast_cpu_dtype()
274
        self.assertEqual(gpu_fast_dtype, torch.half)
275
        self.assertEqual(cpu_fast_dtype, torch.bfloat16)
276

277
    def test_invalid_device(self):
278
        dev = 'not a real device'
279
        msg = f'unsupported autocast device_type \'{dev}\''
280
        with self.assertRaisesRegex(RuntimeError, msg):
281
            with torch.autocast(device_type=dev):
282
                _ = torch.tensor(1)
283

284

285
if __name__ == '__main__':
286
    run_tests()
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.