pytorch

Форк
0
/
test_schema_check.py 
510 строк · 21.3 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
import torch
6
from torch.utils._pytree import tree_map
7
import unittest
8

9
from torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO
10
from torch.fx.operator_schemas import normalize_function
11
from torch._subclasses.schema_check_mode import SchemaCheckMode
12
from torch.utils._python_dispatch import TorchDispatchMode
13
from torch.testing._internal.common_methods_invocations import op_db
14
from torch.testing._internal.jit_utils import JitTestCase
15
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
16
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17
sys.path.append(pytorch_test_dir)
18

19
def secretly_aliasing(x):
20
    return x.view(-1)
21

22
def secretly_mutating(x):
23
    x.mul_(2)
24
    return x * 3
25

26
def output_is_input(x):
27
    return x
28

29
custom_lib = torch.library.Library("bad_schemas", "DEF")  # noqa: TOR901
30
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
31
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
32
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
33

34
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")  # noqa: TOR901
35
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
36
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
37
custom_lib_cpu.impl("output_is_input", output_is_input)
38

39
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")  # noqa: TOR901
40
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
41
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
42
custom_lib_meta.impl("output_is_input", output_is_input)
43

44
# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
45
# which is then used to test that SchemaCheckMode behaves as expected
46

47
class IncorrectAliasTensor(torch.Tensor):
48
    ALIAS_ARG_OUT = {"aten::add"}
49
    ALIAS_OUT_OUT = {"aten::aminmax"}
50
    MUTATE_ARGS_OUT = {"aten::sub"}
51

52
    elem: torch.Tensor
53

54
    __slots__ = ['elem']
55

56
    @staticmethod
57
    def __new__(cls, elem, *args, **kwargs):
58
        # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any
59
        # memory for the class in question, but it should still
60
        # advertise the same device as before
61
        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
62
            cls, elem.size(),
63
            strides=elem.stride(), storage_offset=elem.storage_offset(),
64
            # TODO: clone storage aliasing
65
            dtype=elem.dtype, layout=elem.layout,
66
            device=elem.device, requires_grad=kwargs.get("requires_grad", False)
67
        )
68
        # ...the real tensor is held as an element on the tensor.
69
        r.elem = elem.detach() if r.requires_grad else elem
70
        return r
71

72
    def __repr__(self):
73
        return super().__repr__(tensor_contents=f"{self.elem}")
74

75
    @classmethod
76
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
77
        def unwrap(e):
78
            return e.elem if isinstance(e, cls) else e
79

80
        def wrap(e):
81
            return cls(e) if isinstance(e, torch.Tensor) else e
82
        unwrapped_args = tree_map(unwrap, args)
83
        out = func(*unwrapped_args, **tree_map(unwrap, kwargs))
84
        if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT:
85
            args[0].elem = out
86
        if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT:
87
            args[0].elem = torch.rand(args[0].elem.shape)
88
        if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT:
89
            incorrect_out = list(out)
90
            incorrect_out[0] = incorrect_out[1]
91
            return tree_map(wrap, tuple(incorrect_out))
92

93
        return tree_map(wrap, out)
94

95
# Tests various schema checking functionalities.
96
class TestSchemaCheck(JitTestCase):
97
    def setUp(self):
98
        if TEST_WITH_TORCHDYNAMO:
99
            self.skipTest("SchemaCheckMode is ignored by dynamo")
100
        super().setUp()
101

102
    # Tests that SchemaCheckMode records operator order with grad
103
    def test_schema_check_mode_operator_order(self):
104
        with SchemaCheckMode() as schema_check:
105
            x = torch.rand((3, 3), requires_grad=True)
106
            x.relu().sin()
107
        self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
108

109
    # Tests that SchemaCheckMode records operator order without grad
110
    def test_schema_check_mode_operator_order_without_grad(self):
111
        with SchemaCheckMode() as schema_check:
112
            x = torch.rand((3, 3), requires_grad=False)
113
            x.relu().sin()
114
        self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
115

116
    # Tests that SchemaCheckMode records mutations and aliases with none expected
117
    def test_schema_check_mode_mutated_aliasing_none(self):
118
        # NB: previously requires_grad=True, but this induces a detach for
119
        # saved variable
120
        x = torch.rand((3, 3))
121
        with SchemaCheckMode() as schema_check:
122
            actual = x.relu().sin()
123
        self.assertEqual([], schema_check.mutated)
124
        self.assertEqual([], schema_check.aliasing)
125

126
    # Tests that SchemaCheckMode records mutations and aliases with mutation expected
127
    def test_schema_check_mode_mutated_aliasing_mutation(self):
128
        actual = torch.rand((3, 3), requires_grad=False)
129
        with SchemaCheckMode() as schema_check:
130
            actual.sinh_()
131
        self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
132
        self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
133

134
    # Tests that SchemaCheckMode records mutations and aliases with resize_
135
    def test_schema_check_mode_mutated_aliasing_resize_(self):
136
        actual = torch.rand((3, 3), requires_grad=False)
137
        with SchemaCheckMode() as schema_check:
138
            actual.resize_(9)
139
        self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
140
        self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
141

142
    # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs
143
    def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
144
        actual = torch.rand((3, 3))
145
        y = actual
146
        with SchemaCheckMode() as schema_check:
147
            actual.add_(y)
148
        self.assertEqual(
149
            [
150
                ('aten::add_', 'input'),
151
                ('aten::add_', 'other')
152
            ],
153
            schema_check.mutated
154
        )
155
        self.assertEqual(
156
            [
157
                ('aten::add_', 'input', 'output_0'),
158
                ('aten::add_', 'other', 'output_0')
159
            ],
160
            schema_check.aliasing
161
        )
162

163
    # Tests that SchemaCheckMode records mutations and alias with as_strided
164
    def test_schema_check_mode_mutated_aliasing_as_strided(self):
165
        x = torch.rand((3, 6, 4))
166
        with SchemaCheckMode() as schema_check:
167
            x.as_strided_([3, 6, 4], [9, 1, 1])
168
        self.assertEqual(
169
            [
170
                ('aten::as_strided_', 'input')
171
            ],
172
            schema_check.mutated
173
        )
174
        self.assertEqual(
175
            [
176
                ('aten::as_strided_', 'input', 'output_0')
177
            ],
178
            schema_check.aliasing
179
        )
180

181
    # Tests that SchemaCheckMode records mutations and aliases with multiple outputs
182
    def test_schema_check_mode_mutated_aliasing_multiple_outputs(self):
183
        x = torch.arange(9.)
184
        m_actual = torch.arange(9.)
185
        e_actual = torch.zeros([9], dtype=torch.int32)
186
        with SchemaCheckMode() as schema_check:
187
            torch.frexp(x, out=(m_actual, e_actual))
188
        self.assertEqual(
189
            [
190
                ('aten::frexp', 'mantissa'),
191
                ('aten::frexp', 'exponent')
192
            ],
193
            schema_check.mutated
194
        )
195
        self.assertEqual(
196
            [
197
                ('aten::frexp', 'mantissa', 'output_0'),
198
                ('aten::frexp', 'exponent', 'output_1')
199
            ],
200
            schema_check.aliasing
201
        )
202

203
    # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs
204
    def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
205
        x = torch.rand((3, 3))
206
        actual = torch.zeros(3)
207
        with SchemaCheckMode() as schema_check:
208
            torch.aminmax(x, dim=0, out=[actual, actual])
209
        self.assertEqual(
210
            [
211
                ('aten::aminmax', 'min'),
212
                ('aten::aminmax', 'max')
213
            ],
214
            schema_check.mutated
215
        )
216
        self.assertEqual(
217
            [
218
                ('aten::aminmax', 'min', 'output_0'),
219
                ('aten::aminmax', 'min', 'output_1'),
220
                ('aten::aminmax', 'max', 'output_0'),
221
                ('aten::aminmax', 'max', 'output_1')
222
            ],
223
            schema_check.aliasing
224
        )
225

226
    # Tests that SchemaCheckMode wraps torch.Tensor
227
    def test_schema_check_mode_functionality(self):
228
        x = torch.rand((3, 3), requires_grad=True)
229
        expected = x.relu().sin()
230
        with SchemaCheckMode():
231
            actual = x.relu().sin()
232
        self.assertEqual(expected, actual)
233

234
    # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden
235
    def test_schema_check_mode_functionality_default_replaced(self):
236
        x = torch.rand((3, 3), requires_grad=True)
237
        expected = x.add(x, alpha=2)
238
        with SchemaCheckMode():
239
            actual = x.add(x, alpha=2)
240
        self.assertEqual(expected, actual)
241

242
    # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument
243
    def test_schema_check_mode_functionality_list_input(self):
244
        a = torch.rand((3, 3))
245
        b = torch.rand((3, 3))
246
        c = torch.rand((3, 3))
247
        expected = torch.linalg.multi_dot([a, b, c])
248
        with SchemaCheckMode():
249
            actual = torch.linalg.multi_dot([a, b, c])
250
        self.assertEqual(expected, actual)
251

252
    # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation
253
    def test_schema_check_mode_functionality_wildcard_after(self):
254
        x = torch.rand((3, 3))
255
        expected = x.chunk(6)
256
        with SchemaCheckMode():
257
            actual = x.chunk(6)
258
        self.assertEqual(expected, actual)
259

260
    # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input
261
    @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.")
262
    def test_schema_check_mode_functionality_kwarg_tensor(self):
263
        x = torch.rand((3, 5))
264
        w = torch.rand(4)
265
        expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
266
        with SchemaCheckMode():
267
            actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
268
        self.assertEqual(expected, actual)
269

270
    # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op
271
    def test_schema_check_mode_functionality_mutable_inputs(self):
272
        expected = torch.rand((3, 3), requires_grad=False)
273
        actual = torch.clone(expected)
274
        expected.sinh_()
275
        with SchemaCheckMode():
276
            actual.sinh_()
277
        self.assertEqual(expected, actual)
278

279
    # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias
280
    def test_schema_check_mode_functionality_aliasing_inputs(self):
281
        expected = torch.rand((3, 3))
282
        x = expected
283
        actual = torch.clone(expected)
284
        y = actual
285
        expected.add_(x)
286
        with SchemaCheckMode():
287
            actual.add_(y)
288
        self.assertEqual(expected, actual)
289

290
    # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs
291
    def test_schema_check_mode_functionality_with_multiple_outputs(self):
292
        x = torch.arange(9.)
293
        m_expected, e_expected = torch.frexp(x)
294
        m_actual = torch.arange(9.)
295
        e_actual = torch.zeros([9], dtype=torch.int32)
296
        with SchemaCheckMode():
297
            torch.frexp(x, out=(m_actual, e_actual))
298
        self.assertEqual(m_expected, m_actual)
299
        self.assertEqual(e_expected, e_actual)
300

301
    # Tests that SchemaCheckMode wraps Torch.tensor with aliasing outputs due to aliasing inputs
302
    def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
303
        x = torch.rand((3, 3))
304
        actual = torch.zeros(3)
305
        with SchemaCheckMode():
306
            torch.aminmax(x, dim=0, out=[actual, actual])
307
        self.assertEqual(torch.amax(x, dim=0), actual)
308

309
    # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
310
    def test_schema_check_mode_functionality_device_input(self):
311
        with SchemaCheckMode():
312
            x = torch.rand((3, 3), device="cpu", dtype=torch.double)
313
            y = x + x
314
        self.assertEqual(x + x, y)
315

316
    # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case
317
    def test_schema_check_mode_functionality_training_op(self):
318
        x = torch.rand((3, 3), requires_grad=True)
319
        batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
320
        expected = batch(x)
321
        with SchemaCheckMode():
322
            actual = batch(x)
323
        self.assertEqual(expected, actual)
324

325
    # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case
326
    def test_schema_check_mode_functionality_nested_training_op(self):
327
        actual = torch.rand((3, 3))
328
        batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
329
        expected = torch.clone(actual)
330
        expected.sinh_()
331
        expected.tanh_()
332
        expected.relu_()
333
        expected = batch(expected)
334

335
        with SchemaCheckMode():
336
            actual.sinh_()
337
            actual.tanh_()
338
            actual.relu_()
339
            actual = batch(actual)
340
        self.assertEqual(expected, actual)
341

342
    # Tests that SchemaCheckMode wraps Torch.tensor with empty list input
343
    def test_schema_check_mode_empty_list_input(self):
344
        expected = torch.atleast_1d([])
345
        with SchemaCheckMode():
346
            actual = torch.atleast_1d([])
347
        self.assertEqual(expected, actual)
348

349
    # Tests that an exception is raised for a mismatching mutation
350
    def test_mutation_check_fail(self):
351
        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
352
            x = torch.rand((3, 3))
353
            y = torch.rand((3, 3))
354
            with SchemaCheckMode():
355
                IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
356

357
    # # Tests that an exception is raised for a mismatching mutation over multiple ops
358
    def test_mutation_check_fail_multiple_operators(self):
359
        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
360
            x = torch.rand((3, 3))
361
            y = torch.rand((3, 3))
362
            with SchemaCheckMode():
363
                IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
364

365
    # Tests that an exception is raised for a mismatching alias
366
    def test_alias_check_fail_simple(self):
367
        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
368
            x = torch.rand((3, 3), requires_grad=True)
369
            y = torch.rand((3, 3))
370
            with SchemaCheckMode():
371
                IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
372

373
    # Tests that an exception is raised for a mismatching alias over multiple ops
374
    def test_alias_check_fail_multiple_operators(self):
375
        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
376
            x = torch.rand((3, 3), requires_grad=True)
377
            y = torch.zeros((3, 3), requires_grad=True)
378
            with SchemaCheckMode():
379
                IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
380

381
    # Tests that an exception is raised for a centered mismatching alias over multiple ops
382
    def test_alias_check_fail_multiple_operators_centered(self):
383
        with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
384
            x = torch.rand((3, 3), requires_grad=True)
385
            y = torch.zeros((3, 3), requires_grad=True)
386
            with SchemaCheckMode():
387
                IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
388

389
    # Tests that an exception is raised for a centered mismatching alias over multiple ops
390
    def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
391
        with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
392
            x = torch.rand((3, 3))
393
            with SchemaCheckMode() as s:
394
                IncorrectAliasTensor(x).aminmax(dim=0)
395

396
    # When this file was written, python op registration didn't exist.
397
    # It's probably worth re-writing the entire file to use it,
398
    # but instead I just added extra tests.
399
    def test_alias_check_fail_custom_ops_secretly_aliasing(self):
400
        def f(x):
401
            return torch.ops.bad_schemas.secretly_aliasing(x)
402

403
        x = torch.rand((3, 3))
404
        with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
405
            with SchemaCheckMode() as s:
406
                out = f(x)
407

408
    def test_alias_check_fail_custom_ops_secretly_mutating(self):
409
        def f(x):
410
            return torch.ops.bad_schemas.secretly_mutating(x)
411

412
        x = torch.rand((3, 3))
413
        with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
414
            with SchemaCheckMode() as s:
415
                out = f(x)
416

417
    def test_alias_check_fail_custom_ops_output_is_input(self):
418
        def f(x):
419
            return torch.ops.bad_schemas.output_is_input(x)
420

421
        x = torch.rand((3, 3))
422
        with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
423
            with SchemaCheckMode() as s:
424
                out = f(x)
425

426
    # Tests that is_alias_of returns as expected
427
    def test_is_alias_of_basic(self):
428
        x = torch.rand((3, 3), requires_grad=True)
429
        y = torch.rand((3, 3), requires_grad=True)
430
        y = x.add(x, alpha=2)
431
        self.assertTrue(torch._C._is_alias_of(x, x))
432
        self.assertFalse(torch._C._is_alias_of(x, y))
433

434
    # Tests that is_alias_of returns as expected with empty containers
435
    def test_is_alias_of_empty_container(self):
436
        x = []
437
        y = torch.rand((3, 3), requires_grad=True)
438
        self.assertFalse(torch._C._is_alias_of(x, x))
439
        self.assertFalse(torch._C._is_alias_of(x, y))
440

441
    # Tests that overlaps returns as expected
442
    def test_overlaps_basic(self):
443
        x = torch.rand((3, 3), requires_grad=True)
444
        y = torch.rand((3, 3), requires_grad=True)
445
        z = [x, y]
446
        self.assertTrue(torch._C._overlaps(x, x))
447
        self.assertFalse(torch._C._overlaps(x, y))
448
        self.assertTrue(torch._C._overlaps(z, x))
449
        self.assertTrue(torch._C._overlaps(z, y))
450

451
    # Tests that overlaps returns correctly with empty containers
452
    def test_overlaps_empty_container(self):
453
        x = []
454
        y = [torch.rand((3, 3), requires_grad=True)]
455
        # Empty containers return false
456
        self.assertFalse(torch._C._overlaps(y, x))
457
        self.assertTrue(torch._C._overlaps(y, y))
458

459
    # Tests that SchemaInfo Bindings work as expected
460
    def test_schema_info_bind_basic(self):
461
        class SchemaInfoBindTestMode(TorchDispatchMode):
462
            def __init__(self, test_self):
463
                self.test_self = test_self
464

465
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
466
                named_arg_list = normalize_function(
467
                    func,
468
                    args,
469
                    kwargs,
470
                    normalize_to_only_use_kwargs=True
471
                ).kwargs
472
                schema_info_value_test = torch._C._SchemaInfo(func._schema)
473
                schema_info_values_test = torch._C._SchemaInfo(func._schema)
474
                self.test_self.assertFalse(schema_info_value_test.may_alias(
475
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
476
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
477
                self.test_self.assertFalse(schema_info_values_test.may_alias(
478
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
479
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
480
                for i in named_arg_list:
481
                    schema_info_value_test.add_argument_value(i, named_arg_list[i])
482
                schema_info_values_test.add_argument_values(named_arg_list)
483
                self.test_self.assertTrue(schema_info_value_test.may_alias(
484
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
485
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
486
                self.test_self.assertTrue(schema_info_values_test.may_alias(
487
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
488
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
489

490
                return func(*args, **kwargs)
491
        x = torch.rand((3, 3))
492
        with SchemaInfoBindTestMode(self) as schemaInfoCheck:
493
            x.add(x)
494

495

496
class TestSchemaCheckModeOpInfo(JitTestCase):
497
    @ops(op_db, dtypes=OpDTypes.supported)
498
    def test_schema_correctness(self, device, dtype, op):
499
        # Currently torch.equal isn't supported with torch.complex32
500
        # There's also errors with complex64 and complex128
501
        if (dtype == torch.complex32):
502
            return
503
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
504
            with SchemaCheckMode():
505
                op(sample.input, *sample.args, **sample.kwargs)
506

507
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
508

509
if __name__ == '__main__':
510
    run_tests()
511

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.