pytorch

Форк
0
/
test_ts_opinfo.py 
380 строк · 12.6 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import functools
4
import itertools
5
import os
6
from pathlib import Path
7
from typing import Sequence
8
from unittest import skip
9

10
import yaml
11

12
import torch
13
import torch._lazy
14
import torch._lazy.config
15
import torch._lazy.ir_cache
16
import torch._lazy.metrics
17
import torch._lazy.ts_backend
18
from torch.testing._internal.common_device_type import (
19
    instantiate_device_type_tests,
20
    ops,
21
)
22
from torch.testing._internal.common_methods_invocations import op_db
23
from torch.testing._internal.common_utils import run_tests, TestCase
24
from torch.testing._internal.jit_utils import JitTestCase
25

26

27
torch._lazy.ts_backend.init()
28

29

30
def get_test_device():
31
    return "cuda" if "LTC_TS_CUDA" in os.environ else "cpu"
32

33

34
def remove_suffixes(l):
35
    return [x.split(".")[0] for x in l]
36

37

38
def init_lists():
39
    path_to_script = Path(os.path.abspath(os.path.dirname(__file__)))
40
    TS_NATIVE_FUNCTIONS_PATH = (
41
        path_to_script.parent.parent / "aten/src/ATen/native/ts_native_functions.yaml"
42
    )
43
    with open(TS_NATIVE_FUNCTIONS_PATH) as f:
44
        yaml_ts = yaml.load(f, yaml.SafeLoader)
45
    LAZY_OPS_LIST = set(
46
        remove_suffixes(
47
            itertools.chain(
48
                yaml_ts["full_codegen"], yaml_ts["supported"], yaml_ts["autograd"]
49
            )
50
        )
51
    )
52
    HAS_SYMINT_SUFFIX = yaml_ts["symint"]
53
    FALLBACK_LIST = {"clamp"}
54
    SKIP_RUNTIME_ERROR_LIST = {
55
        "index_select",  # Empty output_sizes is not supported
56
        "clone",  # is clone decomposed?
57
        # General ASAN Failure due to related to generating bool values.
58
        # https://github.com/pytorch/pytorch/issues/74519
59
        # https://github.com/pytorch/pytorch/issues/63034
60
        "nonzero",  # ASAN failure (paste: P501906539)
61
        "all",  # ASAN failure
62
        "any",  # ASAN failure
63
        "logdet",  # ASAN failure
64
    }
65
    SKIP_INCORRECT_RESULTS_LIST = {
66
        "squeeze",  # Value out of range
67
        "t",  # Value out of range
68
        "transpose",  # Value out of range
69
        "bernoulli",  # incorrect results
70
        "pow",  # incorrect results
71
        "addcdiv",  # incorrect results (on CI not locally?)
72
    }
73
    # The following ops all show up directly in ts_native_functions.yaml,
74
    # but run functionalized versions of the composite kernels in core.
75
    # This means that we don't expect the ops to show directly in the LTC metrics.
76
    FUNCTIONAL_DECOMPOSE_LIST = {
77
        "diag_embed",
78
        "block_diag",
79
        "new_empty_strided",
80
        "narrow_copy",
81
        "pixel_shuffle",
82
        "pixel_unshuffle",
83
        "select_backward",
84
        "_trilinear",
85
        "linalg_inv_ex",
86
        "linalg_pinv.atol_rtol_tensor",
87
        "logsumexp",
88
    }
89
    # For some ops, we don't support all variants. Here we use formatted_name
90
    # to uniquely identify the variant.
91
    SKIP_VARIANT_LIST = {"norm_nuc", "min_reduction_with_dim"}
92

93
    return (
94
        LAZY_OPS_LIST,
95
        FALLBACK_LIST,
96
        SKIP_RUNTIME_ERROR_LIST,
97
        SKIP_INCORRECT_RESULTS_LIST,
98
        FUNCTIONAL_DECOMPOSE_LIST,
99
        HAS_SYMINT_SUFFIX,
100
        SKIP_VARIANT_LIST,
101
    )
102

103

104
(
105
    LAZY_OPS_LIST,
106
    FALLBACK_LIST,
107
    SKIP_RUNTIME_ERROR_LIST,
108
    SKIP_INCORRECT_RESULTS_LIST,
109
    FUNCTIONAL_DECOMPOSE_LIST,
110
    HAS_SYMINT_SUFFIX,
111
    SKIP_VARIANT_LIST,
112
) = init_lists()
113

114
torch.manual_seed(42)
115

116

117
def clone_move(t):
118
    dev = "lazy"
119
    copy_t = t.detach().clone().requires_grad_(True).to(device=dev)
120
    return copy_t
121

122

123
class TestLazyTensor(JitTestCase):
124
    @skip("Disable until autograd supports symints")
125
    def testConvolutionBackward(self):
126
        test_device = get_test_device()
127
        inp = torch.rand(1, 3, 128, 128, device=test_device, requires_grad=True)
128
        inp_copy = clone_move(inp)
129
        grad = torch.rand(1, 32, 121, 121, device=test_device)  # no requires_grad
130
        grad_copy = clone_move(grad)
131
        weight = torch.rand(32, 3, 8, 8, device=test_device, requires_grad=True)
132
        weight_copy = clone_move(weight)
133
        bias = torch.rand(32, device=test_device, requires_grad=True)
134
        bias_copy = clone_move(bias)
135

136
        # run eager
137
        conv_out = torch.nn.functional.conv2d(inp, weight, bias)
138
        (inp_grad, weight_grad, bias_grad) = torch.autograd.grad(
139
            [conv_out], [inp, weight, bias], [grad]
140
        )
141

142
        # run lazy
143
        conv_copy_out = torch.nn.functional.conv2d(inp_copy, weight_copy, bias_copy)
144
        (inp_copy_grad, weight_copy_grad, bias_copy_grad) = torch.autograd.grad(
145
            [conv_copy_out], [inp_copy, weight_copy, bias_copy], [grad_copy]
146
        )
147

148
        # check numerics
149
        torch.testing.assert_close(bias_copy_grad.cpu(), bias_grad.cpu())
150

151
        torch.testing.assert_close(weight_copy_grad.cpu(), weight_grad.cpu())
152
        torch.testing.assert_close(inp_copy_grad.cpu(), inp_grad.cpu())
153

154
    def test_view_mark_step_preserved(self):
155
        test_device = get_test_device()
156
        inp = torch.rand(4, device=test_device)
157
        inp_lazy = clone_move(inp)
158

159
        def foo(x, *, mark_step):
160
            y = x.view(2, 2)
161
            y.add_(1)
162
            z = x + x
163

164
            if mark_step:
165
                torch._lazy.mark_step()
166

167
            # y and x should contiue to be aliased after the mark_step call.
168
            y.add_(1)
169
            return x
170

171
        out_ref = foo(inp, mark_step=False)
172
        out = foo(inp_lazy, mark_step=True)
173
        # out will have some pending mutations, which will be synced by the .cpu() call.
174
        torch.testing.assert_close(out_ref.cpu(), out.cpu())
175

176
    def test_tensor_ctr(self):
177
        test_device = get_test_device()
178
        inp = torch.tensor([[1, 2, 3, 4, 5]], device=test_device)
179
        inp_lazy = torch.tensor([[1, 2, 3, 4, 5]], device="lazy")
180

181
        def foo(x):
182
            # Calling a view op to ensure that functionalization wrapping occurs.
183
            return x.view(-1)
184

185
        out_ref = foo(inp)
186
        out = foo(inp_lazy)
187
        torch.testing.assert_close(out_ref.cpu(), out.cpu())
188

189

190
class TestLazyOpInfo(TestCase):
191
    @ops(
192
        [
193
            op
194
            for op in op_db
195
            if op.name in LAZY_OPS_LIST
196
            and op.name not in SKIP_RUNTIME_ERROR_LIST
197
            and op.name not in FUNCTIONAL_DECOMPOSE_LIST
198
            and op.formatted_name not in SKIP_VARIANT_LIST
199
        ],
200
        allowed_dtypes=(torch.float,),
201
    )
202
    def test_dispatched_to_lazy(self, device, dtype, op):
203
        def get_name(op):
204
            l = [op.name]
205
            if op.variant_test_name != "":
206
                l.append(op.variant_test_name)
207
            return ".".join(l)
208

209
        global HAS_SYMINT_SUFFIX, FALLBACK_LIST
210
        samples = op.sample_inputs("lazy", dtype, requires_grad=False)
211
        sample = next(iter(samples))
212
        args = [sample.input] + list(sample.args)
213
        kwargs = sample.kwargs
214
        torch._lazy.mark_step()
215
        torch._lazy.wait_device_ops()
216
        torch._lazy.metrics.reset()
217

218
        r = op(*args, **kwargs)
219
        torch._lazy.mark_step()
220
        torch._lazy.wait_device_ops()
221
        prefix = "aten" if op.name in FALLBACK_LIST else "lazy"
222
        symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else ""
223
        found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes(
224
            torch._lazy.metrics.counter_names()
225
        )
226
        # check aliases
227
        if not found:
228
            for alias in op.aliases:
229
                alias_found = (
230
                    f"{prefix}::{alias.name}{symint_suffix}"
231
                    in remove_suffixes(torch._lazy.metrics.counter_names())
232
                )
233
                found = found or alias_found
234
                if found:
235
                    break
236
        self.assertTrue(found)
237

238
    @ops(
239
        [
240
            op
241
            for op in op_db
242
            if op.name in LAZY_OPS_LIST
243
            and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST
244
        ],
245
        allowed_dtypes=(torch.float,),
246
    )  # noqa: B950
247
    def test_correctness(self, device, dtype, op):
248
        test_device = get_test_device()
249

250
        def clone_to_device(input, dev):
251
            if isinstance(input, torch.Tensor):
252
                return input.detach().clone().to(device=dev)
253
            if isinstance(input, Sequence) and not isinstance(input, str):
254
                return tuple(map(functools.partial(clone_to_device, dev=dev), input))
255
            return input
256

257
        def assert_allclose_rec(t):
258
            a, b = t
259
            self.assertEqual(type(a), type(b))
260
            if isinstance(a, torch.Tensor):
261
                self.assertTrue(
262
                    torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)
263
                )
264

265
            if isinstance(a, Sequence):
266
                map(assert_allclose_rec, zip(a, b))
267

268
        samples = op.sample_inputs("lazy", dtype, requires_grad=False)
269
        for sample in samples:
270
            # Need to run mark step so that all random ops are computed in the right order
271
            torch._lazy.mark_step()
272

273
            args = [sample.input] + list(sample.args)
274
            kwargs = sample.kwargs
275
            copy_args = clone_to_device(args, test_device)
276

277
            r_exp = op(*copy_args, **kwargs)
278
            r_actual = op(*args, **kwargs)
279

280
            torch._lazy.mark_step()
281
            assert_allclose_rec((r_actual, r_exp))
282

283
    @ops(
284
        [
285
            op
286
            for op in op_db
287
            if op.name in LAZY_OPS_LIST
288
            and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST
289
        ],
290
        allowed_dtypes=(torch.float,),
291
    )  # noqa: B950
292
    def test_correctness_with_reusing_ir(self, device, dtype, op):
293
        torch._lazy.config.set_reuse_ir(True)
294
        test_device = get_test_device()
295

296
        def clone_to_device(input, dev):
297
            if isinstance(input, torch.Tensor):
298
                return input.detach().clone().to(device=dev)
299
            if isinstance(input, Sequence) and not isinstance(input, str):
300
                return tuple(map(functools.partial(clone_to_device, dev=dev), input))
301
            return input
302

303
        def assert_allclose_rec(t):
304
            a, b = t
305
            self.assertEqual(type(a), type(b))
306
            if isinstance(a, torch.Tensor):
307
                self.assertTrue(
308
                    torch.allclose(clone_to_device(a, test_device), b, atol=1e-4)
309
                )
310

311
            if isinstance(a, Sequence):
312
                map(assert_allclose_rec, zip(a, b))
313

314
        samples = op.sample_inputs("lazy", dtype, requires_grad=False)
315
        for sample in samples:
316
            # Need to run mark step so that all random ops are computed in the right order
317
            torch._lazy.mark_step()
318

319
            args = [sample.input] + list(sample.args)
320
            kwargs = sample.kwargs
321
            copy_args = clone_to_device(args, test_device)
322

323
            r_exp = op(*copy_args, **kwargs)
324
            r_actual = op(*args, **kwargs)
325

326
            torch._lazy.mark_step()
327
            assert_allclose_rec((r_actual, r_exp))
328

329
        torch._lazy.ir_cache.reset()
330
        torch._lazy.config.set_reuse_ir(False)
331

332

333
# TODO: after we move to master, add Lazy as a new Device here:
334
# https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L532
335
instantiate_device_type_tests(TestLazyOpInfo, globals(), only_for="cpu")
336

337

338
class TestLazyDynamicOps(TestCase):
339
    @classmethod
340
    def setUpClass(cls) -> None:
341
        # Setup the dynamic shape mode
342
        cls.old_ssa_mode = torch._C._lazy._get_symbolic_shape_mode()
343
        torch._C._lazy._set_symbolic_shape_mode(True)
344
        return super().setUpClass()
345

346
    @classmethod
347
    def tearDownClass(cls) -> None:
348
        torch._C._lazy._set_symbolic_shape_mode(cls.old_ssa_mode)
349
        return super().tearDownClass()
350

351
    def test_nonzero_dynamic(self):
352
        # Test that nonzero gives upper bounds sizes when symbolic shape mode is enabled
353
        test_device = get_test_device()
354
        x1 = torch.tensor(
355
            [[0, 1.0, 2.0], [3.0, 0, 0]], device=test_device, requires_grad=True
356
        )
357
        x1_lazy = clone_move(x1)
358
        x2_lazy = torch.nonzero(x1_lazy)
359

360
        # FIXME: Add bindings to get upper bounds
361
        # self.assertEqual(tuple(x2_lazy.size()), (6, 2))
362

363
        # We should still be able to instantiate it and get the actual result
364
        x2_eager = x2_lazy.cpu()
365
        self.assertEqual(tuple(x2_eager.size()), (3, 2))
366

367
    def test_adaptiveavgpool3d_dynamic(self):
368
        # Test that adaptive_avg_pool3d gives correct shapes with lazy backend
369
        img_cpu = torch.zeros([2, 3, 4, 5, 6], device="cpu")
370
        out_cpu = torch.nn.AdaptiveAvgPool3d(2).to(device="cpu")(img_cpu)
371

372
        test_device = get_test_device()
373
        img_lazy = torch.zeros([2, 3, 4, 5, 6], device=test_device)
374
        out_lazy = torch.nn.AdaptiveAvgPool3d(2).to(test_device)(img_lazy)
375

376
        self.assertEqual(out_cpu.shape, out_lazy.shape)
377

378

379
if __name__ == "__main__":
380
    run_tests()
381

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.