6
from torch.utils._pytree import tree_map
9
from torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO
10
from torch.fx.operator_schemas import normalize_function
11
from torch._subclasses.schema_check_mode import SchemaCheckMode
12
from torch.utils._python_dispatch import TorchDispatchMode
13
from torch.testing._internal.common_methods_invocations import op_db
14
from torch.testing._internal.jit_utils import JitTestCase
15
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
16
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17
sys.path.append(pytorch_test_dir)
19
def secretly_aliasing(x):
22
def secretly_mutating(x):
26
def output_is_input(x):
29
custom_lib = torch.library.Library("bad_schemas", "DEF")
30
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
31
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
32
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
34
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")
35
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
36
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
37
custom_lib_cpu.impl("output_is_input", output_is_input)
39
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")
40
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
41
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
42
custom_lib_meta.impl("output_is_input", output_is_input)
47
class IncorrectAliasTensor(torch.Tensor):
48
ALIAS_ARG_OUT = {"aten::add"}
49
ALIAS_OUT_OUT = {"aten::aminmax"}
50
MUTATE_ARGS_OUT = {"aten::sub"}
57
def __new__(cls, elem, *args, **kwargs):
61
r = torch.Tensor._make_wrapper_subclass(
63
strides=elem.stride(), storage_offset=elem.storage_offset(),
65
dtype=elem.dtype, layout=elem.layout,
66
device=elem.device, requires_grad=kwargs.get("requires_grad", False)
69
r.elem = elem.detach() if r.requires_grad else elem
73
return super().__repr__(tensor_contents=f"{self.elem}")
76
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
78
return e.elem if isinstance(e, cls) else e
81
return cls(e) if isinstance(e, torch.Tensor) else e
82
unwrapped_args = tree_map(unwrap, args)
83
out = func(*unwrapped_args, **tree_map(unwrap, kwargs))
84
if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT:
86
if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT:
87
args[0].elem = torch.rand(args[0].elem.shape)
88
if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT:
89
incorrect_out = list(out)
90
incorrect_out[0] = incorrect_out[1]
91
return tree_map(wrap, tuple(incorrect_out))
93
return tree_map(wrap, out)
96
class TestSchemaCheck(JitTestCase):
98
if TEST_WITH_TORCHDYNAMO:
99
self.skipTest("SchemaCheckMode is ignored by dynamo")
103
def test_schema_check_mode_operator_order(self):
104
with SchemaCheckMode() as schema_check:
105
x = torch.rand((3, 3), requires_grad=True)
107
self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
110
def test_schema_check_mode_operator_order_without_grad(self):
111
with SchemaCheckMode() as schema_check:
112
x = torch.rand((3, 3), requires_grad=False)
114
self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
117
def test_schema_check_mode_mutated_aliasing_none(self):
120
x = torch.rand((3, 3))
121
with SchemaCheckMode() as schema_check:
122
actual = x.relu().sin()
123
self.assertEqual([], schema_check.mutated)
124
self.assertEqual([], schema_check.aliasing)
127
def test_schema_check_mode_mutated_aliasing_mutation(self):
128
actual = torch.rand((3, 3), requires_grad=False)
129
with SchemaCheckMode() as schema_check:
131
self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
132
self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
135
def test_schema_check_mode_mutated_aliasing_resize_(self):
136
actual = torch.rand((3, 3), requires_grad=False)
137
with SchemaCheckMode() as schema_check:
139
self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
140
self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
143
def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
144
actual = torch.rand((3, 3))
146
with SchemaCheckMode() as schema_check:
150
('aten::add_', 'input'),
151
('aten::add_', 'other')
157
('aten::add_', 'input', 'output_0'),
158
('aten::add_', 'other', 'output_0')
160
schema_check.aliasing
164
def test_schema_check_mode_mutated_aliasing_as_strided(self):
165
x = torch.rand((3, 6, 4))
166
with SchemaCheckMode() as schema_check:
167
x.as_strided_([3, 6, 4], [9, 1, 1])
170
('aten::as_strided_', 'input')
176
('aten::as_strided_', 'input', 'output_0')
178
schema_check.aliasing
182
def test_schema_check_mode_mutated_aliasing_multiple_outputs(self):
184
m_actual = torch.arange(9.)
185
e_actual = torch.zeros([9], dtype=torch.int32)
186
with SchemaCheckMode() as schema_check:
187
torch.frexp(x, out=(m_actual, e_actual))
190
('aten::frexp', 'mantissa'),
191
('aten::frexp', 'exponent')
197
('aten::frexp', 'mantissa', 'output_0'),
198
('aten::frexp', 'exponent', 'output_1')
200
schema_check.aliasing
204
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
205
x = torch.rand((3, 3))
206
actual = torch.zeros(3)
207
with SchemaCheckMode() as schema_check:
208
torch.aminmax(x, dim=0, out=[actual, actual])
211
('aten::aminmax', 'min'),
212
('aten::aminmax', 'max')
218
('aten::aminmax', 'min', 'output_0'),
219
('aten::aminmax', 'min', 'output_1'),
220
('aten::aminmax', 'max', 'output_0'),
221
('aten::aminmax', 'max', 'output_1')
223
schema_check.aliasing
227
def test_schema_check_mode_functionality(self):
228
x = torch.rand((3, 3), requires_grad=True)
229
expected = x.relu().sin()
230
with SchemaCheckMode():
231
actual = x.relu().sin()
232
self.assertEqual(expected, actual)
235
def test_schema_check_mode_functionality_default_replaced(self):
236
x = torch.rand((3, 3), requires_grad=True)
237
expected = x.add(x, alpha=2)
238
with SchemaCheckMode():
239
actual = x.add(x, alpha=2)
240
self.assertEqual(expected, actual)
243
def test_schema_check_mode_functionality_list_input(self):
244
a = torch.rand((3, 3))
245
b = torch.rand((3, 3))
246
c = torch.rand((3, 3))
247
expected = torch.linalg.multi_dot([a, b, c])
248
with SchemaCheckMode():
249
actual = torch.linalg.multi_dot([a, b, c])
250
self.assertEqual(expected, actual)
253
def test_schema_check_mode_functionality_wildcard_after(self):
254
x = torch.rand((3, 3))
255
expected = x.chunk(6)
256
with SchemaCheckMode():
258
self.assertEqual(expected, actual)
261
@unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.")
262
def test_schema_check_mode_functionality_kwarg_tensor(self):
263
x = torch.rand((3, 5))
265
expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
266
with SchemaCheckMode():
267
actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
268
self.assertEqual(expected, actual)
271
def test_schema_check_mode_functionality_mutable_inputs(self):
272
expected = torch.rand((3, 3), requires_grad=False)
273
actual = torch.clone(expected)
275
with SchemaCheckMode():
277
self.assertEqual(expected, actual)
280
def test_schema_check_mode_functionality_aliasing_inputs(self):
281
expected = torch.rand((3, 3))
283
actual = torch.clone(expected)
286
with SchemaCheckMode():
288
self.assertEqual(expected, actual)
291
def test_schema_check_mode_functionality_with_multiple_outputs(self):
293
m_expected, e_expected = torch.frexp(x)
294
m_actual = torch.arange(9.)
295
e_actual = torch.zeros([9], dtype=torch.int32)
296
with SchemaCheckMode():
297
torch.frexp(x, out=(m_actual, e_actual))
298
self.assertEqual(m_expected, m_actual)
299
self.assertEqual(e_expected, e_actual)
302
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
303
x = torch.rand((3, 3))
304
actual = torch.zeros(3)
305
with SchemaCheckMode():
306
torch.aminmax(x, dim=0, out=[actual, actual])
307
self.assertEqual(torch.amax(x, dim=0), actual)
310
def test_schema_check_mode_functionality_device_input(self):
311
with SchemaCheckMode():
312
x = torch.rand((3, 3), device="cpu", dtype=torch.double)
314
self.assertEqual(x + x, y)
317
def test_schema_check_mode_functionality_training_op(self):
318
x = torch.rand((3, 3), requires_grad=True)
319
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
321
with SchemaCheckMode():
323
self.assertEqual(expected, actual)
326
def test_schema_check_mode_functionality_nested_training_op(self):
327
actual = torch.rand((3, 3))
328
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
329
expected = torch.clone(actual)
333
expected = batch(expected)
335
with SchemaCheckMode():
339
actual = batch(actual)
340
self.assertEqual(expected, actual)
343
def test_schema_check_mode_empty_list_input(self):
344
expected = torch.atleast_1d([])
345
with SchemaCheckMode():
346
actual = torch.atleast_1d([])
347
self.assertEqual(expected, actual)
350
def test_mutation_check_fail(self):
351
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
352
x = torch.rand((3, 3))
353
y = torch.rand((3, 3))
354
with SchemaCheckMode():
355
IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
358
def test_mutation_check_fail_multiple_operators(self):
359
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
360
x = torch.rand((3, 3))
361
y = torch.rand((3, 3))
362
with SchemaCheckMode():
363
IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
366
def test_alias_check_fail_simple(self):
367
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
368
x = torch.rand((3, 3), requires_grad=True)
369
y = torch.rand((3, 3))
370
with SchemaCheckMode():
371
IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
374
def test_alias_check_fail_multiple_operators(self):
375
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
376
x = torch.rand((3, 3), requires_grad=True)
377
y = torch.zeros((3, 3), requires_grad=True)
378
with SchemaCheckMode():
379
IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
382
def test_alias_check_fail_multiple_operators_centered(self):
383
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
384
x = torch.rand((3, 3), requires_grad=True)
385
y = torch.zeros((3, 3), requires_grad=True)
386
with SchemaCheckMode():
387
IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
390
def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
391
with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
392
x = torch.rand((3, 3))
393
with SchemaCheckMode() as s:
394
IncorrectAliasTensor(x).aminmax(dim=0)
399
def test_alias_check_fail_custom_ops_secretly_aliasing(self):
401
return torch.ops.bad_schemas.secretly_aliasing(x)
403
x = torch.rand((3, 3))
404
with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
405
with SchemaCheckMode() as s:
408
def test_alias_check_fail_custom_ops_secretly_mutating(self):
410
return torch.ops.bad_schemas.secretly_mutating(x)
412
x = torch.rand((3, 3))
413
with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
414
with SchemaCheckMode() as s:
417
def test_alias_check_fail_custom_ops_output_is_input(self):
419
return torch.ops.bad_schemas.output_is_input(x)
421
x = torch.rand((3, 3))
422
with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
423
with SchemaCheckMode() as s:
427
def test_is_alias_of_basic(self):
428
x = torch.rand((3, 3), requires_grad=True)
429
y = torch.rand((3, 3), requires_grad=True)
430
y = x.add(x, alpha=2)
431
self.assertTrue(torch._C._is_alias_of(x, x))
432
self.assertFalse(torch._C._is_alias_of(x, y))
435
def test_is_alias_of_empty_container(self):
437
y = torch.rand((3, 3), requires_grad=True)
438
self.assertFalse(torch._C._is_alias_of(x, x))
439
self.assertFalse(torch._C._is_alias_of(x, y))
442
def test_overlaps_basic(self):
443
x = torch.rand((3, 3), requires_grad=True)
444
y = torch.rand((3, 3), requires_grad=True)
446
self.assertTrue(torch._C._overlaps(x, x))
447
self.assertFalse(torch._C._overlaps(x, y))
448
self.assertTrue(torch._C._overlaps(z, x))
449
self.assertTrue(torch._C._overlaps(z, y))
452
def test_overlaps_empty_container(self):
454
y = [torch.rand((3, 3), requires_grad=True)]
456
self.assertFalse(torch._C._overlaps(y, x))
457
self.assertTrue(torch._C._overlaps(y, y))
460
def test_schema_info_bind_basic(self):
461
class SchemaInfoBindTestMode(TorchDispatchMode):
462
def __init__(self, test_self):
463
self.test_self = test_self
465
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
466
named_arg_list = normalize_function(
470
normalize_to_only_use_kwargs=True
472
schema_info_value_test = torch._C._SchemaInfo(func._schema)
473
schema_info_values_test = torch._C._SchemaInfo(func._schema)
474
self.test_self.assertFalse(schema_info_value_test.may_alias(
475
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
476
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
477
self.test_self.assertFalse(schema_info_values_test.may_alias(
478
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
479
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
480
for i in named_arg_list:
481
schema_info_value_test.add_argument_value(i, named_arg_list[i])
482
schema_info_values_test.add_argument_values(named_arg_list)
483
self.test_self.assertTrue(schema_info_value_test.may_alias(
484
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
485
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
486
self.test_self.assertTrue(schema_info_values_test.may_alias(
487
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
488
torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
490
return func(*args, **kwargs)
491
x = torch.rand((3, 3))
492
with SchemaInfoBindTestMode(self) as schemaInfoCheck:
496
class TestSchemaCheckModeOpInfo(JitTestCase):
497
@ops(op_db, dtypes=OpDTypes.supported)
498
def test_schema_correctness(self, device, dtype, op):
501
if (dtype == torch.complex32):
503
for sample in op.sample_inputs(device, dtype, requires_grad=False):
504
with SchemaCheckMode():
505
op(sample.input, *sample.args, **sample.kwargs)
507
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
509
if __name__ == '__main__':