1
# Owner(s): ["module: meta tensors"]
12
from unittest.mock import patch
17
import torch._functorch.config
18
import torch._prims as prims
19
import torch.testing._internal.optests as optests
20
import torch.utils._pytree as pytree
22
from torch import distributed as dist
23
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
24
from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
25
from torch._guards import tracing, TracingContext
26
from torch._subclasses.fake_tensor import (
27
DynamicOutputShapeException,
28
extract_tensor_metadata,
32
unset_fake_temporarily,
33
UnsupportedOperatorException,
36
from torch.fx.experimental.proxy_tensor import make_fx
37
from torch.fx.experimental.symbolic_shapes import (
42
StatelessSymbolicContext,
43
statically_known_true,
45
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
46
from torch.testing import FileCheck
47
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
48
from torch.testing._internal.common_device_type import (
49
instantiate_device_type_tests,
53
from torch.testing._internal.common_utils import (
54
instantiate_parametrized_tests,
61
TEST_WITH_TORCHDYNAMO,
65
from torch.testing._internal.inductor_utils import GPU_TYPE
66
from torch.testing._internal.custom_op_db import custom_op_db
67
from torch.testing._internal.jit_utils import RUN_CUDA
68
from torch.utils._mode_utils import no_dispatch
69
from torch.utils._python_dispatch import TorchDispatchMode
73
torch._dynamo.config.fake_tensor_cache_enabled = True
74
torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True
77
def expectedFailurePropagateRealTensors(fn):
78
fn._expected_failure_propagate_real_tensors = True
82
class FakeTensorTest(TestCase):
83
def checkType(self, t, device_str, size):
84
self.assertTrue(isinstance(t, FakeTensor))
85
self.assertEqual(t.device.type, device_str)
86
self.assertEqual(list(t.size()), size)
88
@unittest.skipIf(not RUN_CUDA, "requires cuda")
89
def test_cuda_initialized(self):
91
with FakeTensorMode():
92
p = torch.randn(4, 2, requires_grad=True, device="cuda")
93
x = torch.randn(8, 4, device="cuda")
94
y = torch.mm(x, p).square().sum()
98
x = torch.empty(2, 2, device="cpu")
99
y = torch.empty(4, 2, 2, device="cpu")
100
with FakeTensorMode() as mode:
101
x = mode.from_tensor(x)
102
y = mode.from_tensor(y)
104
self.assertEqual(z.shape, (4, 2, 2))
105
self.assertEqual(z.device, torch.device("cpu"))
106
self.assertTrue(isinstance(z, FakeTensor))
108
def test_custom_op_fallback(self):
109
from torch.library import impl, Library
112
test_lib = Library("my_test_op", "DEF") # noqa: TOR901
113
test_lib.define("foo(Tensor self) -> Tensor")
115
@impl(test_lib, "foo", "CPU")
119
x = torch.empty(2, 2, device="cpu")
120
with self.assertRaisesRegex(
121
UnsupportedOperatorException, "my_test_op.foo.default"
123
with FakeTensorMode(allow_fallback_kernels=True) as mode:
124
x = mode.from_tensor(x)
125
torch.ops.my_test_op.foo(x)
130
def test_parameter_instantiation(self):
131
with FakeTensorMode():
133
y = torch.nn.parameter.Parameter(x)
134
self.assertTrue(isinstance(y, torch.nn.Parameter))
136
@unittest.skipIf(not dist.is_available(), "requires distributed")
137
def test_fsdp_flat_param(self):
138
from torch.distributed.fsdp._flat_param import FlatParameter
140
with FakeTensorMode() as m:
141
data = torch.randn(2, 2)
142
param = FlatParameter(data, requires_grad=True)
143
self.assertIsInstance(param, FlatParameter)
144
self.assertIsInstance(param, torch.nn.Parameter)
145
self.assertIsInstance(param, FakeTensor)
147
def test_non_parameter_grad(self):
148
mode = FakeTensorMode()
149
t = torch.rand([4], requires_grad=True)
150
fake_t = mode.from_tensor(t)
151
self.assertEqual(fake_t.requires_grad, t.requires_grad)
154
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
156
@unittest.skipIf(not RUN_CUDA, "requires cuda")
157
def test_index_cuda_with_cpu(self):
158
with FakeTensorMode():
159
x = torch.rand([2048], device="cuda")
160
out = x[torch.zeros([36], dtype=torch.int64)]
161
self.checkType(out, "cuda", [36])
163
@unittest.skipIf(not RUN_CUDA, "requires cuda")
164
def test_shape_take_not_device(self):
165
with FakeTensorMode():
166
x = torch.empty(1, device="cpu")
167
y = torch.empty(8, 8, device="cuda")
168
out = x.resize_as_(y)
169
self.assertEqual(out.shape, (8, 8))
170
self.assertEqual(out.device.type, "cpu")
171
self.assertTrue(isinstance(out, FakeTensor))
174
with FakeTensorMode():
175
x = torch.empty(2, 2, device="cpu")
176
self.assertEqual(repr(x), "FakeTensor(..., size=(2, 2))")
177
x = torch.empty(2, 2, device="meta")
178
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
180
@unittest.skipIf(not RUN_CUDA, "requires cuda")
181
def test_zero_dim(self):
182
with FakeTensorMode() as mode:
183
x = torch.tensor(0.0)
184
y = torch.rand([4, 4], device="cuda")
186
self.assertEqual(out.shape, (4, 4))
187
self.assertEqual(out.device, y.device)
188
self.assertTrue(isinstance(out, FakeTensor))
190
def test_nan_to_num(self):
191
with FakeTensorMode():
192
for dtype in [torch.float16, torch.float32]:
193
x = torch.rand([4], dtype=dtype)
194
y = torch.nan_to_num(x, nan=None)
195
z = torch.nan_to_num(x, 0.0)
196
self.assertEqual(dtype, y.dtype)
197
self.assertEqual(dtype, z.dtype)
199
@unittest.skipIf(not RUN_CUDA, "requires cuda")
200
def test_throw(self):
201
x = torch.tensor(0.0) # TODO: tensor() errors
202
with FakeTensorMode() as mode:
203
x_conv = mode.from_tensor(x)
204
y = torch.rand([4, 4], device="cuda")
205
z = torch.rand([4, 4], device="cpu")
206
self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
208
@unittest.skipIf(not RUN_CUDA, "requires cuda")
209
def test_type_as(self):
210
with FakeTensorMode():
211
x = torch.rand([16, 1], device="cpu")
212
y = torch.rand([4, 4], device="cuda")
214
self.assertEqual(out.device.type, "cuda")
215
self.assertTrue(isinstance(out, FakeTensor))
217
@unittest.skipIf(not RUN_CUDA, "requires cuda")
218
def test_setitem(self):
219
for device in ["cpu", "cuda"]:
220
with FakeTensorMode():
221
x = torch.rand([16, 1], device=device)
224
@unittest.skipIf(not RUN_CUDA, "requires cuda")
225
def test_device_inplace_copy(self):
226
with FakeTensorMode():
227
x = torch.rand([8, 8], device="cpu")
228
y = torch.rand([8, 8], device="cuda")
229
assert x.copy_(y).device.type == "cpu"
230
assert y.copy_(x).device.type == "cuda"
232
def test_fake_dispatch_keys(self):
233
with FakeTensorMode():
238
.check("ADInplaceOrView")
239
.check("AutogradCPU")
240
.check("AutocastCPU")
242
f.run(torch._C._dispatch_key_set(x))
244
with torch.inference_mode():
247
FileCheck().check("CPU").check("AutocastCPU").run(
248
torch._C._dispatch_key_set(y)
250
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(
251
torch._C._dispatch_key_set(y)
254
def test_batch_tensor(self):
255
x = torch.rand((3, 4, 5))
256
b = _add_batch_dim(x, 0, 0)
257
mode = FakeTensorMode()
258
fake_b = mode.from_tensor(b)
259
prims.utils.compare_tensor_meta(b, fake_b, check_strides=True)
261
b1 = _add_batch_dim(x, 1, 1)
262
b2 = _add_batch_dim(b1, 0, 2)
263
fake_b2 = mode.from_tensor(b2)
264
prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True)
265
self.assertTrue(is_batchedtensor(fake_b2))
266
fake_b1 = get_unwrapped(fake_b2)
267
self.assertTrue(is_batchedtensor(fake_b1))
268
fake_tensor = get_unwrapped(fake_b1)
269
self.assertIsInstance(fake_tensor, FakeTensor)
271
def test_constructor(self):
272
with FakeTensorMode():
273
x = torch.rand([4, 4], device="cpu")
275
self.assertTrue(isinstance(x, FakeTensor))
276
self.assertTrue(x.device.type == "cpu")
279
with FakeTensorMode():
280
y = torch.rand([4], device="cpu")
283
self.assertTrue(isinstance(out, FakeTensor))
286
# Test torch.full returns tensor with correct dtype
287
with torch._subclasses.CrossRefFakeMode():
288
y = torch.full((4, 4), 1)
290
def check_function_with_fake(self, fn):
292
with torch._subclasses.FakeTensorMode():
295
for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)):
296
if not isinstance(a, torch.Tensor):
297
self.assertTrue(not isinstance(b, torch.Tensor))
300
prims.utils.compare_tensor_meta(a, b, check_strides=True)
302
@unittest.skipIf(not RUN_CUDA, "requires cuda")
303
def test_non_kwarg_device(self):
304
with FakeTensorMode():
305
x = torch.rand([16, 1], device="cpu")
306
y = x.to(torch.device("cpu"))
308
z = x.to(torch.device("cuda"))
309
self.assertEqual(z.device.type, "cuda")
311
def test_non_overlapping_stride_zero(self):
313
x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
316
self.check_function_with_fake(foo)
318
def test_fake_mode_error(self):
319
x = torch.rand([4, 4])
321
with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
322
with FakeTensorMode():
326
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
328
def test_fake_grad_copy(self):
329
x = torch.rand([4, 4], requires_grad=True)
330
x.grad = torch.rand([4, 4])
331
mode = FakeTensorMode()
332
fake_x = mode.from_tensor(x)
333
prims.utils.compare_tensor_meta(fake_x, x)
334
prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
336
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
338
@unittest.skipIf(not RUN_CUDA, "requires cuda")
339
def test_index_put_error(self):
340
mode = FakeTensorMode()
341
for context in [contextlib.nullcontext, lambda: mode]:
343
y = torch.randn(2, 2, 3)
344
x = torch.randn(2, 2, 3).to("cuda")
345
with self.assertRaises(RuntimeError):
348
with self.assertRaises(RuntimeError):
349
torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
352
torch.ops.aten.index_put(
353
x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0)
355
torch.ops.aten.index_put_(
356
x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.0)
359
@unittest.skipIf(not RUN_CUDA, "requires cuda")
360
def test_like_constructor(self):
361
with FakeTensorMode():
362
x = torch.rand([4, 4])
363
y = torch.ones_like(x)
364
self.assertTrue(isinstance(y, FakeTensor))
365
self.assertEqual(y.device.type, "cpu")
366
z = torch.ones_like(x, device="cuda")
367
self.assertTrue(isinstance(z, FakeTensor))
368
self.assertEqual(z.device.type, "cuda")
370
def test_binary_op_type_promotion(self):
371
with FakeTensorMode():
372
x = torch.empty([2, 2], dtype=torch.float)
373
y = torch.empty([2, 2], dtype=torch.int64)
375
self.assertEqual(out.dtype, torch.float)
376
self.assertEqual(out.device.type, "cpu")
379
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
381
def test_from_numpy(self):
382
with FakeTensorMode():
383
x = torch.tensor(np.zeros([4, 4]))
384
self.checkType(x, "cpu", [4, 4])
386
def test_randperm(self):
387
x = torch.randperm(10)
388
y = torch.randperm(5, device="cpu")
389
with FakeTensorMode():
390
x1 = torch.randperm(10)
391
prims.utils.compare_tensor_meta(x, x1)
392
y1 = torch.randperm(5, device="cpu")
393
prims.utils.compare_tensor_meta(y, y1)
395
def test_print_in_fake_mode(self):
398
with FakeTensorMode():
400
assert "FakeTensor" not in out
402
@unittest.skipIf(not RUN_CUDA, "requires cuda")
403
def test_upsample_bilinear_small_channels(self):
405
mode = FakeTensorMode()
406
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
408
arg0_1 = torch.empty_strided(
409
(3, 427, 640), (1, 1920, 3), dtype=torch.float32, device="cuda"
411
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
413
torch.ops.aten.upsample_bilinear2d.default(
414
unsqueeze, [800, 1199], False
418
self.assertTrue(out[1].is_contiguous())
419
self.checkMetaProps(out[0], out[1])
421
@unittest.skipIf(not RUN_CUDA, "requires cuda")
422
def test_cpu_fallback(self):
423
with FakeTensorMode(allow_fallback_kernels=False):
424
filters = torch.randn(8, 4, 3, 3).cuda()
425
inputs = torch.randn(1, 4, 5, 5).cuda()
426
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
427
self.assertEqual(out.device.type, "cuda")
428
self.assertEqual(list(out.size()), [1, 8, 5, 5])
430
with FakeTensorMode(allow_fallback_kernels=True):
431
# intentionally bad inputs
432
filters = torch.randn(8, 20, 3, 3).cuda()
433
inputs = torch.randn(1, 7, 10, 5).cuda()
434
with self.assertRaises(RuntimeError):
435
torch.nn.functional.conv2d(inputs, filters, padding=1)
437
with FakeTensorMode(allow_fallback_kernels=True):
438
filters = torch.randn(8, 4, 3, 3).cuda()
439
inputs = torch.randn(1, 4, 5, 5).cuda()
441
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
442
self.assertEqual(out.device.type, "cuda")
443
self.assertEqual(list(out.size()), [1, 8, 5, 5])
445
@unittest.skipIf(not RUN_CUDA, "requires cuda")
446
def test_out_multi_device(self):
447
with FakeTensorMode():
449
y = torch.rand([4], device="cuda")
451
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
454
with self.assertRaisesRegex(Exception, "found.+two.+devices"):
458
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
460
@unittest.skipIf(not RUN_CUDA, "requires cuda")
461
def test_normalize_device(self):
462
with FakeTensorMode():
463
x = torch.empty(1, device="cuda")
464
y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
466
self.checkType(out, "cuda", [1])
468
def test_recursive_invocation(self):
469
mode = FakeTensorMode()
472
mode.in_kernel_invocation = True
474
self.assertTrue(mode.in_kernel_invocation)
477
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
481
"allow_fallback_kernels",
483
lambda a: "with_fallback" if a else "without_fallback",
485
@unittest.skipIf(not RUN_CUDA, "requires cuda")
486
def test_cudnn_rnn(self, allow_fallback_kernels):
527
return torch.ops.aten._cudnn_rnn(
546
mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
547
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
550
torch.randn([92, 8, 2048]).cuda(),
551
torch.randn([8192, 2048]).cuda(),
552
torch.randn([8192, 2048]).cuda(),
553
torch.randn([8192]).cuda(),
554
torch.randn([8192]).cuda(),
555
torch.randn([8192, 2048]).cuda(),
556
torch.randn([8192, 2048]).cuda(),
557
torch.randn([8192]).cuda(),
558
torch.randn([8192]).cuda(),
559
torch.randn([8192, 4096]).cuda(),
560
torch.randn([8192, 2048]).cuda(),
561
torch.randn([8192]).cuda(),
562
torch.randn([8192]).cuda(),
563
torch.randn([8192, 4096]).cuda(),
564
torch.randn([8192, 2048]).cuda(),
565
torch.randn([8192]).cuda(),
566
torch.randn([8192]).cuda(),
567
torch.randn([167837696]).cuda(),
568
torch.randn([4, 8, 2048]).cuda(),
569
torch.randn([4, 8, 2048]).cuda(),
572
inps2[len(inps2) - 1] = None # argument `cx` can be None
574
for inps in [inps1, inps2]:
576
self.assertIs(out[4], inps[-3])
579
self.assertTrue(isinstance(ten, FakeTensor))
580
self.assertEqual(ten.device.type, "cuda")
582
@unittest.skipIf(not RUN_CUDA, "requires cuda")
583
def test_cuda_lstm(self):
584
# Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
585
with torch.backends.cudnn.flags(enabled=False):
586
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
587
with fake_tensor_mode:
595
D = 2 if bidir else 1
596
H_out = proj_size if proj_size > 0 else hidden_size
598
lstm = torch.nn.LSTM(
600
hidden_size=hidden_size,
601
num_layers=num_layers,
609
h_0 = torch.randn((num_layers * D, N, H_out), device="cuda")
610
c_0 = torch.randn((num_layers * D, N, hidden_size), device="cuda")
611
inp = torch.randn((L, N, H_in), device="cuda")
612
(output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
613
output.sum().backward()
615
self.assertEqual(output.shape, (L, N, D * H_out))
616
self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
617
self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
619
def test_data_dependent_operator(self):
620
with FakeTensorMode(allow_fallback_kernels=False):
621
x = torch.rand([10, 10])
623
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
625
def test_parameter_view(self):
626
x = torch.nn.Parameter(torch.randn(4))
628
mode = FakeTensorMode()
629
fake_x_view = mode.from_tensor(x_view)
630
fake_x = mode.from_tensor(x)
631
self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter))
632
self.assertTrue(isinstance(fake_x, torch.nn.Parameter))
634
def test_tolist(self):
635
shape_env = ShapeEnv()
636
with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env):
640
# Propagate real tensors doesn't work with fake-on-fake
641
@expectedFailurePropagateRealTensors
642
def test_same_shape_env_preserved(self):
643
shape_env = ShapeEnv()
644
mode1 = FakeTensorMode(shape_env=shape_env)
645
t1 = mode1.from_tensor(
647
symbolic_context=StatelessSymbolicContext(
648
dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None]
651
mode2 = FakeTensorMode(shape_env=shape_env)
652
t2 = mode2.from_tensor(t1)
653
# t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
654
self.assertIsNot(t2, t1)
655
self.assertIs(t1.fake_mode, mode1)
656
self.assertIs(t2.fake_mode, mode2)
657
self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
658
self.assertEqual(str(t2.size(0)), str(t1.size(0)))
660
# TODO: Support NJT. There's also some funny business with dynamic shapes
661
# which would need to be dealt with as well
662
@expectedFailurePropagateRealTensors
663
def test_jagged_fake_to_fake_preserved(self):
664
from torch.nested._internal.nested_tensor import jagged_from_list
668
a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
669
b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
670
c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
672
jt, _ = jagged_from_list([a, b, c], offsets)
673
shape_env = ShapeEnv()
674
mode1 = FakeTensorMode(shape_env=shape_env)
675
t1 = mode1.from_tensor(jt)
676
mode2 = FakeTensorMode(shape_env=shape_env)
677
t2 = mode2.from_tensor(t1)
678
# It's not obvious that the invocation above makes it dynamic but it
680
self.assertTrue(free_symbols(t1.size()))
681
self.assertIsNot(t2, t1)
682
self.assertIs(t1.offsets().fake_mode, mode1)
683
self.assertIs(t2.offsets().fake_mode, mode2)
684
self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
685
self.assertEqual(str(t2.size(1)), str(t1.size(1)))
687
def checkMetaProps(self, t1, t2):
688
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
691
def test_deepcopy(self):
692
with FakeTensorMode() as mode:
694
mod = torch.nn.BatchNorm2d(10)
695
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
696
mod_copied = copy.deepcopy(mod)
698
def check_copy(mod, mod_copied):
699
for name, param in itertools.chain(
700
mod.named_parameters(), mod.named_buffers()
702
param_copied = getattr(mod_copied, name)
703
self.checkMetaProps(param, param_copied)
704
self.assertTrue(isinstance(param_copied, FakeTensor))
706
isinstance(param, torch.nn.Parameter),
707
isinstance(param_copied, torch.nn.Parameter),
709
self.assertEqual(param.requires_grad, param_copied.requires_grad)
711
check_copy(mod, mod_copied)
713
class ModuleNew(torch.nn.Module):
714
def __init__(self) -> None:
716
self.a = torch.rand([10, 2])
721
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
722
mod_copied = copy.deepcopy(mod)
724
self.assertIs(mod_copied.a, mod_copied.b)
725
self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
728
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
730
@unittest.skipIf(not RUN_CUDA, "requires cuda")
732
with FakeTensorMode():
733
a = torch.rand([16, 1])
734
self.checkType(a.new(10, 10), "cpu", [10, 10])
735
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
736
b = torch.rand([4, 4], device="cuda")
737
self.checkType(b.new(device="cuda"), "cuda", [0])
738
self.checkType(a.new(torch.rand([1])), "cpu", [1])
741
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
743
def test_scalar_inputs(self):
744
with FakeTensorMode():
745
self.checkType(torch.div(3, 2), "cpu", [])
746
ten = torch.zeros(2, dtype=torch.int32) * 2.0
747
self.assertEqual(ten.dtype, torch.float)
748
self.checkType(ten, "cpu", [2])
751
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
753
def test_allow_meta(self):
755
with FakeTensorMode():
756
x = torch.rand([4], device="meta")
759
self.checkType(run_meta(), "meta", [4])
761
with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
762
self.assertRaises(Exception, run_meta)
764
def test_embedding_bag_meta(self):
766
# This behavior was originally unintentional but we see people
768
embedding = torch.nn.EmbeddingBag(10, 3, mode="sum", device="meta")
769
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
770
offsets = torch.tensor([0, 4], dtype=torch.long)
771
return embedding(input, offsets)
774
with FakeTensorMode():
777
for r, f in zip(real_out, fake_out):
778
self.assertEqual(r.size(), f.size())
779
self.assertEqual(r.device, f.device)
782
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
784
def test_mixed_real_and_fake_inputs(self):
785
class _TestPattern(torch.nn.Module):
786
def __init__(self) -> None:
788
self.conv = torch.nn.Conv2d(1, 1, 1)
789
self.bn = torch.nn.BatchNorm2d(1)
791
def forward(self, input):
792
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
793
scale_factor = self.bn.weight / running_std
794
weight_shape = [1] * len(self.conv.weight.shape)
796
bias_shape = [1] * len(self.conv.weight.shape)
798
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
799
zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
800
conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
801
conv_orig = conv / scale_factor.reshape(bias_shape)
802
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
803
conv = self.bn(conv_orig)
806
example_inputs = (torch.randn(1, 1, 3, 3),)
808
with FakeTensorMode(allow_non_fake_inputs=True):
809
out = mod(torch.randn(1, 1, 3, 3))
810
self.checkType(out, "cpu", (1, 1, 3, 3))
813
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
815
@unittest.skipIf(not RUN_CUDA, "requires cuda")
816
def test_aten_copy_multi_device(self):
817
with FakeTensorMode():
818
x1 = torch.rand(4, device="cpu")
819
x2 = torch.rand(4, device="cuda")
820
copy1 = torch.ops.aten.copy.default(x1, x2)
821
copy2 = torch.ops.aten.copy.default(x2, x1)
822
out = torch.empty(4, device="cpu")
823
torch.ops.aten.copy.out(x1, x2, out=out)
824
self.checkType(copy1, "cpu", (4,))
825
self.checkType(copy2, "cuda", (4,))
826
self.checkType(out, "cpu", (4,))
829
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
831
@unittest.skipIf(not RUN_CUDA, "requires cuda")
832
def test_aten_index_multi_device(self):
833
with FakeTensorMode():
834
x1 = torch.rand(4, 4, device="cpu")
835
x2 = torch.rand(4, 4, device="cuda")
836
i1 = torch.tensor([0, 1], device="cuda")
837
i2 = torch.tensor([0, 1], device="cpu")
838
# NB: This one does not work: cuda indices not allowed on cpu
840
# r1 = torch.ops.aten.index(x1, i1)
841
r2 = torch.ops.aten.index(x2, i2)
843
y1 = torch.rand(4, device="cpu")
844
y2 = torch.rand(4, device="cuda")
845
j1 = torch.tensor([2], device="cuda")
846
j2 = torch.tensor([2], device="cpu")
847
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
848
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
849
# self.checkType(r1, "cpu", ())
850
self.checkType(r2, "cuda", ())
851
self.checkType(r3, "cpu", (4, 4))
852
self.checkType(r4, "cuda", (4, 4))
855
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
857
@unittest.skipIf(not RUN_CUDA, "requires cuda")
858
def test_aten_slice_scatter_multi_device(self):
859
with FakeTensorMode():
860
x1 = torch.rand(4, 4, device="cpu")
861
y1 = torch.rand(2, 4, device="cuda")
862
x2 = torch.rand(4, 4, device="cuda")
863
y2 = torch.rand(2, 4, device="cpu")
864
out = torch.empty(4, 4, device="cpu")
865
r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
866
r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
867
r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
868
self.checkType(r1, "cpu", (4, 4))
869
self.checkType(r2, "cuda", (4, 4))
870
self.checkType(r3, "cpu", (4, 4))
871
self.checkType(out, "cpu", (4, 4))
873
def test__adaptive_avg_pool2d_backward(self):
874
with FakeTensorMode():
875
grad_out = torch.rand(2, 3, 4, 4)
876
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
877
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
879
torch._prims_common.suggest_memory_format(grad_in)
880
== torch.channels_last
883
def test_export_numpy(self):
884
class MyNumpyModel(torch.nn.Module):
885
def forward(self, input):
886
input = input.numpy()
887
return input + np.random.randn(*input.shape)
889
with FakeTensorMode():
890
ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
891
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
893
def test_unsqueeze_copy(self):
894
shape_env = ShapeEnv()
895
t1 = torch.ones(2, 2, 768)
896
with FakeTensorMode(shape_env=shape_env) as fake_mode:
897
t = fake_mode.from_tensor(
899
symbolic_context=StatelessSymbolicContext(
908
self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0])
910
def test_alias_call(self):
911
fwAD = torch.autograd.forward_ad
916
with torch._subclasses.fake_tensor.FakeTensorMode():
917
with fwAD.dual_level():
918
x = torch.randn(3, device="cpu")
919
y = torch.ones_like(x)
920
dual = fwAD.make_dual(x, y)
923
self.assertIsInstance(r, FakeTensor)
924
self.assertEqual(r.size(), [3])
927
instantiate_parametrized_tests(FakeTensorTest)
930
def make_propagate_real_tensors_cls(cls):
931
cls = make_test_cls_with_patches(
933
"PropagateRealTensors",
934
"_propagate_real_tensors",
935
(torch._functorch.config, "fake_tensor_propagate_real_tensors", True),
936
xfail_prop="_expected_failure_propagate_real_tensors",
937
decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"),
939
cls.__file__ = __file__
940
cls.__module__ = __name__
941
globals()[cls.__name__] = cls
944
make_propagate_real_tensors_cls(FakeTensorTest)
947
class FakeTensorConstHandling(TestCase):
948
def assertConst(self, *args):
950
self.assertTrue(arg.constant is not None)
952
def assertNotConst(self, *args):
954
self.assertTrue(arg.constant is None)
956
def test_simple(self):
957
with FakeTensorMode():
958
x = torch.tensor(4.0)
959
self.assertEqual(x.item(), 4.0)
961
def test_inplace_add(self):
962
with FakeTensorMode():
963
x = torch.tensor(4.0)
965
self.assertEqual(x.item(), 5.0)
966
self.assertEqual(y.item(), 5.0)
967
self.assertConst(x, y)
969
def test_shared_storages(self):
970
with FakeTensorMode():
971
x = torch.tensor([4.0])
974
self.assertEqual(x.storage()._cdata, y.storage()._cdata)
975
self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
977
def test_constant_invalidation(self):
978
with FakeTensorMode():
979
x = torch.tensor([1.0])
983
self.assertNotConst(x)
985
def test_inplace_view_invalidation(self):
986
with FakeTensorMode():
987
x = torch.tensor([1])
990
self.assertEqual(x.size(0), 2)
991
self.assertNotConst(x)
993
def test_fake_tensor_in_intlist_repro(self):
995
max_size = torch.tensor([800, 1216], dtype=torch.int64)
996
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
997
return tensors[0].new_full(batch_shape, 0.0)
999
with self.assertRaises(
1000
torch._subclasses.fake_tensor.DataDependentOutputException
1002
with torch._subclasses.fake_tensor.FakeTensorMode():
1003
a = torch.randn(3, 800, 1199)
1004
b = torch.randn(3, 800, 800)
1008
def test_fake_tensor_batch_norm_cpu(self):
1009
with torch._subclasses.CrossRefFakeMode():
1010
m = torch.nn.Sequential(
1011
torch.nn.BatchNorm2d(10),
1015
out = m(torch.randn([2, 10, 8, 8]))
1017
def test_shared_storage_invalidation(self):
1018
with FakeTensorMode():
1019
x = torch.tensor([1.0])
1021
self.assertConst(x, y)
1022
y.add_(torch.rand([1]))
1023
self.assertNotConst(x, y)
1025
def test_aliased_const_write(self):
1026
with FakeTensorMode():
1027
x = torch.tensor([1])
1029
self.assertNotConst(y)
1031
self.assertNotConst(x)
1033
def test_constant_propagate_through_functions(self):
1034
with FakeTensorMode():
1035
y = torch.div(4, 4, rounding_mode="trunc")
1039
make_propagate_real_tensors_cls(FakeTensorConstHandling)
1042
def contains_type(type: torch.Type, maybe_contained_type: torch.Type):
1043
return maybe_contained_type.isSubtypeOf(type) or any(
1044
contains_type(e, maybe_contained_type) for e in type.containedTypes()
1048
class FakeTensorOpInfoTest(TestCase):
1049
@ops(custom_op_db, dtypes=OpDTypes.any_one)
1050
def test_fake(self, device, dtype, op):
1051
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1052
for sample_input in sample_inputs_itr:
1053
args = (sample_input.input,) + sample_input.args
1054
kwargs = sample_input.kwargs
1055
optests.fake_check(op, args, kwargs)
1058
make_propagate_real_tensors_cls(FakeTensorOpInfoTest)
1059
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda"))
1060
instantiate_device_type_tests(
1061
PropagateRealTensorsFakeTensorOpInfoTest, globals(), only_for=("cpu",) # noqa: F821
1065
class FakeTensorConverterTest(TestCase):
1066
def test_memoized_conversion_to_meta(self):
1067
x = torch.rand(2, 2, 2)
1068
mode = FakeTensorMode()
1069
self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
1071
def test_memoized_conversion_from_meta(self):
1072
x = torch.rand(2, 2).to(device="meta")
1073
mode = FakeTensorMode()
1074
converter = mode.fake_tensor_converter
1076
converter.from_meta_and_device(mode, x, "cpu")
1077
is converter.from_meta_and_device(mode, x, "cpu")
1080
def test_separate_tensor_storages_view(self):
1081
x = torch.rand(2, 2, 2)
1083
mode = FakeTensorMode()
1084
converter = mode.fake_tensor_converter
1085
x_conv = converter.from_real_tensor(mode, x)
1086
y_conv = converter.from_real_tensor(mode, y)
1087
self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
1089
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1090
def test_separate_tensor_storages_non_view(self):
1091
x = torch.rand(2, 2, 2)
1092
y = torch.rand(4, 2)
1094
mode = FakeTensorMode()
1095
converter = mode.fake_tensor_converter
1096
x_conv = converter.from_real_tensor(mode, x)
1097
y_conv = converter.from_real_tensor(mode, y)
1098
stor_id = torch._C._storage_id(x_conv)
1099
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
1102
self.assertEqual(len(converter.tensor_memo), 1)
1103
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
1106
self.assertEqual(len(converter.tensor_memo), 0)
1107
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
1109
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1110
def test_dead_weak_ref(self):
1111
x = torch.rand(2, 2, 2)
1113
mode = FakeTensorMode()
1114
converter = FakeTensorConverter()
1115
x_conv = converter.from_real_tensor(mode, x)
1116
x_conv_storage = x_conv.untyped_storage()
1118
self.assertFalse(x in converter.tensor_memo)
1119
y_conv = converter.from_real_tensor(mode, y)
1120
self.assertIs(x_conv_storage, y_conv.untyped_storage())
1122
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1123
def test_dead_key(self):
1124
x = torch.rand(2, 2, 2)
1125
mode = FakeTensorMode()
1126
converter = FakeTensorConverter()
1127
x_conv = converter.from_real_tensor(mode, x)
1128
self.assertEqual(len(converter.tensor_memo), 1)
1129
x_conv2 = converter.from_real_tensor(mode, x)
1130
assert x_conv2 is x_conv
1134
self.assertEqual(len(converter.tensor_memo), 0)
1136
def test_no_active_mode(self):
1137
with FakeTensorMode() as mode:
1138
x = torch.empty(2, 2, device="cpu")
1139
y = torch.empty(2, 2, device="cpu")
1142
self.assertEqual(mode, out.fake_mode)
1143
self.assertTrue(isinstance(out, FakeTensor))
1144
self.assertEqual(out.device.type, "cpu")
1146
def test_multiple_modes(self):
1148
t2 = torch.rand([4])
1149
with FakeTensorMode() as m:
1150
with FakeTensorMode() as m2:
1151
t_fake = m.from_tensor(t)
1152
t2_fake = m2.from_tensor(t2)
1154
with self.assertRaisesRegex(Exception, "Mixing fake modes"):
1157
def test_separate_mode_error(self):
1158
with FakeTensorMode():
1159
x = torch.empty(2, 2, device="cpu")
1160
with FakeTensorMode():
1161
y = torch.empty(2, 2, device="cpu")
1162
self.assertRaises(Exception, lambda: x, y)
1164
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
1165
def test_no_ref_cycle(self):
1167
mode = FakeTensorMode()
1168
y = mode.from_tensor(x)
1169
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
1170
mode_weak = weakref.ref(mode)
1171
y_weak = weakref.ref(mode)
1174
assert mode_weak() is None
1175
assert y_weak() is None
1178
make_propagate_real_tensors_cls(FakeTensorConverterTest)
1181
class FakeTensorOperatorInvariants(TestCase):
1182
def get_aten_op(self, schema):
1183
namespace, name = schema.name.split("::")
1184
overload = schema.overload_name if schema.overload_name else "default"
1185
assert namespace == "aten"
1186
return getattr(getattr(torch.ops.aten, name), overload)
1188
def get_all_aten_schemas(self):
1189
for schema in torch._C._jit_get_all_schemas():
1190
namespace = schema.name.split("::")[0]
1191
if namespace != "aten":
1195
def test_non_kwarg_only_device(self):
1196
for schema in self.get_all_aten_schemas():
1197
ten_type = torch._C.TensorType.get()
1199
contains_type(arg.type, ten_type)
1200
for arg in itertools.chain(schema.arguments, schema.returns)
1204
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1205
has_non_kwarg_device = any(
1206
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1207
for arg in schema.arguments
1209
if has_non_kwarg_device:
1211
self.get_aten_op(schema)
1212
in torch._subclasses.fake_tensor._device_not_kwarg_ops
1215
def test_tensor_constructors_all_have_kwarg_device(self):
1216
for schema in self.get_all_aten_schemas():
1217
op = self.get_aten_op(schema)
1218
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
1221
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
1222
has_kwarg_device = any(
1223
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
1224
for arg in schema.arguments
1228
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
1231
@unittest.expectedFailure
1232
def test_sparse_new(self):
1233
with FakeTensorMode():
1234
indices = torch.randn(1, 1, dtype=torch.int64)
1235
values = torch.randn(1)
1237
sparse = torch.randn(1).to_sparse()
1238
# This used to segfault, now it does not, but it still raises an
1240
sparse2 = sparse.new(indices, values, extra)
1242
def test_tensor_new(self):
1243
with FakeTensorMode():
1244
x = torch.Tensor([1, 2, 3])
1245
self.assertIsInstance(x, FakeTensor)
1247
def test_like_ops(self):
1248
for schema in self.get_all_aten_schemas():
1249
if "_like" == schema.name[-5:]:
1250
op = self.get_aten_op(schema)
1252
op, torch._subclasses.fake_tensor._like_tensor_constructors
1255
def test_str_storage(self):
1257
with FakeTensorMode() as m:
1258
y = m.from_tensor(x)
1259
self.assertExpectedInline(
1265
[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]""",
1267
self.assertExpectedInline(
1271
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""",
1274
self.assertExpectedInline(
1278
[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""",
1281
# at::_embedding_bag has no op info,
1282
# and returns extra tensors that at::embedding bag throws away
1283
def test_embedding_bag_private(self):
1286
torch.ones(6, dtype=torch.int64),
1287
torch.arange(2, dtype=torch.int64),
1292
ref_out = torch.ops.aten._embedding_bag(*args)
1293
with FakeTensorMode() as m:
1295
m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
1297
meta_out = torch.ops.aten._embedding_bag(*meta_args)
1299
self.assertEqual(len(ref_out), len(meta_out))
1300
for ref_o, meta_o in zip(ref_out, meta_out):
1301
self.assertEqual(ref_o.size(), meta_o.size())
1303
def test_cross_entropy_loss(self):
1304
inp = torch.randn(3, 5)
1305
target = torch.randint(5, (3,), dtype=torch.long)
1306
weight = torch.rand(5)
1307
fn = torch.nn.functional.cross_entropy
1308
for w in (weight, None):
1309
args = (inp, target, w)
1311
with FakeTensorMode() as m:
1313
m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args
1315
meta_out = torch.nn.functional.cross_entropy(
1316
*meta_args, label_smoothing=0.5
1319
self.assertEqual(ref.size(), meta_out.size())
1323
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
1324
"Does not support SDPA or pre-SM80 hardware",
1326
def test_flash_attention(self):
1327
class Repro(torch.nn.Module):
1328
def __init__(self) -> None:
1331
def forward(self, arg1, arg2, arg3):
1332
torch.ops.aten._scaled_dot_product_flash_attention(
1333
arg1, arg2, arg3, scale=0.17677669529663687
1338
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1339
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1340
((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"),
1343
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1344
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1345
((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"),
1348
for args_list in args_new:
1350
rand_strided(bsz, num_heads, seq_len, head_dim)
1351
for (bsz, num_heads, seq_len, head_dim) in args_list
1354
with torch._subclasses.CrossRefFakeMode():
1356
except RuntimeError as e:
1357
# We expect the cross ref to succed for the first output to fail
1358
# for the rng state, see Note [Seed and Offset]
1359
self.assertTrue("output[0]" not in str(e))
1361
"found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!"
1365
# IMPORTANT!!! Always run even if CUDA is not available
1366
def test_fake_gpu_no_init(self):
1367
# Skip this test, we will try to run CUDA operations to real prop so
1368
# it clearly will not work on CPU runner
1369
if torch._functorch.config.fake_tensor_propagate_real_tensors:
1371
with FakeTensorMode():
1372
torch.empty(10, device=GPU_TYPE)
1373
torch.ones(10, device=GPU_TYPE)
1374
torch.zeros(10, device=GPU_TYPE)
1375
torch.rand(10, device=GPU_TYPE)
1376
torch.tensor(3.14, device=GPU_TYPE)
1377
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
1380
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1381
def test_conv_c1_backward(self):
1382
class Repro(torch.nn.Module):
1383
def __init__(self) -> None:
1386
def forward(self, arg1, arg2, arg3):
1387
torch.ops.aten.convolution_backward.default(
1398
[True, True, False],
1402
((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
1403
((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
1404
((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
1406
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
1408
with torch._subclasses.CrossRefFakeMode():
1411
def test_no_dispatch_with_like_function(self):
1412
class CountingMode(TorchDispatchMode):
1413
def __init__(self) -> None:
1416
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1418
return func(*args, **kwargs)
1420
with FakeTensorMode():
1422
with CountingMode() as mode:
1426
self.assertEqual(mode.count, 0)
1429
make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)
1432
class FakeTensorPropTest(TestCase):
1433
def test_fake_tensor_prop_on_nn_module(self):
1434
class ToyNnModuleWithParameters(torch.nn.Module):
1435
def __init__(self) -> None:
1437
self.layer1 = torch.nn.Linear(4, 3)
1438
self.layer2 = torch.nn.Linear(3, 2)
1440
def forward(self, value):
1441
value = self.layer1(value)
1442
value = torch.relu(value)
1443
value = self.layer2(value)
1446
model = ToyNnModuleWithParameters()
1447
value = torch.randn(5, 4)
1448
# Convert nn.Module to GraphModule so that FakeTensorProp runs.
1449
graph_model = torch.fx.symbolic_trace(model, (value,))
1450
# The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
1452
# TODO(wschin): there should be an API to run FakeTensorProp for GraphModule
1453
# with parameters and buffers.
1454
with FakeTensorMode() as fake_tensor_mode:
1456
def to_fake_tensor(x):
1457
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
1458
return fake_tensor_mode.from_tensor(x)
1461
fake_parameters_and_buffers = {
1462
k: to_fake_tensor(v)
1463
for k, v in itertools.chain(
1464
graph_model.named_parameters(), graph_model.named_buffers()
1467
with torch.nn.utils.stateless._reparametrize_module(
1468
graph_model, fake_parameters_and_buffers
1470
# This case uses the **same** fake tensor mode to
1471
# 1. create fake parameters and fake buffers, and
1472
# 2. run FakeTensorProp
1473
# The result should be correct.
1474
result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
1475
self.assertTrue(isinstance(result, FakeTensor))
1476
self.assertEqual(result.shape, (5, 2))
1477
# This case uses the **different** fake tensor modes to
1478
# 1. create fake parameters and fake buffers, and
1479
# 2. run FakeTensorProp
1480
# The following code should fail.
1483
FakeTensorProp(graph_model).propagate(value)
1484
except AssertionError:
1485
# AssertionError: tensor's device must be `meta`, got cpu instead
1487
self.assertTrue(failed)
1489
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
1490
class OptionalArgumentInBetween(torch.nn.Module):
1491
def __init__(self) -> None:
1493
self.layer1 = torch.nn.Linear(4, 3)
1494
self.layer2 = torch.nn.Linear(3, 2)
1496
def forward(self, value, another_value=None, another_optional_value=None):
1497
# Mimic huggingface's `forward` methods which have several optional arguments.
1498
# For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
1499
# To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
1500
if another_value is None:
1501
another_value = torch.rand_like(value)
1502
if another_optional_value is None:
1503
another_optional_value = torch.rand_like(value)
1504
value = value + another_value + another_optional_value
1505
return value * value
1507
fake_mode = FakeTensorMode(
1508
allow_non_fake_inputs=True, allow_fallback_kernels=False
1511
model = OptionalArgumentInBetween()
1512
value = torch.randn(5, 4)
1513
another_optional_value = torch.randn(5, 4)
1514
graph_model = torch.fx.symbolic_trace(
1515
model, (value, None, another_optional_value)
1517
FakeTensorProp(graph_model, fake_mode).propagate(
1518
value, None, another_optional_value
1521
def test_unbacked_shape_realloc(self):
1525
shape_env = ShapeEnv()
1526
fake_mode = FakeTensorMode(shape_env=shape_env)
1528
value = torch.randn(5)
1529
gm = make_fx(f)(value)
1531
n for n in gm.graph.nodes if n.target is torch.ops.aten.nonzero.default
1533
self.assertEqual(len(nonzero_nodes), 1)
1534
self.assertIsInstance(nonzero_nodes[0].meta["val"].shape[0], torch.SymInt)
1535
u0 = nonzero_nodes[0].meta["val"].shape[0]
1536
FakeTensorProp(gm, fake_mode).propagate(value)
1537
u1 = nonzero_nodes[0].meta["val"].shape[0]
1538
# Test that this test is actually doing something in that the
1539
# FakeTensorProp actually triggered a reallocation. If this assert is
1540
# failing, it could be because we started memoizing the nnz count for
1541
# nonzero, which is nice in some sense (no reallocation) but not
1542
# helpful for this test, which is checking what we do when we have
1543
# to reallocate. If so, you need to make this example more
1544
# complicated (e.g., maybe have a nontrivial computation on the input
1545
# before feeding it into nonzero, or have some sort of randomness)
1546
self.assertIsNot(u0, u1)
1547
self.assertTrue(statically_known_true(u0 == u1))
1549
def test_torch_load_with_fake_mode(self):
1550
class TheModelClass(torch.nn.Module):
1551
def __init__(self) -> None:
1553
self.fc1 = torch.nn.Linear(5, 10)
1555
def forward(self, x):
1558
with TemporaryFileName() as state_dict_file:
1559
# Create state_dict to be loaded later
1560
model = TheModelClass()
1561
torch.save(model.state_dict(), state_dict_file)
1563
fake_mode = FakeTensorMode()
1565
torch.load(state_dict_file) # scenario 1
1566
torch.load(state_dict_file, map_location="cpu") # scenario 2
1569
make_propagate_real_tensors_cls(FakeTensorPropTest)
1572
class FakeTensorSerialization(TestCase):
1573
def test_serialization(self):
1574
x = torch.tensor([0], device="cpu")
1575
with FakeTensorMode():
1576
y = pickle.loads(pickle.dumps(x))
1577
self.assertEqual(type(y), FakeTensor)
1578
self.assertEqual(y.device.type, "meta")
1580
with unset_fake_temporarily():
1581
y = pickle.loads(pickle.dumps(x))
1582
self.assertEqual(x.device, y.device)
1584
def test_serialization_with_tracing(self):
1585
x = torch.tensor([0], device="cpu")
1586
with tracing(TracingContext(FakeTensorMode())):
1587
y = pickle.loads(pickle.dumps(x))
1588
self.assertEqual(x.device, y.device)
1591
class FakeTensorDispatchCache(TestCase):
1592
def test_shape_env_settings(self):
1594
Validation that any boolean settings in ShapeEnv are present in the
1595
ShapeEnvSettings. We hope to ensure that any new settings that might
1596
affect FakeTensor dispatch are included in the cache key calculation.
1597
If this test fails, consider updating ShapeEnvSettings or change this
1598
test to omit checking for the new field.
1600
init_sig = inspect.signature(ShapeEnv._init)
1603
for name, param in init_sig.parameters.items()
1604
if type(param.default) is bool
1607
settings = [f.name for f in dataclasses.fields(ShapeEnvSettings)]
1609
self.assertTrue(arg in settings)
1611
def _test_cache_key(self, fm, x, y, z):
1613
Helper for all test_cache_key_* tests below. Assert that the
1614
cache keys for inputs x and y are the same, but z is different.
1616
func = aten.add.Tensor
1617
state = _CacheKeyState()
1618
key_x = fm._cache_key(state, func, [x], {})
1619
key_y = fm._cache_key(state, func, [y], {})
1620
key_z = fm._cache_key(state, func, [z], {})
1622
self.assertEqual(key_x, key_y)
1623
self.assertNotEqual(key_x, key_z)
1625
def test_cache_key_dtype(self):
1626
with FakeTensorMode() as fm:
1627
x = torch.randn(4, 3, dtype=torch.float16)
1628
y = torch.randn(4, 3, dtype=torch.float16)
1629
z = x.to(dtype=torch.float32)
1630
self._test_cache_key(fm, x, y, z)
1632
def test_cache_key_shape(self):
1633
with FakeTensorMode() as fm:
1634
x = torch.randn(4, 3)
1635
y = torch.randn(4, 3)
1636
z = torch.randn(4, 2)
1637
self._test_cache_key(fm, x, y, z)
1639
def test_cache_key_stride(self):
1640
with FakeTensorMode() as fm:
1641
x = torch.randn(4, 2)
1642
y = torch.randn(4, 2)
1643
z = x.as_strided((4, 2), (1, 2))
1644
self._test_cache_key(fm, x, y, z)
1646
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1647
def test_cache_key_device(self):
1648
with FakeTensorMode() as fm:
1649
x = torch.randn(4, 3)
1650
y = torch.randn(4, 3)
1651
z = x.to(device="cuda")
1652
self._test_cache_key(fm, x, y, z)
1654
def test_cache_key_memory_format(self):
1655
with FakeTensorMode() as fm:
1656
x = torch.randn(1, 2, 3, 4)
1657
y = torch.randn(1, 2, 3, 4)
1658
z = x.to(memory_format=torch.channels_last)
1659
self._test_cache_key(fm, x, y, z)
1661
def test_cache_key_storage_offset(self):
1662
with FakeTensorMode() as fm:
1663
x = torch.randn(3)[1:]
1664
y = torch.randn(3)[1:]
1666
self._test_cache_key(fm, x, y, z)
1668
def test_cache_key_requires_grad(self):
1669
with FakeTensorMode() as fm:
1670
x = torch.randn(4, 3)
1671
y = torch.randn(4, 3)
1672
z = torch.randn(4, 3, requires_grad=True)
1673
self._test_cache_key(fm, x, y, z)
1675
def test_cache_key_is_conj(self):
1676
with FakeTensorMode() as fm:
1677
x = torch.randn(4, 3, dtype=torch.complex64)
1678
y = torch.randn(4, 3, dtype=torch.complex64)
1679
z = torch.randn(4, 3, dtype=torch.complex64)
1680
torch._C._set_conj(z, not z.is_conj())
1681
self._test_cache_key(fm, x, y, z)
1683
def test_cache_key_is_neg(self):
1684
with FakeTensorMode() as fm:
1685
x = torch.randn(4, 3, dtype=torch.complex64)
1686
y = torch.randn(4, 3, dtype=torch.complex64)
1687
z = torch.randn(4, 3, dtype=torch.complex64)
1688
torch._C._set_neg(z, not z.is_neg())
1689
self._test_cache_key(fm, x, y, z)
1691
def test_cache_key_is_inference(self):
1692
with torch.inference_mode(True):
1693
t = torch.randn(4, 3)
1694
with FakeTensorMode() as fm:
1695
x = torch.randn(4, 3)
1696
y = torch.randn(4, 3)
1697
z = fm.from_tensor(t)
1698
self._test_cache_key(fm, x, y, z)
1700
def test_cache_key_constants(self):
1701
with FakeTensorMode() as fm:
1702
# Python hashes 1.0 to the same value as 1. Make sure the
1703
# cache key calculation differentiates them.
1704
self._test_cache_key(fm, 1.0, 1.0, 1)
1705
self._test_cache_key(fm, 0.0, 0.0, 0)
1707
def assertHitsMisses(self, hits, misses):
1709
Helper to assert on the number of recorded hits and misses.
1711
info = FakeTensorMode.cache_info()
1712
self.assertEqual(info.hits, hits)
1713
self.assertEqual(info.misses, misses)
1715
def assertBypasses(self, reason, count):
1717
Helper to assert on the number of recorded bypasses.
1719
info = FakeTensorMode.cache_info()
1721
self.assertIn(reason, info.bypasses)
1722
self.assertEqual(info.bypasses[reason], count)
1724
self.assertNotIn(reason, info.bypasses)
1726
def test_cache_hit(self):
1728
Test that cache hit/miss counters are updated correctly.
1730
with FakeTensorMode():
1731
x = torch.randn(4, 3)
1732
y = torch.randn(4, 3)
1734
FakeTensorMode.cache_clear()
1735
self.assertHitsMisses(0, 0)
1737
self.assertHitsMisses(0, 1)
1739
self.assertHitsMisses(1, 1)
1742
extract_tensor_metadata(res1),
1743
extract_tensor_metadata(res2),
1746
def test_cache_bypass(self):
1748
Test that cache bypass counters are updated correctly.
1750
with FakeTensorMode():
1751
x = torch.randn(1, 2)
1753
FakeTensorMode.cache_clear()
1754
self.assertBypasses("inplace view", 0)
1757
self.assertBypasses("inplace view", 1)
1759
def test_cache_default_dtype(self):
1761
Test that the default dtype is respected when serving cached results.
1763
with FakeTensorMode():
1764
x = torch.tensor([1, 2], dtype=torch.int32)
1765
torch.set_default_dtype(torch.float32)
1767
FakeTensorMode.cache_clear()
1768
self.assertHitsMisses(0, 0)
1771
self.assertEqual(y.dtype, torch.float32)
1772
self.assertHitsMisses(0, 1)
1774
torch.set_default_dtype(torch.float16)
1776
self.assertEqual(y.dtype, torch.float16)
1777
self.assertHitsMisses(0, 2)
1779
torch.set_default_dtype(torch.float32)
1781
self.assertEqual(y.dtype, torch.float32)
1782
self.assertHitsMisses(1, 2)
1784
@unittest.skipIf(not RUN_CUDA, "requires cuda")
1785
def test_cache_default_device(self):
1787
Test that the default device is respected when serving cached results.
1789
with FakeTensorMode():
1790
FakeTensorMode.cache_clear()
1791
self.assertHitsMisses(0, 0)
1793
torch.set_default_device("cpu")
1794
x = torch.tensor([1, 2])
1796
self.assertEqual(y.device.type, "cpu")
1797
self.assertHitsMisses(0, 1)
1799
torch.set_default_device("cuda")
1800
x = torch.tensor([1, 2])
1802
self.assertEqual(y.device.type, "cuda")
1803
self.assertHitsMisses(0, 2)
1805
torch.set_default_device("cpu")
1806
x = torch.tensor([1, 2])
1808
self.assertEqual(y.device.type, "cpu")
1809
self.assertHitsMisses(1, 2)
1811
def test_cache_inplace_op(self):
1813
Test that inplace ops served from the cache correctly reference the
1816
with FakeTensorMode():
1817
x = torch.randn(1, 2)
1818
y = torch.randn(1, 2)
1820
FakeTensorMode.cache_clear()
1821
self.assertHitsMisses(0, 0)
1824
self.assertHitsMisses(0, 1)
1825
self.assertEqual(id(x), id(z))
1828
self.assertHitsMisses(1, 1)
1829
self.assertEqual(id(x), id(w))
1831
def test_cache_view_op(self):
1833
Test that view ops are handled correctly when served from the cache.
1835
with FakeTensorMode():
1836
x1 = torch.ones(2, requires_grad=True).clone()
1837
x2 = torch.ones(2, requires_grad=True).clone()
1840
# Test operating on a non-view tensor, then the same operation
1841
# on a view tensor. Assert that the view property is set correctly.
1843
self.assertFalse(z1._is_view())
1846
self.assertTrue(z2._is_view())
1848
# Now the other way around: first operate on a view tensor, then
1849
# the same operation on a non-view tensor.
1851
self.assertTrue(z2._is_view())
1854
self.assertFalse(z1._is_view())
1856
def test_cache_dispatch_key_set(self):
1858
Test that operations that change the dispatch key set bypass caching.
1860
with FakeTensorMode():
1861
FakeTensorMode.cache_clear()
1862
self.assertBypasses("dispatch_key_set mismatch", 0)
1864
x = torch._efficientzerotensor(3)
1865
self.assertTrue(x._is_zerotensor())
1866
self.assertBypasses("dispatch_key_set mismatch", 1)
1868
y = torch._efficientzerotensor(3)
1869
self.assertTrue(y._is_zerotensor())
1870
self.assertBypasses("dispatch_key_set mismatch", 2)
1872
def test_inference_mode(self):
1874
Test that caching handles inference mode correctly.
1876
with FakeTensorMode():
1877
x = torch.randn(4, 3)
1878
y = torch.randn(4, 3)
1880
FakeTensorMode.cache_clear()
1881
self.assertHitsMisses(0, 0)
1883
# Expect a miss when the inference mode is different
1885
with torch.inference_mode():
1888
self.assertHitsMisses(0, 2)
1889
self.assertFalse(res1.is_inference())
1890
self.assertTrue(res2.is_inference())
1892
# Second tries should see hits
1895
self.assertHitsMisses(1, 2)
1896
self.assertFalse(res3.is_inference())
1898
extract_tensor_metadata(res1),
1899
extract_tensor_metadata(res3),
1902
with torch.inference_mode():
1905
self.assertHitsMisses(2, 2)
1906
self.assertTrue(res4.is_inference())
1908
extract_tensor_metadata(res2),
1909
extract_tensor_metadata(res4),
1913
if __name__ == "__main__":